Last active
January 16, 2017 20:25
-
-
Save mthomure/45e6606c5a625ba0baa3d92db0a821e0 to your computer and use it in GitHub Desktop.
First steps at clojure idiomatic wrapper for mallet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(ns learn-mallet.core | |
(:import [cc.mallet.optimize Optimizable$ByGradientValue | |
ConjugateGradient GradientAscent LimitedMemoryBFGS | |
OptimizerEvaluator$ByGradient])) | |
;; add dependency [cc.mallet/mallet "2.0.8"] | |
(defprotocol IProblem | |
;; returns map of problem's current state | |
(problem-state [this])) | |
(extend-protocol IProblem | |
Optimizable$ByGradientValue | |
(problem-state [^Optimizable$ByGradientValue this] | |
(let [n (.getNumParameters this) | |
params (double-array n) | |
grad (double-array n)] | |
(.getParameters this params) | |
(.getValueGradient this grad) | |
{:params (into [] params) | |
:value (.getValue this) | |
:gradient (into [] grad)}))) | |
;; XXX is there a better way to do this? | |
(defn- coll->array! [coll arr] | |
(doall (map-indexed #(aset-double arr %1 %2) coll))) | |
(defn problem [f g initial-params] | |
(let [params (double-array initial-params)] | |
(reify Optimizable$ByGradientValue | |
(getNumParameters [this] | |
(alength params)) | |
(getParameters [this out] | |
(System/arraycopy params 0 out 0 (alength params))) | |
(getParameter [this idx] | |
(aget params idx)) | |
(setParameters [this new-params] | |
(System/arraycopy new-params 0 params 0 (alength params))) | |
(setParameter [this idx param] | |
(aset-double params idx param)) | |
(getValue [this] | |
(f params)) | |
(getValueGradient [this out] | |
(coll->array! (g params) out))))) | |
(defn ->evaluator [f] | |
(reify cc.mallet.optimize.OptimizerEvaluator$ByGradient | |
(evaluate [_ maxable iter] | |
(f maxable iter)))) | |
(defn- optimize! [optimizer problem {:keys [max-iterations tolerance]}] | |
(when tolerance (.setTolerance optimizer tolerance)) | |
(let [it (atom 0) | |
evaluator (fn [m i] (do (reset! it i) true)) | |
_ (.setEvaluator optimizer (->evaluator evaluator)) | |
[conv? e] (try | |
[(if max-iterations | |
(.optimize optimizer max-iterations) | |
(.optimize optimizer)) | |
nil] | |
;; This exception may be thrown if L-BFGS cannot step in the | |
;; current direction. This condition does not necessarily | |
;; mean that the optimizer has failed, but it doesn't want | |
;; to claim to have succeeded... | |
;; XXX this is bad. we're probably swallowing real | |
;; exceptions, too. | |
(catch IllegalArgumentException e | |
[false e]))] | |
(merge (problem-state problem) | |
{:converged? conv? | |
:num-iterations @it} | |
(when e {:exception e})))) | |
(defn lbfgs! [problem & {:as args}] | |
(optimize! (LimitedMemoryBFGS. problem) problem args)) | |
(defn conjugate-gradient! [problem & {:keys [step-size] :as args}] | |
(let [optimizer (if step-size | |
(ConjugateGradient. problem step-size) | |
(ConjugateGradient. problem))] | |
(optimize! optimizer problem args))) | |
(defn gradient-ascent! | |
[problem & {:keys [step-size] :as args}] | |
(let [optimizer (GradientAscent. problem)] | |
(when step-size (.setInitialStepSize optimizer step-size)) | |
(optimize! optimizer problem args))) | |
;;;;;;;;;;;;;;;;;;;;;; | |
;; see http://mallet.cs.umass.edu/optimization.php | |
(defn problem-1 [] | |
(let [val-fn (fn [[x y]] | |
(+ | |
(* -3 x x) | |
(* -4 y y) | |
(* 2 x) | |
(* -4 y) | |
18)) | |
grad-fn (fn [[x y]] | |
[(-> x (* -6) (+ 2)) | |
(-> y (* -8) (- 4))])] | |
(problem val-fn grad-fn [0 0]))) | |
(defn sqr [x] | |
(* x x)) | |
;; see http://www.cas.mcmaster.ca/~cs4te3/tutorials/BFGS.pdf | |
(defn problem-2 [] | |
(let [val-fn (fn [[x y]] | |
(- | |
(+ (Math/exp (dec x)) | |
(Math/exp (inc (- y))) | |
(sqr (- x y))))) | |
grad-fn (fn [[x y]] | |
[(- | |
(+ (Math/exp (dec x)) | |
(* 2 (- x y)))) | |
(- | |
(+ (- (Math/exp (inc (- y)))) | |
(* -2 (- x y))))])] | |
(problem val-fn grad-fn [0 0]))) | |
(comment | |
(pprint (lm/lbfgs! (lm/problem-2)))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment