Skip to content

Instantly share code, notes, and snippets.

@NTT123
Last active September 19, 2024 12:43
Show Gist options
  • Save NTT123/34d72e71b2e4644f78dbed12a6192b89 to your computer and use it in GitHub Desktop.
Save NTT123/34d72e71b2e4644f78dbed12a6192b89 to your computer and use it in GitHub Desktop.
This script converts a Hugging Face LLaMA3 model checkpoint to the original LLaMA3 checkpoint format.
"""
This script converts a Hugging Face LLaMA3 model checkpoint to the original LLaMA3 checkpoint format.
Usage example:
python convert_hf_to_llama3.py --hf_model_path "path/to/hf/model" --output_path "path/to/output"
"""
import torch
from transformers import LlamaForCausalLM
import os
import json
import argparse
def write_json(data, path):
with open(path, "w") as f:
json.dump(data, f)
def hf_to_llama3(hf_model_path, output_path):
os.makedirs(output_path, exist_ok=True)
# Load the HF model
model = LlamaForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16)
config = model.config
# Helper function to un-permute the weights
def unpermute(w, n_heads, dim1=config.hidden_size, dim2=config.hidden_size):
return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
# Prepare the state dict for Llama3 format
llama3_state_dict = {}
# Convert layer weights
for layer_i in range(config.num_hidden_layers):
layer_prefix = f"model.layers.{layer_i}."
llama3_state_dict.update({
f"layers.{layer_i}.attention.wq.weight": unpermute(
model.state_dict()[f"{layer_prefix}self_attn.q_proj.weight"],
n_heads=config.num_attention_heads
),
f"layers.{layer_i}.attention.wk.weight": unpermute(
model.state_dict()[f"{layer_prefix}self_attn.k_proj.weight"],
n_heads=config.num_key_value_heads,
dim1=config.hidden_size * config.num_key_value_heads // config.num_attention_heads
),
f"layers.{layer_i}.attention.wv.weight": model.state_dict()[f"{layer_prefix}self_attn.v_proj.weight"],
f"layers.{layer_i}.attention.wo.weight": model.state_dict()[f"{layer_prefix}self_attn.o_proj.weight"],
f"layers.{layer_i}.feed_forward.w1.weight": model.state_dict()[f"{layer_prefix}mlp.gate_proj.weight"],
f"layers.{layer_i}.feed_forward.w2.weight": model.state_dict()[f"{layer_prefix}mlp.down_proj.weight"],
f"layers.{layer_i}.feed_forward.w3.weight": model.state_dict()[f"{layer_prefix}mlp.up_proj.weight"],
f"layers.{layer_i}.attention_norm.weight": model.state_dict()[f"{layer_prefix}input_layernorm.weight"],
f"layers.{layer_i}.ffn_norm.weight": model.state_dict()[f"{layer_prefix}post_attention_layernorm.weight"],
})
llama3_state_dict.update({
"tok_embeddings.weight": model.state_dict()["model.embed_tokens.weight"],
"norm.weight": model.state_dict()["model.norm.weight"],
"output.weight": model.state_dict()["lm_head.weight"],
})
# Save the weights
torch.save(llama3_state_dict, os.path.join(output_path, "consolidated.00.pth"))
# Save params.json
params = {
"dim": config.hidden_size,
"n_layers": config.num_hidden_layers,
"n_heads": config.num_attention_heads,
"n_kv_heads": config.num_key_value_heads,
"vocab_size": config.vocab_size,
"norm_eps": config.rms_norm_eps,
"max_seq_len": config.max_position_embeddings,
"use_scaled_rope": True,
"ffn_dim_multiplier": 1.3,
}
write_json(params, os.path.join(output_path, "params.json"))
print(f"LLaMA3 checkpoint saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert HF LLaMA model to LLaMA3 format")
parser.add_argument("--hf_model_path", type=str, required=True, help="Path to the HF model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
args = parser.parse_args()
hf_to_llama3(args.hf_model_path, args.output_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment