Created
May 29, 2024 10:40
-
-
Save gante/effbc18dee06ce88c200eb2eb1e1d583 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# `torch.compile`-enabled Llama 3 | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch, time, os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", padding_side="left") | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
"meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", torch_dtype=torch.float16 | |
) | |
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") | |
# The first iteration is slow (compilation overhead). Subsequent iterations are | |
# faster than the uncompiled call. | |
inputs = tokenizer(["The quick", "Foo"], return_tensors="pt", padding=True).to(model.device) | |
for i in range(10): | |
start = time.time() | |
gen_out = model.generate( | |
**inputs, do_sample=False, cache_implementation="static", max_length=100 | |
) | |
print(f"Time taken: {time.time() - start:.2f}s") | |
# Time on RTX3090: 43.16s / 25.95s/ 1.99s / 1.99s / ... | |
inputs = tokenizer(["The quick brown", "Foo bar"], return_tensors="pt", padding=True).to(model.device) | |
start = time.time() | |
gen_out = model.generate( | |
**inputs, do_sample=False, cache_implementation="static", max_length=100 | |
) | |
print(f"Time taken: {time.time() - start:.2f}s") | |
# Time on RTX3090: 37.88 -> much slower that the previous call, despite having to generate fewer tokens | |
# ------------------------------------------------------------------------------------- | |
# Same example, now with `pad_to_multiple_of` | |
torch._dynamo.reset() | |
# The first iteration is slow (compilation overhead). Subsequent iterations are | |
# faster than the uncompiled call. | |
inputs = tokenizer(["The quick", "Foo"], return_tensors="pt", padding=True, pad_to_multiple_of=8).to(model.device) | |
for i in range(10): | |
start = time.time() | |
gen_out = model.generate( | |
**inputs, do_sample=False, cache_implementation="static", max_length=100 | |
) | |
print(f"Time taken: {time.time() - start:.2f}s") | |
# Time on RTX3090: 42.96s / 27.61s/ 1.88s / 1.88s / ... | |
inputs = tokenizer(["The quick brown", "Foo bar"], return_tensors="pt", padding=True, pad_to_multiple_of=8).to(model.device) | |
start = time.time() | |
gen_out = model.generate( | |
**inputs, do_sample=False, cache_implementation="static", max_length=100 | |
) | |
print(f"Time taken: {time.time() - start:.2f}s") | |
# Time on RTX3090: 1.88 -> same as before :D |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment