-
-
Save zecho/db37b985fa72464cc4e035d1c16f1c7b to your computer and use it in GitHub Desktop.
fizzbuzz with Nx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
defmodule Mlearning do | |
@moduledoc false | |
def mods(x) do | |
[rem(x, 3), rem(x, 5), rem(x, 15)] | |
end | |
def fizzbuzz(n) do | |
cond do | |
rem(n, 15) == 0 -> [0, 0, 1, 0] | |
rem(n, 3) == 0 -> [1, 0, 0, 0] | |
rem(n, 5) == 0 -> [0, 1, 0, 0] | |
true -> [0, 0, 0, 1] | |
end | |
end | |
defmodule Demo do | |
import Nx.Defn | |
defn relu(x) do | |
custom_grad( | |
Nx.max(x, 0), | |
[x], | |
fn g -> [Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))] end | |
) | |
end | |
defn softmax(logits) do | |
Nx.exp(logits) / Nx.sum(Nx.exp(logits)) | |
end | |
defn loss({w1, b1, w2, b2}, numbers, labels) do | |
preds = predict({w1, b1, w2, b2}, numbers) | |
-Nx.sum(Nx.mean(Nx.log(preds) * labels)) | |
end | |
defn update({w1, b1, w2, b2} = params, numbers, labels) do | |
{grad_w1, grad_b1, grad_w2, grad_b2} = grad(params, &loss(&1, numbers, labels)) | |
{w1 - grad_w1 * 0.01, b1 - grad_b1 * 0.01, w2 - grad_w2 * 0.01, b2 - grad_b2 * 0.01} | |
end | |
defn init_params(key) do | |
w1 = | |
Nx.Random.normal_split(key, 0.0, 0.1, | |
shape: {3, 10}, | |
names: [:input, :hidden], | |
type: {:f, 32} | |
) | |
b1 = Nx.Random.normal_split(key, 0.0, 0.1, shape: {10}, names: [:hidden], type: {:f, 32}) | |
w2 = | |
Nx.Random.normal_split(key, 0.0, 0.1, | |
shape: {10, 4}, | |
names: [:hidden, :output], | |
type: {:f, 32} | |
) | |
b2 = Nx.Random.normal_split(key, 0.0, 0.1, shape: {4}, names: [:output], type: {:f, 32}) | |
{w1, b1, w2, b2} | |
end | |
defn predict({w1, b1, w2, b2}, numbers) do | |
numbers | |
|> Nx.dot(w1) | |
|> Nx.add(b1) | |
|> relu() | |
|> Nx.dot(w2) | |
|> Nx.add(b2) | |
|> softmax() | |
end | |
end | |
def world() do | |
init_nums = | |
1..1000 | |
|> Enum.map(fn n -> | |
mods(n) | |
end) | |
init_labels = | |
1..1000 | |
|> Enum.map(fn n -> | |
fizzbuzz(n) | |
end) | |
key = Nx.Random.key(1) | |
init_params = Demo.init_params(key) | |
data = Enum.zip(init_nums, init_labels) |> Enum.with_index() | |
params = | |
Enum.reduce(1..5, init_params, fn _, params -> | |
data | |
|> Enum.reduce(params, fn {{numbers, labels}, _b}, cur_params -> | |
numbers = numbers |> Nx.tensor() | |
labels = labels |> Nx.tensor() | |
Demo.update(cur_params, numbers, labels) | |
end) | |
end) | |
guess = fn x -> | |
mod = Nx.tensor(mods(x)) | |
Demo.predict(params, mod) | |
case Demo.predict(params, mod) |> Nx.argmax() |> Nx.to_flat_list() do | |
[0] -> "fizz" | |
[1] -> "buzz" | |
[2] -> "fizzbuzz" | |
[3] -> "womp" | |
end | |
end | |
guess.(3) |> IO.inspect(label: "3") | |
guess.(5) |> IO.inspect(label: "5") | |
guess.(15) |> IO.inspect(label: "15") | |
guess.(16) |> IO.inspect(label: "16") | |
guess.(15_432_115) |> IO.inspect(label: "15,432,115") | |
guess.(20_399_985) |> IO.inspect(label: "20,399,985") | |
guess.(20_399_997) |> IO.inspect(label: "20,399,997") | |
guess.(20_399_998) |> IO.inspect(label: "20,399,998") | |
:ok | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment