Skip to content

Instantly share code, notes, and snippets.

@gante
Created March 21, 2024 15:42
Show Gist options
  • Save gante/d795ec9fd04503a5fdf708111787d6af to your computer and use it in GitHub Desktop.
Save gante/d795ec9fd04503a5fdf708111787d6af to your computer and use it in GitHub Desktop.
v4.39 Llama 2 + torch.compile
# `torch.compile` enabled Llama 2 🏎️
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, time
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto", torch_dtype=torch.float16
)
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
# The first iteration is slow (compilation overhead). The subsequent iterations are
# faster than the uncompiled call.
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
for i in range(10):
start = time.time()
gen_out = model.generate(
**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100
)
print(f"Time taken: {time.time() - start}")
# Time on my RTX 3090 machine: 34.25s / 19.32s / 0.39s / 0.39s / ...
# Commenting out the compilation line: 1.69s / 1.42s / 1.42s / ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment