Last active
January 11, 2020 07:37
-
-
Save TiarkRompf/ef47cdf03f2fe1c3481b280f2a7784ab to your computer and use it in GitHub Desktop.
Reverse-mode automatic differentiation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package autodiff | |
import org.scalatest._ | |
import org.scalatest.Assertions._ | |
import scala.collection.mutable._ | |
class AutoDiffSpec extends FunSuite { | |
// computation graph (list of nodes) | |
val graph = new ArrayBuffer[Exp] | |
def emit(e: Exp): Int = { | |
val i = graph.indexOf(e) // hash consing / common subexpression elim | |
if (i >= 0) i else try graph.size finally graph += e | |
} | |
def input(x: Int) = 1*1000*1000 + x // an id not otherwise used (should check in emit) | |
// nodes refer to their inputs by index | |
type Sym = Int | |
abstract class Exp | |
case class Const(a: Double) extends Exp | |
case class Plus(a: Sym, b: Sym) extends Exp | |
case class Times(a: Sym, b: Sym) extends Exp | |
case class Div(a: Sym, b: Sym) extends Exp | |
case class Log(a: Sym) extends Exp | |
case class Sin(a: Sym) extends Exp | |
case class Cos(a: Sym) extends Exp | |
// smart constructors (LMS-style) | |
def const(x: Double) = this emit Const(x) | |
def plus(x: Double, y: Double): Double = x + y | |
def plus(x: Double, y: Sym): Sym = plus(const(x), y) | |
def plus(x: Sym, y: Double): Sym = plus(x, const(y)) | |
def plus(x: Sym, y: Sym): Sym = this emit Plus(x,y) | |
def times(x: Double, y: Double): Double = x * y | |
def times(x: Double, y: Sym): Sym = times(const(x), y) | |
def times(x: Sym, y: Double): Sym = times(x, const(y)) | |
def times(x: Sym, y: Sym): Sym = | |
if (x < graph.size && graph(x) == Const(1)) y | |
else if (y < graph.size && graph(y) == Const(1)) x | |
else this emit Times(x,y) | |
def div(x: Double, y: Double): Double = x / y | |
def div(x: Double, y: Sym): Sym = div(const(x), y) | |
def div(x: Sym, y: Sym): Sym = this emit Div(x,y) | |
def log(x: Sym) = emit(Log(x)) | |
def sin(x: Sym) = emit(Sin(x)) | |
def cos(x: Sym) = emit(Cos(x)) | |
def printGraph() = { | |
println("IR Graph:") | |
for ((e,i) <- graph.zipWithIndex) { | |
println(s"x$i = $e") | |
} | |
} | |
// compute derivatives, starting from result node y ("dependent var") | |
// return a mapping from node id x to ∂x/∂y | |
def backprop(y: Sym): Map[Sym,Sym] = backprop(Map(y -> const(1))) | |
def backprop(adj: Map[Sym,Sym]): Map[Sym,Sym] = { | |
println("Backprop:") | |
def add(i: Sym, s: Sym) = { | |
if (adj.contains(i)) | |
adj(i) = plus(adj(i), s) | |
else | |
adj(i) = s | |
} | |
val ins = graph.zipWithIndex.toList.reverse // reverse mode | |
for ((e,i) <- ins if adj contains i) { | |
def add1(j: Sym, s: Sym) = add(j, times(adj(i),s)) | |
val dep = e match { | |
case Const(x) => | |
Nil | |
case Plus(x,y) => | |
add1(x, const(1)); add1(y, const(1)) | |
List((x,s"1"),(y,s"1")) | |
case Times(x,y) => | |
add1(x, y); add1(y, x) | |
List((x,s"x$y"),(y,s"x$x")) | |
case Div(x,y) => | |
??? //List(x,y) | |
case Log(x) => | |
add1(x,div(const(1),x)) | |
List((x,s"1/x$x")) | |
case Sin(x) => | |
add1(x,cos(x)) | |
List((x,s"cos(x$x)")) | |
case Cos(x) => | |
add1(x,times(const(-1),sin(x))) | |
List((x,s"-1*cos(x$x)")) | |
} | |
for ((d,e) <- dep) { | |
println(s"x$d' += x$i' * ∂x$i/∂x$d = x$i' * $e") | |
} | |
} | |
println("---") | |
printGraph() | |
println("Adjoints:") | |
for ((k,v) <- adj) { | |
println(s"x$k' = $v") | |
} | |
adj | |
} | |
// evaluate result y (mutable env: memoization) | |
def eval(y: Sym)(implicit env: Map[Sym,Double]): Double = { | |
env.getOrElse(y, { | |
println(graph(y)) | |
val r = graph(y) match { | |
case Const(x) => x | |
case Plus(x,y) => eval(x) + eval(y) | |
case Times(x,y) => eval(x) * eval(y) | |
case Div(x,y) => eval(x) / eval(y) | |
case Log(x) => math.log(eval(x)) | |
case Sin(x) => math.sin(eval(x)) | |
case Cos(x) => math.cos(eval(x)) | |
} | |
env(y) = r | |
r | |
}) | |
} | |
// ----- test cases ----- | |
test("test1 - deriv") { // x*x -> 2*x | |
def f(x1: Sym, x2: Sym) = times(x1,x2) | |
graph += Const(0) | |
val x1 = input(1) | |
val y = f(x1, x1) | |
val nine = eval(y)(Map(x1 -> 3.0)) | |
assert(nine == 9.0) | |
val adj = backprop(y) | |
val six = eval(adj(x1))(Map(x1 -> 3.0)) | |
assert(six == 6.0) | |
graph.clear | |
} | |
test("test2 - 2nd order deriv") { // second order deriv | |
def f(x1: Sym, x2: Sym, x3: Sym) = times(x1,times(x2,x3)) | |
graph += Const(0) | |
val x1 = input(1) | |
val y = f(x1, x1, x1) | |
println(s"y: $y") | |
val sixtyFour = eval(y)(Map(x1 -> 4.0)) | |
assert(sixtyFour == 64.0) | |
val adj = backprop(y) | |
val y1 = adj(x1) | |
val fortyEight = eval(y1)(Map(x1 -> 4.0)) | |
assert(fortyEight == 48.0) | |
val adj2 = backprop(y1) | |
val y2 = adj2(x1) | |
val twentyFour = eval(y2)(Map(x1 -> 4.0)) | |
assert(twentyFour == 24.0) | |
graph.clear | |
} | |
test("test3 - 2D grad") { | |
def f(x1: Sym, x2: Sym) = plus(plus(log(x1), times(x1, x2)), sin(x2)) // - sin in paper | |
graph += Const(0) | |
val x1 = input(1); val x2 = input(2) | |
val y = f(x1, x2) | |
printGraph() | |
backprop(y) | |
// XXX not currently testing anything | |
graph.clear | |
} | |
test("test4 - opt") { // from: http://diffsharp.github.io/DiffSharp/examples-gradientdescent.html | |
// evaluate a function | |
def fun(f: (Sym,Sym) => Sym) = { | |
val x1 = input(1); val x2 = input(2) | |
val y = f(x1, x2) | |
(x0: (Double,Double)) => | |
eval(y)(Map(x1 -> x0._1, x2 -> x0._2)) | |
} | |
// evaluate a function's gradient | |
def grad(f: (Sym,Sym) => Sym) = { | |
val x1 = input(1); val x2 = input(2) | |
val y = f(x1, x2) | |
val adj = backprop(y) | |
println(adj) | |
(x0: (Double,Double)) => | |
(eval(adj(x1))(Map(x1 -> x0._1, x2 -> x0._2)), // FIXME: efficiency <- should eval together | |
eval(adj(x2))(Map(x1 -> x0._1, x2 -> x0._2))) | |
} | |
// gradient descent optimization | |
def gd(f: (Sym,Sym) => Sym, x0: (Double,Double), eta: Double, epsilon: Double) = { | |
def desc(x: (Double,Double)): (Double,Double) = { | |
def l2norm(x: (Double,Double)) = math.sqrt(x._1*x._1 + x._2*x._2) | |
val g = grad(f)(x) | |
val x1 = (x._1 - eta * g._1, x._2 - eta * g._2) | |
if (l2norm(g) < epsilon) x else desc(x1) | |
} | |
desc(x0) | |
} | |
def f(x1: Sym, x2: Sym) = plus(sin(x1), cos(x2)) | |
val xmin = gd(f, (1.0,1.0), 0.9, 0.00001) | |
assert(xmin == (-1.5707907586270724,3.141591963803541)) // (-π/2, π) | |
val fxmin = fun(f)(xmin) | |
assert(fxmin == -1.9999999999842597) | |
println(xmin) | |
println(fxmin) | |
graph.clear | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment