Skip to content

Instantly share code, notes, and snippets.

@adamnemecek
Forked from jliszka/Dual.scala
Created August 5, 2024 19:59
Show Gist options
  • Save adamnemecek/94465982e1e94a9727f8b74c8a45f7de to your computer and use it in GitHub Desktop.
Save adamnemecek/94465982e1e94a9727f8b74c8a45f7de to your computer and use it in GitHub Desktop.
abstract class Dual(val rank: Int) {
self =>
// Cell value accessor
protected def get(r: Int, c: Int): Double
// Memoizing cell value accessor
def apply(r: Int, c: Int): Double = memo.getOrElseUpdate(r - c, self.get(r, c))
// The memo table
private val memo = scala.collection.mutable.HashMap[Int, Double]()
def +(other: Dual): Dual = new Dual(rank) {
def get(r: Int, c: Int) = self(r, c) + other(r, c)
}
def -(other: Dual): Dual = new Dual(rank) {
def get(r: Int, c: Int) = self(r, c) - other(r, c)
}
def unary_-(): Dual = new Dual(rank) {
def get(r: Int, c: Int) = -self(r, c)
}
def *(other: Dual): Dual = new Dual(rank) {
def get(r: Int, c: Int) = (1 to rank).map(i => self(r, i) * other(i, c)).sum
}
def *(x: Double): Dual = new Dual(rank) {
def get(r: Int, c: Int) = self(r, c) * x
}
def /(other: Dual): Dual = self * other.inv
def /(x: Double): Dual = new Dual(rank) {
def get(r: Int, c: Int) = self(r, c) / x
}
def inv: Dual = {
val a = self(1, 1)
val D = self - I * a
val N = -D / a
List.iterate(I, rank)(_ * N).reduce(_ + _) / a
}
// An identity matrix of the same dimension as this one
lazy val I: Dual = new Dual(rank) {
def get(r: Int, c: Int) = if (r == c) 1 else 0
}
def pow(p: Int): Dual = {
def helper(b: Dual, e: Int, acc: Dual): Dual = {
if (e == 0) acc
else helper(b * b, e / 2, if (e % 2 == 0) acc else acc * b)
}
helper(self, p, self.I)
}
def exp: Dual = {
val a = self(1, 1)
val A = I * a
val N = self - A
val eA = I * math.exp(a)
val eN = List.iterate((I, 1), rank){ case (m, n) => (m * N / n, n+1) }.map(_._1).reduce(_ + _)
eA * eN
}
override def toString = {
(1 to rank).map(c => self(1, c)).mkString(" ")
}
}
class I(override val rank: Int) extends Dual(rank) {
def get(r: Int, c: Int) = if (r == c) 1 else 0
}
class D(override val rank: Int) extends Dual(rank) {
def get(r: Int, c: Int) = if (r + 1 == c) 1 else 0
}
object Examples {
val d = new Dual(10)
val one = d.I
def f(x: Dual): Dual = x.pow(4)
val f2 = f(one*2 + d)
def g(x: Dual): Dual = x.pow(2) * 4 / (one - x).pow(3)
val g3 = g(one*3 + d)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment