Created
September 25, 2023 16:16
-
-
Save HDCharles/3239500fa3c9dc670c53af05ebbfcd3d to your computer and use it in GitHub Desktop.
changes to model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Full definition of a LLaMA Language Model, all of it in this single file. | |
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. | |
""" | |
# mypy: ignore-errors | |
import math | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from typing_extensions import Self | |
MaskCache = torch.Tensor | |
RoPECache = torch.Tensor | |
KVCache = Tuple[torch.Tensor, torch.Tensor] | |
def find_multiple(n: int, k: int) -> int: | |
if n % k == 0: | |
return n | |
return n + k - (n % k) | |
class LinearInt8(torch.nn.Module): | |
__constants__ = ['in_features', 'out_features'] | |
in_features: int | |
out_features: int | |
weight: torch.Tensor | |
def __init__(self, in_features: int, out_features: int, bias: bool = True, | |
device=None, dtype=None) -> None: | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
super().__init__() | |
self.in_features = in_features | |
self.out_features = out_features | |
self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) | |
# if bias: | |
# self.register_buffer("bias", torch.empty(out_features, **factory_kwargs, dtype=torch.int8)) | |
# else: | |
# self.bias('bias', None) | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
return F.linear(input, self.weight.to(dtype=input.dtype)) | |
# nn.Linear = LinearInt8 | |
@dataclass | |
class LLaMAConfig: | |
block_size: int = 2048 | |
vocab_size: int = 32000 | |
padded_vocab_size: Optional[int] = None | |
n_layer: int = 32 | |
n_head: int = 32 | |
n_embd: int = 4096 | |
def __post_init__(self): | |
if self.padded_vocab_size is None: | |
self.padded_vocab_size = find_multiple(self.vocab_size, 64) | |
@classmethod | |
def from_name(cls, name: str) -> Self: | |
return cls(**llama_configs[name]) | |
llama_configs = { | |
"7B": dict(n_layer=32, n_head=32, n_embd=4096), | |
"13B": dict(n_layer=40, n_head=40, n_embd=5120), | |
"30B": dict(n_layer=60, n_head=52, n_embd=6656), | |
"65B": dict(n_layer=80, n_head=64, n_embd=8192), | |
} | |
class KVCache(nn.Module): | |
def __init__(self, max_batch_size, max_seq_length, n_heads, head_size, device='cuda', dtype=torch.bfloat16): | |
super().__init__() | |
cache_shape = (max_batch_size, n_heads, max_seq_length, head_size) | |
self.k_cache = torch.nn.Parameter(torch.zeros(cache_shape, device=device, dtype=dtype)) | |
self.v_cache = torch.nn.Parameter(torch.zeros(cache_shape, device=device, dtype=dtype)) | |
def update(self, input_pos, k_val, v_val): | |
# input_pos: [S], k_val: [B, H, S, D] | |
assert input_pos.shape[0] == k_val.shape[2] | |
k_out = torch.ops.aten.index_put(self.k_cache, [None, None, input_pos], k_val) | |
v_out = torch.ops.aten.index_put(self.v_cache, [None, None, input_pos], v_val) | |
return k_out, v_out | |
# self.k_cache = torch.ops.aten.index_put(self.k_cache, [None, None, input_pos], k_val) | |
# self.v_cache = torch.ops.aten.index_put(self.v_cache, [None, None, input_pos], v_val) | |
# return self.k_cache, self.v_cache | |
class KVCacheAggregator(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.kv_caches = nn.ModuleList([]) | |
def initialize(self,layers, max_batch_size, max_seq_length, n_heads, head_size, device='cuda', dtype=torch.bfloat16): | |
cache_shape = (max_batch_size, n_heads, max_seq_length, head_size) | |
self.kv_caches = nn.ModuleList([KVCache(max_batch_size, max_seq_length, n_heads, head_size, device) for _ in range(layers)]) | |
def __getitem__(self, idx): | |
return self.kv_caches[idx] | |
def clear(self): | |
self.kv_caches = nn.ParameterList([]) | |
class LLaMA(nn.Module): | |
def __init__(self, config: LLaMAConfig) -> None: | |
super().__init__() | |
assert config.padded_vocab_size is not None | |
self.config = config | |
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) | |
self.transformer = nn.ModuleDict( | |
dict( | |
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), | |
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), | |
ln_f=RMSNorm(config.n_embd), | |
) | |
) | |
self.rope_cache: Optional[RoPECache] = None | |
self.mask_cache: Optional[MaskCache] = None | |
self.kv_caches = KVCacheAggregator() | |
self.max_batch_size = None | |
self.max_seq_length = None | |
def setup_caches(self, max_batch_size, max_seq_length, device='cuda', dtype=torch.bfloat16): | |
head_size = self.config.n_embd // self.config.n_head | |
self.max_seq_length = max_seq_length | |
self.max_batch_size = max_batch_size | |
self.kv_caches.initialize(layers=self.config.n_layer, max_batch_size=max_batch_size, max_seq_length=max_seq_length, n_heads=self.config.n_head, head_size=head_size, device=device) | |
self.rope_cache = build_rope_cache( | |
seq_len=self.config.block_size, | |
n_elem=self.config.n_embd // self.config.n_head, | |
dtype=dtype, | |
device=device, | |
) | |
ones = torch.ones((self.config.block_size, self.config.block_size), device=device, dtype=torch.bool) | |
self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) | |
def _init_weights(self, module: nn.Module) -> None: | |
if isinstance(module, nn.Linear): | |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) | |
elif isinstance(module, nn.Embedding): | |
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) | |
def forward( | |
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]: | |
B, T = idx.size() | |
block_size = self.config.block_size | |
max_seq_length = self.max_seq_length | |
if max_seq_length is None: | |
max_seq_length = block_size | |
assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" | |
assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" | |
assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}" | |
rope = self.rope_cache.index_select(0, input_pos) | |
mask = self.mask_cache.index_select(2, input_pos) | |
mask = mask[:, :, :, :max_seq_length] | |
# forward the model itself | |
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) | |
for i, block in enumerate(self.transformer.h): | |
x, new_kv_cache = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i]) | |
x = self.transformer.ln_f(x) | |
logits = self.lm_head(x) # (b, t, vocab_size) | |
return logits | |
@classmethod | |
def from_name(cls, name: str) -> Self: | |
return cls(LLaMAConfig.from_name(name)) | |
def reset_cache(self) -> None: | |
self.kv_caches.clear() | |
class Block(nn.Module): | |
def __init__(self, config: LLaMAConfig) -> None: | |
super().__init__() | |
self.rms_1 = RMSNorm(config.n_embd) | |
self.attn = CausalSelfAttention(config) | |
self.rms_2 = RMSNorm(config.n_embd) | |
self.mlp = MLP(config) | |
def forward( | |
self, | |
x: torch.Tensor, | |
rope: RoPECache, | |
mask: MaskCache, | |
max_seq_length: int, | |
input_pos: Optional[torch.Tensor] = None, | |
kv_cache: Optional[KVCache] = None, | |
) -> Tuple[torch.Tensor, Optional[KVCache]]: | |
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache) | |
x = x + h | |
x = x + self.mlp(self.rms_2(x)) | |
return x, new_kv_cache | |
class CausalSelfAttention(nn.Module): | |
def __init__(self, config: LLaMAConfig) -> None: | |
super().__init__() | |
assert config.n_embd % config.n_head == 0 | |
# key, query, value projections for all heads, but in a batch | |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) | |
# output projection | |
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) | |
self.n_head = config.n_head | |
self.n_embd = config.n_embd | |
self.block_size = config.block_size | |
def forward( | |
self, | |
x: torch.Tensor, | |
rope: RoPECache, | |
mask: MaskCache, | |
max_seq_length: int, | |
input_pos: Optional[torch.Tensor] = None, | |
kv_cache: Optional[KVCache] = None, | |
) -> Tuple[torch.Tensor, Optional[KVCache]]: | |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
head_size = C // self.n_head | |
k = k.view(B, T, self.n_head, head_size) | |
q = q.view(B, T, self.n_head, head_size) | |
v = v.view(B, T, self.n_head, head_size) | |
q = apply_rope(q, rope) | |
k = apply_rope(k, rope) | |
k = k.transpose(1, 2) # (B, nh, T, hs) | |
q = q.transpose(1, 2) # (B, nh, T, hs) | |
v = v.transpose(1, 2) # (B, nh, T, hs) | |
if kv_cache is not None: | |
k, v = kv_cache.update(input_pos, k, v) | |
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) | |
# att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
# att = att.masked_fill(mask[:,:,:T,:T] == 0, float('-inf')) | |
# att = F.softmax(att, dim=-1) | |
# y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
# efficient attention using Flash Attention CUDA kernels | |
# y = F.scaled_dot_product_attention(q, k, v) | |
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
# output projection | |
y = self.c_proj(y) | |
return y, kv_cache | |
class MLP(nn.Module): | |
def __init__(self, config: LLaMAConfig) -> None: | |
super().__init__() | |
hidden_dim = 4 * config.n_embd | |
n_hidden = int(2 * hidden_dim / 3) | |
n_hidden = find_multiple(n_hidden, 256) | |
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) | |
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) | |
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = F.silu(self.c_fc1(x)) * self.c_fc2(x) | |
x = self.c_proj(x) | |
return x | |
class RMSNorm(nn.Module): | |
"""Root Mean Square Layer Normalization. | |
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: | |
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. | |
""" | |
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: | |
super().__init__() | |
self.scale = nn.Parameter(torch.ones(size)) | |
self.eps = eps | |
self.dim = dim | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# NOTE: the original RMSNorm paper implementation is not equivalent | |
# norm_x = x.norm(2, dim=self.dim, keepdim=True) | |
# rms_x = norm_x * d_x ** (-1. / 2) | |
# x_normed = x / (rms_x + self.eps) | |
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) | |
x_normed = x * torch.rsqrt(norm_x + self.eps) | |
return self.scale * x_normed | |
def build_rope_cache( | |
seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 | |
) -> RoPECache: | |
"""Enhanced Transformer with Rotary Position Embedding. | |
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ | |
transformers/rope/__init__.py. MIT License: | |
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. | |
""" | |
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ | |
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) | |
# Create position indexes `[0, 1, ..., seq_len - 1]` | |
seq_idx = torch.arange(seq_len, dtype=dtype, device=device) | |
# Calculate the product of position index and $\theta_i$ | |
idx_theta = torch.outer(seq_idx, theta).float() | |
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) | |
# this is to mimic the behaviour of complex32, else we will get different results | |
if dtype in (torch.float16, torch.bfloat16, torch.int8): | |
cache = cache.half() | |
return cache | |
def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor: | |
# truncate to support variable sizes | |
T = x.size(1) | |
rope_cache = rope_cache[:T] | |
# cast because the reference does | |
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) | |
rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) | |
x_out2 = torch.stack( | |
[ | |
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], | |
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], | |
], | |
-1, | |
) | |
x_out2 = x_out2.flatten(3) | |
return x_out2.type_as(x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment