Skip to content

Instantly share code, notes, and snippets.

@typio
Last active May 26, 2018 20:34
Show Gist options
  • Save typio/0227400e0e3b1df58bb9506452a83c55 to your computer and use it in GitHub Desktop.
Save typio/0227400e0e3b1df58bb9506452a83c55 to your computer and use it in GitHub Desktop.
shiffman's tf.js linear regression
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="ie=edge">
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.6.1/p5.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.7.0">
</script>
<script src="./sketch.js"></script>
<title>linear regression</title>
</head>
</html>
let x_vals = [];
let y_vals = [];
let m, b;
const learningRate = .5;
const optimizer = tf.train.sgd(learningRate);
function setup() {
createCanvas(400, 400);
m = tf.variable(tf.scalar(random(1)));
b = tf.variable(tf.scalar(random(1)));
}
function loss(pred, labels) {
return pred.sub(labels).square().mean();
}
function predict(x) {
const xs = tf.tensor1d(x);
// y = mx + b
const ys = xs.mul(m).add(b);
return ys;
}
function mousePressed() {
let x = map(mouseX, 0, width, 0, 1);
let y = map(mouseY, 0, height, 1, 0);
x_vals.push(x);
y_vals.push(y);
}
function draw() {
background(51);
stroke(255);
strokeWeight(4);
for (let i = 0; i < x_vals.length; i++) {
let px = map(x_vals[i], 0, 1, 0, width);
let py = map(y_vals[i], 0, 1, height, 0);
point(px, py)
}
tf.tidy(() => {
if (x_vals.length > 0) {
const ys = tf.tensor1d(y_vals);
optimizer.minimize(() => loss(predict(x_vals), ys));
}
});
const lineX = [0, 1];
const ys = tf.tidy(() => predict(lineX));
let lineY = ys.dataSync();
ys.dispose();
let x1 = map(lineX[0], 0, 1, 0, width);
let x2 = map(lineX[1], 0, 1, 0, width);
let y1 = map(lineY[0], 0, 1, height, 0);
let y2 = map(lineY[1], 0, 1, height, 0);
line(x1, y1, x2, y2);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment