Last active
April 2, 2020 11:55
-
-
Save davegreenwood/7e0b17526796f7866bfbe23d4a069e2c to your computer and use it in GitHub Desktop.
PyTorch demo to learn line parameters.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Fit a 2d line to data using auto differentiation to learn w, b.""" | |
# %% | |
import torch | |
import matplotlib.pyplot as plt | |
# size of data - more data wil give a more accurate result | |
N = 500 | |
# we can add noise to data to make it more realistic | |
noise = torch.randn(N) | |
def get_data(w, b, n=100): | |
"""generate x, y data """ | |
x = torch.linspace(0, 10.0, n) | |
y = w * x + b | |
return x, y | |
# %% Start | |
# line parameters y = wx + b | |
w_true, b_true = 2.1, 1.8 | |
w_start, b_start = 1.0, 0.0 | |
x_true, y_true = get_data(w_true, b_true, N) | |
x_start, y_start = get_data(w_start, b_start, N) | |
# This is our noisy data!! | |
x, y = x_true + noise, y_true - noise | |
fig, ax = plt.subplots(1, figsize=[9, 9]) | |
ax.plot(x_start, y_start, "g") | |
ax.plot(x_true, y_true, "r") | |
ax.plot(x, y, "+k") | |
ax.legend(["start", "true", "data"]) | |
plt.savefig("start.png") | |
# %% Learning | |
# learn these parameters by setting requires_grad = True!! | |
W = torch.tensor([w_start], requires_grad=True) | |
B = torch.tensor([b_start], requires_grad=True) | |
n_evals = 20000 | |
learning_rate = 0.0001 | |
optimizer = torch.optim.SGD([W, B], lr=learning_rate) | |
loss_func = torch.nn.MSELoss() | |
# repeatedly calculate the loss and back prop | |
for i in range(n_evals): | |
optimizer.zero_grad() | |
x_p, y_p = get_data(W, B, N) | |
loss = loss_func(x, x_p) + loss_func(y, y_p) | |
loss.backward() | |
optimizer.step() | |
if i % 1000 == 0: | |
# report every 1000 evals | |
log = f"eval: {i}, w: {W[0]:0.3f}, b: {B[0]:0.3f}" | |
print(log) | |
log = f"eval: {i}, w: {W[0]:0.3f}, b: {B[0]:0.3f}" | |
print(log) | |
# %% Results | |
fig, ax = plt.subplots(1, figsize=[9, 9]) | |
ax.plot(x_p.clone().detach(), y_p.clone().detach(), "b") | |
ax.plot(x_true, y_true, "r") | |
ax.plot(x, y, "+k") | |
ax.legend(["predicted", "true", "data"]) | |
plt.savefig("result.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment