Created
December 27, 2017 14:39
-
-
Save vbkaisetsu/492efc3f4b8757562d16f8cc03de14f5 to your computer and use it in GitHub Desktop.
primitiv XOR example (Clojure)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns xor-example.core | |
(:import | |
[primitiv Device Graph Parameter Shape] | |
[primitiv functions functions$batch] | |
[primitiv.devices Naive] | |
[primitiv.initializers XavierUniform Constant] | |
[primitiv.optimizers SGD])) | |
(defn -main [] | |
(let [ | |
input_data [ | |
1 1 | |
1 -1 | |
-1 1 | |
-1 -1 | |
] | |
output_data [ | |
1 | |
-1 | |
-1 | |
1 | |
] | |
dev (Naive.) | |
g (Graph.) | |
] | |
(Device/set_default dev) | |
(Graph/set_default g) | |
(let [ | |
pw1 (Parameter. (Shape. [8 2]) (XavierUniform.)) | |
pb1 (Parameter. (Shape. [8]) (Constant. 0)) | |
pw2 (Parameter. (Shape. [1 8]) (XavierUniform.)) | |
pb2 (Parameter. (Shape. []) (Constant. 0)) | |
optimizer (SGD. 0.1) | |
] | |
(.add optimizer [pw1 pb1 pw2 pb2]) | |
(doseq [i (range 10)] | |
(.clear g) | |
(let [ | |
x (functions/input (Shape. [2] 4) input_data) | |
w1 (functions/parameter pw1) | |
b1 (functions/parameter pb1) | |
w2 (functions/parameter pw2) | |
b2 (functions/parameter pb2) | |
h (functions/tanh (functions/add (functions/matmul w1 x) b1)) | |
y (functions/add (functions/matmul w2 h) b2) | |
y_val (.to_array y) | |
t (functions/input (Shape. [] 4) output_data) | |
diff (functions/subtract t y) | |
loss (functions$batch/mean (functions/multiply diff diff)) | |
loss_val (.to_float loss) | |
] | |
(println "epoch " i ":") | |
(doseq [[j y] (map-indexed vector y_val)] | |
(println " [" j "]: " y)) | |
(println " loss: " loss_val) | |
(.reset_gradients optimizer) | |
(.backward loss) | |
(.update optimizer)))))) | |
;; epoch 0 : | |
;; [ 0 ]: -0.51549953 | |
;; [ 1 ]: 0.7876314 | |
;; [ 2 ]: -0.7876314 | |
;; [ 3 ]: 0.51549953 | |
;; loss: 1.4430515 | |
;; | |
;; ~~~~ | |
;; | |
;; ~~~~ | |
;; | |
;; epoch 8 : | |
;; [ 0 ]: 0.06495705 | |
;; [ 1 ]: -0.05289926 | |
;; [ 2 ]: -0.056721658 | |
;; [ 3 ]: 0.05311704 | |
;; loss: 0.88941664 | |
;; epoch 9 : | |
;; [ 0 ]: 0.07288616 | |
;; [ 1 ]: -0.072105706 | |
;; [ 2 ]: -0.05547892 | |
;; [ 3 ]: 0.06541988 | |
;; loss: 0.871522 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment