Created
March 14, 2020 15:45
-
-
Save joelgrus/ac38e2d726fd560d7be64f221a1e03a3 to your computer and use it in GitHub Desktop.
this is the LSTM implementation that the 2ed of data science from scratch promised to share
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Lstm(Layer): | |
def __init__(self, input_dim: int, hidden_dim: int) -> None: | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
# Forget-gate weights | |
self.w_f = random_tensor(hidden_dim, input_dim, variance=2/(hidden_dim + input_dim)) | |
self.u_f = random_tensor(hidden_dim, hidden_dim, variance=1/hidden_dim) | |
self.b_f = random_tensor(hidden_dim) | |
# Input-gate weights | |
self.w_i = random_tensor(hidden_dim, input_dim, variance=2/(hidden_dim + input_dim)) | |
self.u_i = random_tensor(hidden_dim, hidden_dim, variance=1/hidden_dim) | |
self.b_i = random_tensor(hidden_dim) | |
# Output-gate weights | |
self.w_o = random_tensor(hidden_dim, input_dim, variance=2/(hidden_dim + input_dim)) | |
self.u_o = random_tensor(hidden_dim, hidden_dim, variance=1/hidden_dim) | |
self.b_o = random_tensor(hidden_dim) | |
# Cell state weights | |
self.w_c = random_tensor(hidden_dim, input_dim, variance=2/(hidden_dim + input_dim)) | |
self.u_c = random_tensor(hidden_dim, hidden_dim, variance=1/hidden_dim) | |
self.b_c = random_tensor(hidden_dim) | |
# reset hidden state | |
self.reset_hidden_state() | |
def reset_hidden_state(self) -> None: | |
self.h = [0 for _ in range(self.hidden_dim)] | |
self.c = [0 for _ in range(self.hidden_dim)] | |
def forward(self, input: Tensor) -> Tensor: | |
# Remember to use in backprop | |
self.input = input | |
self.prev_c = self.c | |
self.prev_h = self.h | |
# Forget gate | |
self.pre_f = [dot(self.w_f[h], input) + dot(self.u_f[h], self.h) + self.b_f[h] | |
for h in range(self.hidden_dim)] | |
self.f = tensor_apply(sigmoid, self.pre_f) | |
# Input gate | |
self.pre_i = [dot(self.w_i[h], input) + dot(self.u_i[h], self.h) + self.b_i[h] | |
for h in range(self.hidden_dim)] | |
self.i = tensor_apply(sigmoid, self.pre_i) | |
# Output gate | |
self.pre_o = [dot(self.w_o[h], input) + dot(self.u_o[h], self.h) + self.b_o[h] | |
for h in range(self.hidden_dim)] | |
self.o = tensor_apply(sigmoid, self.pre_o) | |
# pre-cell-state | |
self.pre_pre_c = [dot(self.w_c[h], input) + dot(self.u_c[h], self.h) + self.b_c[h] | |
for h in range(self.hidden_dim)] | |
self.pre_c = tensor_apply(tanh, self.pre_pre_c) | |
# forget some cell state (using the forget gate), and | |
# add some new cell state (using the input gate and the pre-cell-state) | |
self.updated_c = [self.f[h] * self.c[h] + self.i[h] * self.pre_c[h] | |
for h in range(self.hidden_dim)] | |
# apply non-linearity to updated_c | |
self.c = tensor_apply(tanh, self.updated_c) | |
# update hidden state using output gate and new cell state | |
self.h = [self.o[h] * self.c[h] for h in range(self.hidden_dim)] | |
return self.h | |
def backward(self, gradient: Tensor) -> Tensor: | |
# Let's chug along one step at a time | |
# gradients from self.h = | |
grad_o = [gradient[h] * self.c[h] for h in range(self.hidden_dim)] | |
grad_c = [self.o[h] * gradient[h] for h in range(self.hidden_dim)] | |
# gradient from self.c = | |
grad_updated_c = [grad_c[h] * (1 - self.c[h] ** 2) for h in range(self.hidden_dim)] | |
# gradient from self.updated_c = | |
grad_f = [grad_updated_c[h] * self.c[h] for h in range(self.hidden_dim)] | |
grad_prev_c = [self.f[h] * grad_updated_c[h] for h in range(self.hidden_dim)] | |
grad_i = [grad_updated_c[h] * self.pre_c[h] for h in range(self.hidden_dim)] | |
grad_pre_c = [self.i[h] * grad_updated_c[h] for h in range(self.hidden_dim)] | |
# gradient from self.pre_c = | |
grad_pre_pre_c = [grad_pre_c[h] * (1 - self.pre_c[h] ** 2) for h in range(self.hidden_dim)] | |
# gradients from self.pre_pre_c = (except for wrt inputs) | |
self.grad_w_c = [[grad_pre_pre_c[h] * self.input[i] for i in range(self.input_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_u_c = [[grad_pre_pre_c[h] * self.prev_h[h2] for h2 in range(self.hidden_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_b_c = grad_pre_pre_c | |
# gradient from self.o = | |
grad_pre_o = [grad_o[h] * (1 - self.o[h]) * self.o[h] for h in range(self.hidden_dim)] | |
# gradient from self.pre_o = | |
self.grad_w_o = [[grad_pre_o[h] * self.input[i] for i in range(self.input_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_u_o = [[grad_pre_o[h] * self.prev_h[h2] for h2 in range(self.hidden_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_b_o = grad_pre_o | |
# gradient from self.i = | |
grad_pre_i = [grad_i[h] * (1 - self.i[h]) * self.i[h] for h in range(self.hidden_dim)] | |
# gradient from self.pre_i = | |
self.grad_w_i = [[grad_pre_i[h] * self.input[i] for i in range(self.input_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_u_i = [[grad_pre_i[h] * self.prev_h[h2] for h2 in range(self.hidden_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_b_i = grad_pre_i | |
# gradient from self.f = | |
grad_pre_f = [grad_f[h] * (1 - self.f[h]) * self.f[h] for h in range(self.hidden_dim)] | |
# gradient from self.pre_f = | |
self.grad_w_f = [[grad_pre_f[h] * self.input[i] for i in range(self.input_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_u_f = [[grad_pre_f[h] * self.prev_h[h2] for h2 in range(self.hidden_dim)] | |
for h in range(self.hidden_dim)] | |
self.grad_b_f = grad_pre_f | |
# gradient wrt input | |
return [sum(grad_pre_f[h] * self.w_f[h][i] + | |
grad_pre_i[h] * self.w_i[h][i] + | |
grad_pre_o[h] * self.w_o[h][i] + | |
grad_pre_c[h] * self.w_c[h][i] | |
for h in range(self.hidden_dim)) | |
for i in range(self.input_dim)] | |
def params(self): | |
return [ | |
self.w_o, self.u_o, self.b_o, | |
self.w_i, self.u_i, self.b_i, | |
self.w_f, self.u_f, self.b_f, | |
self.w_c, self.u_c, self.b_c | |
] | |
def grads(self): | |
return [ | |
self.grad_w_o, self.grad_u_o, self.grad_b_o, | |
self.grad_w_i, self.grad_u_i, self.grad_b_i, | |
self.grad_w_f, self.grad_u_f, self.grad_b_f, | |
self.grad_w_c, self.grad_u_c, self.grad_b_c | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment