Skip to content

Instantly share code, notes, and snippets.

@zmaril
Last active December 16, 2015 14:19
Show Gist options
  • Save zmaril/5447488 to your computer and use it in GitHub Desktop.
Save zmaril/5447488 to your computer and use it in GitHub Desktop.
Simple rejection sampling based probabilistic programming library. An approximation of church.
(ns hacklheber.core)
(defn flip
"A function which returns true or false randomly. Can optionally be
supplied a number for a bias."
([] (> 0.5 (rand)))
([p] (> p (rand))))
(defn- memo-bangs
"If a variable is bound with a bang, then it will be memoized."
[[k v]]
(if (= \! (last (name k)))
[k `(memoize ~v)]
[k v]))
(defn- find-clause
"Given a list of clauses and a key, this finds the body of the first
clause which has the same keyword as the given key."
[clauses key]
(->> clauses
(filter (fn [[k v]] (= key k)))
first
second))
(defmacro query-by-rejection
"Query a distribution via rejection. See further below for
examples."
[[& bindings] & body-expr]
(let [{clauses# true pairs# false} (group-by (comp keyword? first)
(partition-all 2 bindings))
where# (or (find-clause clauses# :where)
'(fn [] true))
memobound# (->> (find-clause clauses# :memobound)
(mapcat (juxt identity (fn [v] `(memoize ~v))))
vec)
pairs# (vec (mapcat memo-bangs pairs#))]
`(loop []
(let [[cond# result#] (binding ~memobound#
(let [~@pairs# cond# ~where#]
[cond# (when cond# (do ~@body-expr))]))]
(if cond#
result#
(recur))))))
(defn normalized-frequencies
"Takes in a collection and computes the normalizied frequenicies of
elements in the colleciton."
[col]
(let [freqs (frequencies col)
count (reduce + (map second freqs))
normalized (for [[k v] freqs] [k (float (/ v count))])]
(into {} normalized)))
(defmacro sample-by-rejection
"Takes in a number n and the body for a query-by-rejection. Executes
the specified query n times."
[n & body]
`(for [i# (range ~n)]
(query-by-rejection ~@body)))
;;Example queries taken from Church
;;http://projects.csail.mit.edu/church/wiki/Conditioning
(defn ^{:dynamic true
:doc "Taken from the Church examples."}
eye-color
[person]
(rand-nth '(blue green brown)))
;;Persistent randomized functions
(query-by-rejection
[bob-1 (eye-color :bob)
alice-1 (eye-color :alice)
bob-2 (eye-color :bob)
:where (flip 0.01)
:memobound [eye-color]]
[bob-1 alice-1 bob-2])
;;A complex query
(defn complex-samples []
(sample-by-rejection
10000
[works-in-hospital (flip 0.01)
smokes (flip 0.2)
lung-cancer (or (flip 0.01)
(and smokes (flip 0.02)))
TB (or (flip 0.005)
(and works-in-hospital (flip 0.01)))
cold (or (flip 0.2)
(and works-in-hospital (flip 0.25)))
stomach-flu (flip 0.1)
other (flip 0.1)
cough (or (and cold (flip 0.5))
(and lung-cancer (flip 0.3))
(and TB (flip 0.7))
(and other (flip 0.01)))
fever (or (and cold (flip 0.3))
(and stomach-flu (flip 0.5))
(and TB (flip 0.2))
(and other (flip 0.01)))
chest-pain (or (and lung-cancer (flip 0.4))
(and TB (flip 0.5))
(and other( flip 0.01)))
shortness-of-breath (or (and lung-cancer (flip 0.4))
(and TB (flip 0.5))
(and other (flip 0.01)))
:where (and cough chest-pain shortness-of-breath)]
(list lung-cancer TB)))
(defn ^{:dynamic true} strength
[person]
(if (flip) 10 5))
(defn lazy
[person]
(flip (/ 1 3)))
(defn contribution
[person]
(if (lazy person)
(/ (strength person) 2)
(strength person)))
(defn total-pulling
[team]
(->> team
(map contribution)
(reduce +)))
(defn winner [team1 team2]
(if (< (total-pulling team1) (total-pulling team2))
'team2 'team1))
;;Using persistent values defined outside the query
(defn tug-of-war-sample []
(sample-by-rejection
10000
[:memobound [strength]
:where (and (= 'team1 (winner '(bob mary) '(tom sue)))
(= 'team1 (winner '(bob sue) '(tom jim))))]
(strength 'bob)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment