Skip to content

Instantly share code, notes, and snippets.

@viig99
Last active September 9, 2024 12:05
Show Gist options
  • Save viig99/04d5b4055946dabe364e6ea343ee42ad to your computer and use it in GitHub Desktop.
Save viig99/04d5b4055946dabe364e6ea343ee42ad to your computer and use it in GitHub Desktop.
Efficient Transformer Classifier
import torch
import torch.nn as nn
import math
from torch import Tensor
from torch.nn import functional as F
from dataclasses import dataclass
from typing import Optional, Literal
def count_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def model_size_in_megabytes(model: nn.Module) -> float:
param_size = 0
for param in model.parameters():
param_size += param.numel() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.numel() * buffer.element_size()
total_size_in_bytes = param_size + buffer_size
return total_size_in_bytes / (1024**2) # Convert bytes to megabytes
@dataclass
class TransformerConfigs:
vocab_size: int = 30528
num_layers: int = 12
hidden_dim: int = 384
ff_inner_dim: Optional[int] = None
atten_num_heads: int = 6
dropout: float = 0.0
rope_size: int = 2048
rope_base: int = 10000
num_classes: int = 1
pooling: Literal["mean", "cls"] = "mean"
qkv_bias: bool = False
atten_num_inner_heads: Optional[int] = None
def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000, dtype: torch.dtype = torch.bfloat16
) -> Tensor:
freqs = 1.0 / (
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
)
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
return cache.to(dtype=dtype)
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return x_out2.type_as(x)
class MultiHeadAttentionWithROPE(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
qkv_bias: bool = False,
num_inner_heads: Optional[int] = None,
):
super().__init__()
if embed_dim % num_heads > 0:
raise ValueError(
"Invalid num_heads, embed_dim should be divisible by num_heads."
)
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.num_inner_heads = num_heads if num_inner_heads is None else num_inner_heads
self.embed_dim = embed_dim
self.qkv = nn.Linear(
embed_dim,
(self.num_heads + 2 * self.num_inner_heads) * self.head_dim,
bias=qkv_bias,
)
self.proj = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.dropout = dropout
def forward(self, x: Tensor, freqs_cis: Tensor, attn_mask=None) -> Tensor:
B, T, C = x.shape
qkv = self.qkv(x)
kv_size = self.num_inner_heads * self.head_dim
q, k, v = qkv.split([self.embed_dim, kv_size, kv_size], dim=-1)
q = q.view(B, T, self.num_heads, self.head_dim)
k = k.view(B, T, self.num_inner_heads, self.head_dim)
v = v.view(B, T, self.num_inner_heads, self.head_dim)
q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
k = k.repeat_interleave(self.num_heads // self.num_inner_heads, dim=1)
v = v.repeat_interleave(self.num_heads // self.num_inner_heads, dim=1)
context_vec = nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.dropout if self.training else 0.0,
)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.transpose(1, 2).contiguous().view(B, T, C)
context_vec = self.proj(context_vec)
return context_vec
class FeedForward(nn.Module):
def __init__(self, in_dim: int, inner_dim: Optional[int]) -> None:
super().__init__()
inner_dim = inner_dim or in_dim * 4
self.w1 = nn.Linear(in_dim, inner_dim, bias=False)
self.w3 = nn.Linear(in_dim, inner_dim, bias=False)
self.w2 = nn.Linear(inner_dim, in_dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
def __init__(self, dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: Tensor) -> Tensor:
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
class Pooling(nn.Module):
POOLING_TYPES = ["cls", "mean"]
def __init__(self, mode: str) -> None:
super().__init__()
if mode not in self.POOLING_TYPES:
raise ValueError(f"Invalid pooling mode: {mode}")
self.mode = mode
def forward(self, x: Tensor, attn_mask: Tensor) -> Tensor:
if self.mode == "cls":
return x[:, 0]
else:
mask = attn_mask.view(attn_mask.size(0), -1, 1)
sum_embeddings = (x * mask.float()).sum(dim=1)
num_mask_elements = torch.clamp(mask.sum(dim=1), 1e-9)
return sum_embeddings / num_mask_elements
class TransformerBlock(nn.Module):
def __init__(self, config: TransformerConfigs) -> None:
super().__init__()
self.attention = MultiHeadAttentionWithROPE(
config.hidden_dim,
config.atten_num_heads,
config.dropout,
config.qkv_bias,
config.atten_num_inner_heads,
)
self.feed_forward = FeedForward(config.hidden_dim, config.ff_inner_dim)
self.ffn_norm = RMSNorm(config.hidden_dim)
self.attention_norm = RMSNorm(config.hidden_dim)
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, config: TransformerConfigs) -> None:
super().__init__()
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_dim)
self.layers = nn.ModuleList(
TransformerBlock(config) for _ in range(config.num_layers)
)
self.norm = RMSNorm(config.hidden_dim)
self.register_buffer(
"freqs_cis",
precompute_freqs_cis(
config.rope_size,
config.hidden_dim // config.atten_num_heads,
config.rope_base,
self.tok_embeddings.weight.dtype,
),
persistent=False,
)
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith("w2.weight"):
nn.init.normal_(
p, mean=0.0, std=0.02 / math.sqrt(2 * config.num_layers)
)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
input_len = x.size(1)
x = self.tok_embeddings(x)
freqs_cis = self.freqs_cis[:input_len, ...]
mask = mask[:, None, None, :]
for layer in self.layers:
x = layer(x, freqs_cis, mask)
x = self.norm(x)
return x
class TransformerWithClassificationHead(nn.Module):
def __init__(self, config: TransformerConfigs) -> None:
super().__init__()
self.transformer = Transformer(config)
self.classifier = nn.Linear(config.hidden_dim, config.num_classes)
self.pool_layer = Pooling(config.pooling)
def forward(self, x: Tensor, mask: Tensor) -> Tensor:
x = self.transformer(x, mask)
x = self.pool_layer(x, mask)
return self.classifier(x)
@classmethod
def load_from_smollm(
cls, smol_path: str, num_classes: int, pooling_type: Literal["mean", "cls"]
) -> "TransformerWithClassificationHead":
from safetensors import safe_open
from collections import defaultdict
config = TransformerConfigs(
vocab_size=49152,
num_layers=30,
hidden_dim=576,
ff_inner_dim=1536,
atten_num_heads=9,
dropout=0.0,
rope_size=2048,
rope_base=10000,
num_classes=num_classes,
pooling=pooling_type,
qkv_bias=False,
atten_num_inner_heads=3,
)
model = cls(config)
state_dict = model.state_dict()
# Mapping for renaming keys
key_mapping = {
"embed_tokens": "tok_embeddings",
"input_layernorm": "attention_norm",
"post_attention_layernorm": "ffn_norm",
"self_attn": "attention",
"o_proj": "proj",
"gate_proj": "w1",
"up_proj": "w3",
"down_proj": "w2",
"mlp": "feed_forward",
}
# Keys to merge during the loading process
merge_keys = [".k_proj.", ".q_proj.", ".v_proj."]
def replace_keys(key: str) -> str:
"""Replace parts of the key using the defined mapping."""
for old_key, new_key in key_mapping.items():
key = key.replace(f".{old_key}.", f".{new_key}.")
return key.replace("model.", "transformer.")
def needs_merging(key: str) -> bool:
"""Check if the key contains any of the merge keys."""
return any(merge_key in key for merge_key in merge_keys)
tensors = {}
grouped_weights = defaultdict(dict)
# Load tensors from the safetensors file
with safe_open(smol_path, "pt", device="cpu") as f: # type: ignore
for key in f.keys():
transformed_key = replace_keys(key)
if needs_merging(key):
# Grouping weights based on layer and type
group_name = ".".join(transformed_key.split(".")[:4])
layer_num = int(transformed_key.split(".")[2])
weight_type = transformed_key.split(".")[4][0]
if layer_num < config.num_layers:
grouped_weights[group_name][weight_type] = f.get_tensor(key)
elif transformed_key in state_dict:
# Directly assign matching tensors
tensors[transformed_key] = f.get_tensor(key)
# Combine the Q, K, V weights for grouped weights
for group_name, weights in grouped_weights.items():
tensors[f"{group_name}.qkv.weight"] = torch.cat(
[weights["q"], weights["k"], weights["v"]], dim=0
)
# Identify missing weights
missing_weights = set(state_dict.keys()) - set(tensors.keys())
if missing_weights:
print(f"Weights not loaded: {missing_weights}")
# Load the state dict into the model
model.load_state_dict(tensors, strict=False)
return model
if __name__ == "__main__":
smol_path = "smolm-checkpoint/model.safetensors"
smol_model = TransformerWithClassificationHead.load_from_smollm(
smol_path, 11, "mean"
)
print(f"Model size: {model_size_in_megabytes(smol_model):.2f} MB")
print(f"Number of parameters: {count_parameters(smol_model)/1e6:.2f} M")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment