Created
July 8, 2024 10:37
-
-
Save gante/09fffc31a362ce1f603ba38f70f44fa7 to your computer and use it in GitHub Desktop.
yarn checks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Assumes: | |
1. transformers on this branch (https://github.com/huggingface/transformers/pull/30910) | |
2. yarn pip installed (https://github.com/jquesnelle/yarn) | |
3. HF login with read token (`huggingface-cli login`) | |
""" | |
import torch | |
from huggingface_hub import hf_hub_download | |
from transformers import AutoConfig, AutoTokenizer | |
from transformers.models.llama.modeling_llama import LlamaYarnScalingRotaryEmbedding, LlamaDynamicYarnScalingRotaryEmbedding | |
from scaled_rope.LlamaYaRNScaledRotaryEmbedding import LlamaYaRNScaledRotaryEmbedding | |
from scaled_rope.LlamaDynamicYaRNScaledRotaryEmbedding import LlamaDynamicYaRNScaledRotaryEmbedding | |
model_id="meta-llama/Meta-Llama-3-8B" | |
filenames = ["config.json", "generation_config.json", "model-00001-of-00004.safetensors", "model-00002-of-00004.safetensors", | |
"model-00003-of-00004.safetensors", "model-00004-of-00004.safetensors", "model.safetensors.index.json", | |
"special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"] | |
for filename in filenames: | |
downloaded_model_path = hf_hub_download(repo_id=model_id, filename=filename) | |
print(downloaded_model_path) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") | |
model_config = AutoConfig.from_pretrained("meta-llama/Meta-Llama-3-8B") | |
def generate_hf_embeddings(input_text, dim, tokenizer, method="yarn"): | |
inputs = tokenizer(input_text, return_tensors="pt") | |
position_ids = torch.ones_like(inputs.input_ids).cumsum(dim=1) - 1 | |
# the embeddings only need the right dtype and device from `x`, the input | |
dummy_input = torch.ones_like(inputs.input_ids, dtype=torch.float32) | |
if method == "yarn": | |
embedding = LlamaYarnScalingRotaryEmbedding(dim=dim) | |
elif method == "dynamic_yarn": | |
embedding = LlamaDynamicYarnScalingRotaryEmbedding(dim=dim) | |
else: | |
raise ValueError("Invalid method specified") | |
return embedding(dummy_input, position_ids) | |
def generate_yarn_embeddings(input_text, dim, tokenizer, method="yarn"): | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# the embeddings only need the right dtype and device from `x`, the input | |
dummy_input = torch.ones_like(inputs.input_ids, dtype=torch.float32) | |
if method == "yarn": | |
embedding = LlamaYaRNScaledRotaryEmbedding(dim=dim) | |
elif method == "dynamic_yarn": | |
embedding = LlamaDynamicYaRNScaledRotaryEmbedding(dim=dim) | |
else: | |
raise ValueError("Invalid method specified") | |
seq_len = inputs.input_ids.size(1) | |
return embedding(dummy_input, seq_len=seq_len) | |
input_text = "This is a large test input. " * 1200 # sequence length > 8k | |
dim = model_config.hidden_size // model_config.num_attention_heads | |
hf_yarn_embeddings = generate_hf_embeddings(input_text, dim, tokenizer, method="yarn") | |
hf_yarn_embeddings_cos = hf_yarn_embeddings[0][0] | |
hf_yarn_embeddings_sin = hf_yarn_embeddings[1][0] | |
hf_dynamic_yarn_embeddings = generate_hf_embeddings(input_text, dim, tokenizer, method="dynamic_yarn") | |
hf_dynamic_yarn_embeddings_cos = hf_dynamic_yarn_embeddings[0][0] | |
hf_dynamic_yarn_embeddings_sin = hf_dynamic_yarn_embeddings[1][0] | |
yarn_embeddings = generate_yarn_embeddings(input_text, dim, tokenizer, method="yarn") | |
yarn_embeddings_cos = yarn_embeddings[0][0, 0] | |
yarn_embeddings_sin = yarn_embeddings[1][0, 0] | |
dynamic_yarn_embeddings = generate_yarn_embeddings(input_text, dim, tokenizer, method="dynamic_yarn") | |
dynamic_yarn_embeddings_cos = dynamic_yarn_embeddings[0][0, 0] | |
dynamic_yarn_embeddings_sin = dynamic_yarn_embeddings[1][0, 0] | |
assert torch.allclose(hf_yarn_embeddings_cos, yarn_embeddings_cos), "Yarn embeddings do not match!" | |
assert torch.allclose(hf_yarn_embeddings_sin, yarn_embeddings_sin), "Yarn embeddings do not match!" | |
assert torch.allclose(hf_dynamic_yarn_embeddings_cos, dynamic_yarn_embeddings_cos), "Dynamic Yarn embeddings do not match!" | |
assert torch.allclose(hf_dynamic_yarn_embeddings_sin, dynamic_yarn_embeddings_sin), "Dynamic Yarn embeddings do not match!" | |
print("Embeddings match successfully!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment