Last active
January 22, 2023 04:57
-
-
Save rohitdavas/7a190d8d2ec176a51e7b712bbedee05a to your computer and use it in GitHub Desktop.
SImple Self-Attention with vectorised form understanding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import torch | |
import torch.nn.functional as F # for using softmax | |
cc = torch.manual_seed(0) # for reproducibility | |
# somehow on my m1 mac, the randomness is not reproducible. todo: figure out why | |
b = 1 # mini batch size | |
t = 3 # sequence length | |
k = 2 # dimension of each vector in the sequence | |
# create a data for consistent use in the notebook | |
X_batched = torch.randn(b, t, k) | |
X = torch.randn(t, k) | |
print(f""" | |
X_batched.shape: {X_batched.shape} | |
X.shape: {X.shape} | |
X_batched: {X_batched} | |
X: {X} | |
""") | |
# %% [markdown] | |
# # Simple self-attention | |
# | |
# ```python | |
# | |
# """ | |
# ------------ | |
# CONVENTIONS: | |
# ------------ | |
# | |
# formula for self-attention simple version: | |
# | |
# 1. Given X = [x1, | |
# x2, | |
# .., | |
# .., | |
# .., | |
# xt] where each xi is a vector of dimension k | |
# | |
# shape of X is (t, k) | |
# | |
# 2. We want to compute the self-attention weights for each xi where i goes from i = 0, to t-1 in X | |
# | |
# E.g | |
# | |
# y1, y2, y3, y4, y5, y6 | |
# | | | | | | | |
# ------------------------ | |
# Self - attention weights | |
# ------------------------ | |
# | | | | | | | |
# x1, x2, x3, x4, x5, x6 | |
# | |
# y1 = W11 * x1 + W12 * x2 + W13 * x3 + W14 * x4 + W15 * x5 + W16 * x6 | |
# | |
# W11 : weight for x1 | |
# W12 : weight for x2 | |
# W13 : weight for x3 and so on | |
# | |
# 3. finding W | |
# | |
# for a single Y, we find the weight of each X by computing the dot product of the corresponding X with other time steps of X including itself. | |
# | |
# W1 = [x1, x2, x3, x4, x5, x6] * [x1, x2, x3, x4, x5, x6]T | |
# and normalise the contribution of each weight by softmax. | |
# | |
# """ | |
# | |
# ``` | |
# %% [markdown] | |
# ## Simple for loop way | |
# %% | |
def calc_self_attention(X): | |
# X is a sequence of vectors | |
# Filling of W matrix would be like this: | |
# W11, W12, W13, W14, W15, W16 # weights for y1 | |
# W21, W22, W23, W24, W25, W26 # weights for y2 | |
# W31, W32, W33, W34, W35, W36 # weights for y3 | |
# W41, W42, W43, W44, W45, W46 # weights for y4 | |
# W51, W52, W53, W54, W55, W56 # weights for y5 | |
# W61, W62, W63, W64, W65, W66 # weights for y6 | |
W = torch.zeros(t, t) | |
# calculate W matrix | |
for i in range(t): | |
x_i = X[i] # a vector of dimension k | |
for j in range(t): | |
W[i][j] = torch.dot(x_i, X[j]) | |
# calculate self-attention weights by softmax over dim 1 | |
W = F.softmax(W, dim=1) | |
return W # self-attention weights | |
W = calc_self_attention(X) | |
# now calculate the self-attention vectors | |
Y = torch.zeros_like(X) | |
for i in range(t): | |
for j in range(t): | |
Y[i] += W[i][j] * X[j] | |
for i in range(t): | |
inp = X | |
weights = W[i] | |
out = Y[i] | |
print(f"input: {inp}") | |
print(f"weigh: {weights}") | |
print(f"outpu: {out}") | |
print() | |
# %% [markdown] | |
# ## Matrix multiplication way | |
# %% | |
# vectorized version | |
def calculate_y(X, W): | |
""" | |
Parameters | |
---------- | |
X : torch.tensor | |
X is a sequence of vectors of shape t x k | |
W : torch.tensor | |
W is a matrix of shape t x t. | |
Returns | |
------- | |
Y : torch.tensor | |
Y is a sequence of vectors of shape t x k | |
Notes | |
----- | |
W is the self-attention weights matrix | |
X is input sequence of vectors | |
Y is output sequence of vectors | |
t : sequence length | |
k : dimension of each vector in the sequence | |
W = [w11, w12, ..., w1t] | |
[w21, w22, ..., w2t] | |
[.., .., ..., ..] | |
[wt1, wt2, ..., wtt] | |
each row of W is the self-attention weights for a vector in X | |
that is y1 = w11 * x1 + w12 * x2 + ... + w1t * xt | |
and | |
Y = [w11 * x1 + w12 * x2 + ... + w1t * xt] | |
[w21 * x1 + w22 * x2 + ... + w2t * xt] | |
[.., .., ..., ..] | |
[wt1 * x1 + wt2 * x2 + ... + wtt * xt ] | |
where each x1 is a vector of dimension k | |
resulting Y is a sequence of vectors of shape t x k | |
Y = ^ [w11 * x11 + w12 * x21 + ... + w1t * xt1, w11 * x12 + w12 * x22 + ... + w1t * xt2, ..., w11 * x1k + w12 * x2k + ... + w1t * xtk] | |
| [w21 * x11 + w22 * x21 + ... + w2t * xt1, w21 * x12 + w22 * x22 + ... + w2t * xt2, ..., w21 * x1k + w22 * x2k + ... + w2t * xtk] | |
t rows [.., .., ..., ..] | |
| [wt1 * x11 + wt2 * x21 + ... + wtt * xt1, wt1 * x12 + wt2 * x22 + ... + wtt * xt2, ..., wt1 * x1k + wt2 * x2k + ... + wtt * xtk] | |
<------------- k columns------------------> | |
""" | |
Y = W @ X | |
return Y # self-attention vectors for a sequence | |
# %% | |
def calc_self_attention(X): | |
""" | |
Parameters | |
---------- | |
X : torch.tensor | |
X is a sequence of vectors of shape t x k | |
Returns | |
------- | |
W : torch.tensor | |
W is a matrix of shape t x t. | |
Notes | |
----- | |
W is the self-attention weights matrix | |
X is input sequence of vectors | |
t : sequence length | |
k : dimension of each vector in the sequence | |
now calculation of attention vectors is | |
Y = W @ X | |
for this a row of W reflects the self-attention weights for a vector in X | |
that is y1 = w11 * x1 + w12 * x2 + ... + w1t * xt | |
Now finding the W matrix would be like this: | |
W11, W12, W13, W14, W15, W16 # weights for y1 | |
W11 : reflects the x1 and attention with x1 | |
W12 : reflects the x1 and attention with x2 | |
W13 : reflects the x1 and attention with x3 | |
and so on | |
W21, W22, W23, W24, W25, W26 # weights for y2 | |
W21 reflects the x2 and attention with x1 | |
W22 reflects the x2 and attention with x2 | |
and so on | |
considering this we can calculate the W matrix | |
W = X @ X.T | |
Now we normalise the contribution of each vector in X by softmax | |
W = F.softmax(W, dim=1) | |
y1 = W11 * x1 + W12 * x2 + ... + W1t * xt # k dimension | |
y2 = W21 * x1 + W22 * x2 + ... + W2t * xt # k dimension | |
""" | |
W = X @ X.T # calculate the self-attention weights | |
W = F.softmax(W, dim=1) # normalise the weights by softmax | |
return W | |
# %% | |
W_vectorised = calc_self_attention(X) | |
Y_vectorised = calculate_y(X, W_vectorised) # self-attention vectors for a sequence | |
# %% | |
Y_vectorised | |
# %% | |
Y | |
# %% | |
torch.allclose(Y, Y_vectorised) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment