Created
September 20, 2023 14:38
-
-
Save RepComm/fbeb918719c90e4a2af169142da14874 to your computer and use it in GitHub Desktop.
Arbitrary math function tensorflow learning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import { Tensor, layers, sequential, tensor, train } from "@tensorflow/tfjs-node"; | |
function create_model () { | |
const model = sequential(); | |
model.add(layers.dense({ | |
units: 16, | |
activation: "relu", | |
inputShape: [1] | |
})); | |
model.add(layers.dense({ | |
units: 1 | |
})); | |
return model; | |
} | |
function lerp (from: number, to: number, by: number): number { | |
return from*(1-by)+to*by; | |
} | |
function generate_quiz (f: (n: number)=> number, samples: number = 100, min: number = 0, max: number = 100) { | |
const quiz = new Array<number>(samples); | |
const answerKey = new Array<number>(samples); | |
for (let i=0; i<samples; i++) { | |
const v = lerp(min, max, Math.random()); | |
quiz[i] = v; | |
answerKey[i] = f(v); | |
} | |
const _quiz = tensor(quiz); | |
const _answerKey = tensor(answerKey); | |
return { | |
quiz: _quiz, | |
answerKey: _answerKey | |
} | |
} | |
async function main () { | |
//Aqcuire the machine learning model | |
const model = create_model(); | |
//in case we're training it, use a specific optimiser good at what job we want to do | |
//in this case the default "sgd" or sigmoid is a litle too aggresive and the loss gets stuck oscillating around ~0.4 | |
//with the adam optimiser the loss is able to fully go down to 0.001 after enough epochs of training | |
const optimizer = train.adam(0.01); | |
//whether training or running a trained model, we need to "compile" it for some reason | |
//probably making it efficient to run on CPU/GPU | |
model.compile({ | |
loss: "meanSquaredError", | |
optimizer: optimizer | |
}); | |
//Create a quiz that is based on the Math.sin function | |
//this isn't a tensorflow thing, we designed this in-house to make it easy | |
//to create training data for arbitrary functions like Math.sin, etc | |
const quiz = generate_quiz((n)=>{ | |
return Math.sin(n); | |
}, 200, 0, 10); | |
//Train the model with our quiz and grade it with our answer key | |
//we "await" to wait for it to finish before moving on | |
await model.fit(quiz.quiz, quiz.answerKey, { | |
epochs: 1000, | |
verbose: 0, | |
callbacks: { | |
onEpochEnd: (ep, logs)=>{ | |
//After an individual session of training, give us how well it did | |
console.log(`Epoch: ${ep}: loss = ${logs.loss}`); | |
} | |
}}); | |
//Create a test value we want Math.sin() of | |
const test = 1; | |
//Predict Math.sin of test | |
const prediction = await (model.predict(tensor([ | |
[test] | |
])) as Tensor).data(); | |
//Perfect prediction for Math.sin(1) is 0.8414... | |
//Predicted value after ~1000 epochs of training is 0.84434 | |
console.log("Predicted f(", test, "):", prediction); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment