Created
June 29, 2022 07:18
-
-
Save vedantroy/05e6500ae1bc2f6164b87a6510456007 to your computer and use it in GitHub Desktop.
Transformer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from params import params | |
# NOTATION: | |
# W_k = W (k as subscript) | |
# Wk = W (k as superscript) | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, d_model: int, num_heads: int, mask: torch.Tensor): | |
super(self).__init__() | |
self.mask = mask | |
self.d_k = d_model / num_heads | |
assert self.d_k.is_integer() | |
self.num_heads = num_heads | |
# TODO: Why no biases? | |
# Notice, there are no biases: | |
# MultiHead(Q, K, V) = Concat(head_1, ..., head_h)(Wo) | |
# where head_i = Attention(QWq, KWk, VWv) | |
# A confusing part: there should be multiple attention heads | |
# each with its own copy of Wq, Wk, Wv -- but to represent that | |
# we'll just use a single giant matrix + Pytorch trickery | |
self.Wq = nn.Linear(d_model, d_model, bias=False) | |
self.Wk = nn.Linear(d_model, d_model, bias=False) | |
self.Wv = nn.Linear(d_model, d_model, bias=False) | |
self.linear = nn.Linear(d_model, d_model) | |
def forward(self, x): | |
orig_shape = x.shape | |
batch_size, sequence_len, d_model = x.shape | |
assert params['batch_size'] == batch_size | |
assert params['sequence_len'] == sequence_len | |
assert params['d_model'] == d_model | |
Q, K, V = self.Wq(x), self.Wk(x), self.Wv(x) | |
expected_shape = (batch_size, sequence_len, d_model) | |
assert Q.shape == expected_shape | |
assert K.shape == expected_shape | |
assert V.shape == expected_shape | |
Q, K, V = self.split_into_heads(Q), self.split_into_heads(K), self.split_into_heads(V) | |
expected_shape == (batch_size, self.num_heads, sequence_len, self.d_k) | |
assert Q.shape == expected_shape | |
assert K.shape == expected_shape | |
assert V.shape == expected_shape | |
K_T = K.transpose(2, 3) | |
assert K_T.shape == (batch_size, self.num_heads, self.d_k, sequence_len) | |
# For high-dimensional tensors, the matrix multiplication can only be | |
# operated on the last two dimensions, which requires the previous dimensions to be equal. | |
query_attention_to_keys = Q @ K_T | |
query_attention_to_keys *= (1 / torch.sqrt(self.d_k)) | |
assert query_attention_to_keys.shape == (batch_size, self.num_heads, sequence_len, sequence_len) | |
assert self.mask.shape == (sequence_len, sequence_len) | |
# From paper: | |
# > We need to prevent leftward | |
# > information flow in the decoder to preserve the auto-regressive property. We implement this | |
# > inside of scaled dot-product attention by masking out (setting to −∞) all values in the input | |
query_attention_to_keys.masked_fill_(self.mask == 0, -1e9) | |
query_attention_to_keys_normalized = torch.softmax(query_attention_to_keys, dim=3) | |
combined_value_vectors = query_attention_to_keys_normalized @ V | |
assert combined_value_vectors.shape == (batch_size, self.num_heads, sequence_len, self.d_k) | |
transposed = combined_value_vectors.transpose(1, 2) | |
assert transposed.shape == (batch_size, sequence_len, self.num_heads, self.d_k) | |
concatted = transposed.view(batch_size, sequence_len, d_model) | |
out = self.linear(concatted) | |
assert out.shape == orig_shape | |
return out | |
def split_into_heads(self, tensor): | |
batch_size, sequence_len, d_model = tensor.shape | |
assert self.d_k * self.num_heads == d_model | |
return tensor.view(batch_size, sequence_len, self.num_heads, self.d_k).transpose(1, 2) | |
class DecoderLayer(nn.Module): | |
def __init__(self, num_heads: int, d_model: int, d_ff: int, mask: torch.Tensor): | |
super(self).__init__() | |
self.d_ff = d_ff | |
self.attention = MultiHeadAttention(d_model, num_heads, mask) | |
# TODO: How was this epsilon chosen? | |
self.norm1 = nn.LayerNorm(d_model, eps=1e-6) | |
# https://stats.stackexchange.com/questions/485910/what-is-the-role-of-feed-forward-layer-in-transformer-neural-network-architectur | |
self.lin1 = nn.Linear(d_model, d_ff) | |
self.relu = nn.ReLU() | |
self.lin2 = nn.Linear(d_ff, d_model) | |
self.norm2 = nn.LayerNorm(d_model, eps=1e-6) | |
def forward(self, x): | |
batch_size, sequence_len, d_model = params['batch_size'], params['sequence_len'], params['d_model'] | |
attention = self.attention(x) | |
assert attention.shape == (batch_size, sequence_len, d_model) | |
# add & normalize | |
x = x + attention | |
x = self.norm1(x) | |
assert x.shape == (batch_size, sequence_len, d_model) | |
before_ffn = x | |
x = self.lin1(x) | |
assert x.shape == (batch_size, sequence_len, self.d_ff) | |
x = self.relu(x) | |
x = self.lin2(x) | |
assert x.shape == (batch_size, sequence_len, d_model) | |
# add & normalize again | |
x = self.norm2(before_ffn + x) | |
return x | |
class Transformer(nn.Module): | |
def __init__(self, vocab_size: int, num_heads: int, d_model: int, sequence_len: int, layers: int, mask: torch.Tensor): | |
super().__init__() | |
self.vocab_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) | |
self.positional_embedding = nn.Embedding(num_embeddings=sequence_len, embedding_dim=d_model) | |
assert self.vocab_embedding.weight.shape == (vocab_size, d_model) | |
assert self.positional_embedding.weight.shape == (sequence_len, d_model) | |
self.decoder_layers = nn.ModuleList(DecoderLayer(num_heads, d_model, mask) for _ in range(layers)) | |
# Maps the output embeddings back to tokens | |
# You could also do: | |
# torch.matmul(decoder_output, self.vocab_embedding.weight.tranpose(0, 1)) | |
# inside of `forward` if you wanted | |
# If you wanted to re-use the input embedding matrix | |
# TODO: How does gradient flow work if we re-use the embedding matrix? | |
# TODO: Why don't we need to subtract positional encodings if using tied? | |
# https://github.com/tunz/transformer-pytorch/blob/e7266679f0b32fd99135ea617213f986ceede056/model/transformer.py#L292 | |
self.linear = nn.Linear(d_model, vocab_size) | |
def forward(self, x): | |
batch_size, sequence_len = x.shape | |
assert params['batch_size'] == batch_size | |
assert params['sequence_len'] == sequence_len | |
embeddings = self.vocab_embedding(x) + self.positional_embedding(x) | |
assert embeddings.shape == (batch_size, sequence_len, params['d_model']) | |
decoder_output = embeddings | |
for layer in self.decoder_layers: | |
decoder_output = layer(embeddings) | |
return self.linear(decoder_output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment