Created
March 21, 2024 15:42
-
-
Save gante/d795ec9fd04503a5fdf708111787d6af to your computer and use it in GitHub Desktop.
v4.39 Llama 2 + torch.compile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# `torch.compile` enabled Llama 2 🏎️ | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch, time | |
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
model = AutoModelForCausalLM.from_pretrained( | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto", torch_dtype=torch.float16 | |
) | |
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") | |
# The first iteration is slow (compilation overhead). The subsequent iterations are | |
# faster than the uncompiled call. | |
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device) | |
for i in range(10): | |
start = time.time() | |
gen_out = model.generate( | |
**inputs, do_sample=False, cache_implementation="static", max_new_tokens=100 | |
) | |
print(f"Time taken: {time.time() - start}") | |
# Time on my RTX 3090 machine: 34.25s / 19.32s / 0.39s / 0.39s / ... | |
# Commenting out the compilation line: 1.69s / 1.42s / 1.42s / ... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment