Last active
January 17, 2023 21:47
-
-
Save mcminis1/33238787522303ab12bf036dd8cd1501 to your computer and use it in GitHub Desktop.
RNN from scratch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import pickle | |
# Configuration | |
## RNN definition | |
### maximum length of sequence | |
T = 6 | |
### hidden state vector dimension | |
hidden_dim = 32 | |
### output length | |
output_dim = 8 | |
## training params | |
### cutoff for linear gradient | |
alpha = 0.025 | |
### learning rate | |
eps = 1e-1 | |
### number of training epochs | |
n_epochs = 10000 | |
### number of samples to reserve for test | |
test_set_size = 4 | |
### number of samples to generate | |
n_samples = 50 | |
rng = np.random.default_rng(2882) | |
# the "hidden layer". aka the transition matrix. these are the weights in the RNN. | |
# shape: hidden_dim x hidden_dim | |
W = rng.normal(0, (hidden_dim * hidden_dim) ** -0.75, size=(hidden_dim, hidden_dim)) | |
_, W = np.linalg.qr(W, mode='complete') | |
# input matrix. translates from input vector to W | |
# shape: hidden_dim x T | |
U = rng.normal(0, (hidden_dim * T) ** -0.75, size=(hidden_dim, T)) | |
svd_u, _, svd_vh = np.linalg.svd(U, full_matrices=False) | |
U = np.dot(svd_u, svd_vh) | |
# output matrix. translates from W to the output vector | |
# shape: output_dim x hidden_dim | |
V = rng.normal(0, (output_dim * hidden_dim) ** -0.75, size=(output_dim, hidden_dim)) | |
svd_u, _, svd_vh = np.linalg.svd(V, full_matrices=False) | |
V = np.dot(svd_u, svd_vh) | |
# this is the formula used to update the hidden state | |
def new_hidden_state(x, s): | |
u = np.dot(U, x) | |
w = np.dot(W, s) | |
rv = 1 / (1 + np.exp(-(u + w))) | |
return rv | |
def el_mul(v, m): | |
r = np.zeros_like(m) | |
for c in range(r.shape[1]): | |
r[:, c] = v * m[:, c] | |
return r | |
def l_grad(dy): | |
return np.array([np.maximum(np.minimum(1.0,y),-1.0) for y in dy]) | |
def plot_tests(step): | |
v_lines = [] | |
for plot_i, x_y in enumerate(X_test): | |
xs = x_y[:T] | |
ys = x_y[T:] | |
rnn_s = np.zeros(hidden_dim, dtype=np.float64) | |
for t in range(T): | |
x_i = np.zeros(T, dtype=np.float64) | |
x_i[t] = xs[t] | |
rnn_s = new_hidden_state(x_i, rnn_s) | |
y_hat = np.dot(V, rnn_s) | |
x = x_grid[:output_dim] + dx_grid*(output_dim + 1)*plot_i | |
v_lines.append(dx_grid*(output_dim + 1)*plot_i - dx_grid) | |
plt.plot(x, y_hat, "r") | |
plt.plot(x, ys, "g") | |
for x_pos in v_lines[1:]: | |
plt.vlines(x_pos, -1, 1) | |
frame1 = plt.gca() | |
frame1.axes.get_xaxis().set_ticks([]) | |
frame1.set_ylim([-1.1,1.1]) | |
plt.savefig(f"step_plots/{step:06d}.png", format='png') | |
plt.clf() | |
# set up training data: | |
# let's use sin as out target method. | |
x_grid = np.linspace(0, 4 * np.pi, num=n_samples + test_set_size + T + output_dim) | |
dx_grid = x_grid[1] - x_grid[0] | |
sin_wave = np.sin(x_grid) | |
n_data_points = sin_wave.shape[0] | |
n_samples = n_data_points - T - output_dim | |
X = [] | |
for i in range(0, n_samples): | |
X.append(sin_wave[i : i + T + output_dim]) | |
np.random.shuffle(X) | |
X_test = X[:test_set_size] | |
X = X[test_set_size:] | |
print(f"n_data_points: {n_data_points}") | |
print(f"n_samples: {len(X)}") | |
print(f"n_test: {len(X_test)}") | |
print(f"input length : {T}") | |
print(f"hidden_dim length: {hidden_dim}") | |
print(f"output length: {output_dim}") | |
eps = eps / n_samples | |
for e_i in range(n_epochs): | |
loss = 0 | |
dL_dV = 0 | |
dL_dU = 0 | |
dL_dW = 0 | |
for x_y in X: | |
xs = x_y[:T] | |
ys = x_y[T:] | |
rnn_s = np.zeros(hidden_dim, dtype=np.float64) | |
rnn_ds_dU = np.zeros((hidden_dim, T), dtype=np.float64) | |
rnn_ds_dW = np.zeros((hidden_dim, hidden_dim), dtype=np.float64) | |
for t in range(T): | |
x_i = np.zeros(T, dtype=np.float64) | |
x_i[t] = xs[t] | |
p_rnn_s = rnn_s | |
rnn_s = new_hidden_state(x_i, rnn_s) | |
# derivs | |
ds = rnn_s * (1 - rnn_s) | |
ds_W = el_mul(ds, W) | |
rnn_ds_dU = np.dot(ds_W, rnn_ds_dU) | |
rnn_ds_dU += np.outer(ds, x_i) | |
rnn_ds_dW = np.dot(ds_W, rnn_ds_dW) | |
rnn_ds_dW += np.outer(ds, p_rnn_s) | |
dy = np.dot(V, rnn_s) - ys | |
rnn_dL_dV = np.outer(l_grad(dy), rnn_s) | |
dyV = np.dot(l_grad(dy), V) | |
loss_i = (0.5 * dy**2).sum() | |
rnn_dL_dW = el_mul(dyV, rnn_ds_dW) | |
rnn_dL_dU = el_mul(dyV, rnn_ds_dU) | |
loss += loss_i | |
dL_dV += rnn_dL_dV | |
dL_dW += rnn_dL_dW | |
dL_dU += rnn_dL_dU | |
if (e_i + 1) % 100 == 0 or e_i == 0: | |
print( | |
f"{e_i}: total loss: {loss}\n\t\t<error> per data point: {np.sqrt(loss/n_samples/output_dim)}" | |
) | |
print(f" dV range: {np.max(dL_dV) - np.min(dL_dV)}") | |
print(f" dU range: {np.max(dL_dU) - np.min(dL_dU)}") | |
print(f" dW range: {np.max(dL_dW) - np.min(dL_dW)}") | |
plot_tests(e_i) | |
W = W - eps * dL_dW | |
V = V - eps * dL_dV | |
U = U - eps * dL_dU | |
with open('weights.pkl', 'wb')as f: | |
pickle.dump([U,W,V], f, protocol=pickle.HIGHEST_PROTOCOL) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment