Skip to content

Instantly share code, notes, and snippets.

@gante
Created July 8, 2024 10:37
Show Gist options
  • Save gante/09fffc31a362ce1f603ba38f70f44fa7 to your computer and use it in GitHub Desktop.
Save gante/09fffc31a362ce1f603ba38f70f44fa7 to your computer and use it in GitHub Desktop.
yarn checks
"""
Assumes:
1. transformers on this branch (https://github.com/huggingface/transformers/pull/30910)
2. yarn pip installed (https://github.com/jquesnelle/yarn)
3. HF login with read token (`huggingface-cli login`)
"""
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaYarnScalingRotaryEmbedding, LlamaDynamicYarnScalingRotaryEmbedding
from scaled_rope.LlamaYaRNScaledRotaryEmbedding import LlamaYaRNScaledRotaryEmbedding
from scaled_rope.LlamaDynamicYaRNScaledRotaryEmbedding import LlamaDynamicYaRNScaledRotaryEmbedding
model_id="meta-llama/Meta-Llama-3-8B"
filenames = ["config.json", "generation_config.json", "model-00001-of-00004.safetensors", "model-00002-of-00004.safetensors",
"model-00003-of-00004.safetensors", "model-00004-of-00004.safetensors", "model.safetensors.index.json",
"special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"]
for filename in filenames:
downloaded_model_path = hf_hub_download(repo_id=model_id, filename=filename)
print(downloaded_model_path)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
model_config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B")
def generate_hf_embeddings(input_text, dim, tokenizer, method="yarn"):
inputs = tokenizer(input_text, return_tensors="pt")
position_ids = torch.ones_like(inputs.input_ids).cumsum(dim=1) - 1
# the embeddings only need the right dtype and device from `x`, the input
dummy_input = torch.ones_like(inputs.input_ids, dtype=torch.float32)
if method == "yarn":
embedding = LlamaYarnScalingRotaryEmbedding(dim=dim)
elif method == "dynamic_yarn":
embedding = LlamaDynamicYarnScalingRotaryEmbedding(dim=dim)
else:
raise ValueError("Invalid method specified")
return embedding(dummy_input, position_ids)
def generate_yarn_embeddings(input_text, dim, tokenizer, method="yarn"):
inputs = tokenizer(input_text, return_tensors="pt")
# the embeddings only need the right dtype and device from `x`, the input
dummy_input = torch.ones_like(inputs.input_ids, dtype=torch.float32)
if method == "yarn":
embedding = LlamaYaRNScaledRotaryEmbedding(dim=dim)
elif method == "dynamic_yarn":
embedding = LlamaDynamicYaRNScaledRotaryEmbedding(dim=dim)
else:
raise ValueError("Invalid method specified")
seq_len = inputs.input_ids.size(1)
return embedding(dummy_input, seq_len=seq_len)
input_text = "This is a large test input. " * 1200 # sequence length > 8k
dim = model_config.hidden_size // model_config.num_attention_heads
hf_yarn_embeddings = generate_hf_embeddings(input_text, dim, tokenizer, method="yarn")
hf_yarn_embeddings_cos = hf_yarn_embeddings[0][0]
hf_yarn_embeddings_sin = hf_yarn_embeddings[1][0]
hf_dynamic_yarn_embeddings = generate_hf_embeddings(input_text, dim, tokenizer, method="dynamic_yarn")
hf_dynamic_yarn_embeddings_cos = hf_dynamic_yarn_embeddings[0][0]
hf_dynamic_yarn_embeddings_sin = hf_dynamic_yarn_embeddings[1][0]
yarn_embeddings = generate_yarn_embeddings(input_text, dim, tokenizer, method="yarn")
yarn_embeddings_cos = yarn_embeddings[0][0, 0]
yarn_embeddings_sin = yarn_embeddings[1][0, 0]
dynamic_yarn_embeddings = generate_yarn_embeddings(input_text, dim, tokenizer, method="dynamic_yarn")
dynamic_yarn_embeddings_cos = dynamic_yarn_embeddings[0][0, 0]
dynamic_yarn_embeddings_sin = dynamic_yarn_embeddings[1][0, 0]
assert torch.allclose(hf_yarn_embeddings_cos, yarn_embeddings_cos), "Yarn embeddings do not match!"
assert torch.allclose(hf_yarn_embeddings_sin, yarn_embeddings_sin), "Yarn embeddings do not match!"
assert torch.allclose(hf_dynamic_yarn_embeddings_cos, dynamic_yarn_embeddings_cos), "Dynamic Yarn embeddings do not match!"
assert torch.allclose(hf_dynamic_yarn_embeddings_sin, dynamic_yarn_embeddings_sin), "Dynamic Yarn embeddings do not match!"
print("Embeddings match successfully!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment