Skip to content

Instantly share code, notes, and snippets.

@gante
Created May 29, 2024 10:40
Show Gist options
  • Save gante/effbc18dee06ce88c200eb2eb1e1d583 to your computer and use it in GitHub Desktop.
Save gante/effbc18dee06ce88c200eb2eb1e1d583 to your computer and use it in GitHub Desktop.
# `torch.compile`-enabled Llama 3
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, time, os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", torch_dtype=torch.float16
)
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
# The first iteration is slow (compilation overhead). Subsequent iterations are
# faster than the uncompiled call.
inputs = tokenizer(["The quick", "Foo"], return_tensors="pt", padding=True).to(model.device)
for i in range(10):
start = time.time()
gen_out = model.generate(
**inputs, do_sample=False, cache_implementation="static", max_length=100
)
print(f"Time taken: {time.time() - start:.2f}s")
# Time on RTX3090: 43.16s / 25.95s/ 1.99s / 1.99s / ...
inputs = tokenizer(["The quick brown", "Foo bar"], return_tensors="pt", padding=True).to(model.device)
start = time.time()
gen_out = model.generate(
**inputs, do_sample=False, cache_implementation="static", max_length=100
)
print(f"Time taken: {time.time() - start:.2f}s")
# Time on RTX3090: 37.88 -> much slower that the previous call, despite having to generate fewer tokens
# -------------------------------------------------------------------------------------
# Same example, now with `pad_to_multiple_of`
torch._dynamo.reset()
# The first iteration is slow (compilation overhead). Subsequent iterations are
# faster than the uncompiled call.
inputs = tokenizer(["The quick", "Foo"], return_tensors="pt", padding=True, pad_to_multiple_of=8).to(model.device)
for i in range(10):
start = time.time()
gen_out = model.generate(
**inputs, do_sample=False, cache_implementation="static", max_length=100
)
print(f"Time taken: {time.time() - start:.2f}s")
# Time on RTX3090: 42.96s / 27.61s/ 1.88s / 1.88s / ...
inputs = tokenizer(["The quick brown", "Foo bar"], return_tensors="pt", padding=True, pad_to_multiple_of=8).to(model.device)
start = time.time()
gen_out = model.generate(
**inputs, do_sample=False, cache_implementation="static", max_length=100
)
print(f"Time taken: {time.time() - start:.2f}s")
# Time on RTX3090: 1.88 -> same as before :D
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment