Skip to content

Instantly share code, notes, and snippets.

@emilk
Created November 8, 2020 19:51
Show Gist options
  • Save emilk/c027311e5d0e8b69953c83a3ec283b74 to your computer and use it in GitHub Desktop.
Save emilk/c027311e5d0e8b69953c83a3ec283b74 to your computer and use it in GitHub Desktop.
Dual numbers in Rust, for automatic differentiation
use std::{
fmt,
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
pub trait Scalar:
fmt::Debug
+ fmt::Display
+ std::iter::Sum
+ Copy
+ Clone
+ Default
+ From<f64>
+ Add<f64, Output = Self>
+ Add<Self, Output = Self>
+ AddAssign<f64>
+ AddAssign<Self>
+ Div<f64, Output = Self>
+ Div<Self, Output = Self>
+ DivAssign<f64>
+ DivAssign<Self>
+ Mul<f64, Output = Self>
+ Mul<Self, Output = Self>
+ Neg<Output = Self>
+ Sub<f64, Output = Self>
+ Sub<Self, Output = Self>
+ SubAssign<f64>
+ SubAssign<Self>
+ PartialOrd<Self>
+ PartialOrd<f64>
{
// Same order as in https://doc.rust-lang.org/std/primitive.f64.html
// Some yet to be implemented
fn abs(self) -> Self;
fn mul_add(self, a: Self, b: Self) -> Self;
fn powi(self, n: i32) -> Self;
fn powf(self, n: f64) -> Self;
fn pow(self, n: Self) -> Self;
fn sqrt(self) -> Self;
fn exp(self) -> Self;
fn exp2(self) -> Self;
fn ln(self) -> Self;
fn log(self, base: f64) -> Self;
fn log2(self) -> Self;
fn log10(self) -> Self;
fn cbrt(self) -> Self;
fn hypot(self, other: Self) -> Self;
// fn sin(self) -> Self;
// fn cos(self) -> Self;
// fn tan(self) -> Self;
// fn asin(self) -> Self;
// fn acos(self) -> Self;
// fn atan(self) -> Self;
// fn atan2(self, other: Self) -> Self;
fn is_nan(self) -> bool;
fn is_finite(self) -> bool;
fn recip(self) -> Self;
}
#[rustfmt::skip]
impl Scalar for f64 {
#[inline] fn abs(self) -> Self { f64::abs(self) }
#[inline] fn mul_add(self, a: Self, b: Self) -> Self { f64::mul_add(self, a, b) }
#[inline] fn powi(self, n: i32) -> Self { f64::powi(self, n) }
#[inline] fn powf(self, n: f64) -> Self { f64::powf(self, n) }
#[inline] fn pow(self, n: Self) -> Self { f64::powf(self, n) }
#[inline] fn sqrt(self) -> Self { f64::sqrt(self) }
#[inline] fn exp(self) -> Self { f64::exp(self) }
#[inline] fn exp2(self) -> Self { f64::exp2(self) }
#[inline] fn ln(self) -> Self { f64::ln(self) }
#[inline] fn log(self, base: f64) -> Self { f64::log(self, base) }
#[inline] fn log2(self) -> Self { f64::log2(self) }
#[inline] fn log10(self) -> Self { f64::log10(self) }
#[inline] fn cbrt(self) -> Self { f64::cbrt(self) }
#[inline] fn hypot(self, other: Self) -> Self { f64::hypot(self, other) }
#[inline] fn is_nan(self) -> bool { f64::is_nan(self) }
#[inline] fn is_finite(self) -> bool { f64::is_finite(self) }
#[inline] fn recip(self) -> Self { f64::recip(self) }
}
// ----------------------------------------------------------------------------
/// A dual number, defined as e^2 = 0
#[derive(Copy, Clone, Default, PartialEq, PartialOrd)]
pub struct Dual {
pub x: f64,
/// epsilon, or can be thought of as dx.
/// With const generics, we should change this into a vector
/// so that we can differentiate on several variables in one go.
pub e: f64,
}
impl Dual {
pub fn constant(x: f64) -> Dual {
Dual { x, e: 0.0 }
}
}
impl fmt::Debug for Dual {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} + {}ε", self.x, self.e)
}
}
impl fmt::Display for Dual {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} + {}ε", self.x, self.e)
}
}
impl From<f64> for Dual {
#[inline]
fn from(x: f64) -> Dual {
Dual { x, e: 0.0 }
}
}
impl PartialEq<f64> for Dual {
fn eq(&self, other: &f64) -> bool {
self.x.eq(other)
}
fn ne(&self, other: &f64) -> bool {
self.x.ne(other)
}
}
impl PartialOrd<f64> for Dual {
fn partial_cmp(&self, other: &f64) -> Option<std::cmp::Ordering> {
self.x.partial_cmp(other)
}
}
impl std::iter::Sum for Dual {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = Dual>,
{
let mut sum = Self::default();
for item in iter {
sum += item;
}
sum
}
}
impl Scalar for Dual {
#[inline]
fn abs(self) -> Self {
if self.x < 0.0 {
-self
} else {
self
}
}
#[inline]
fn mul_add(self, a: Self, b: Self) -> Self {
Dual {
x: self.x.mul_add(a.x, b.x),
e: self.x * a.e + self.e * a.x + b.e,
}
}
fn powi(self, n: i32) -> Self {
Dual {
x: self.x.powi(n),
e: self.e * (n as f64) * self.x.powi(n - 1),
}
}
fn powf(self, n: f64) -> Self {
Dual {
x: self.x.powf(n),
e: self.e * n * self.x.powf(n - 1.0),
}
}
fn pow(self, n: Self) -> Self {
if n.e == 0.0 {
self.powf(n.x)
} else if self.x == 0.0 && 1.0 < n.x {
Dual { x: 0.0, e: 0.0 }
} else if self.x == 0.0 && n.x == 1.0 {
self
} else {
let x = self.x.powf(n.x);
Dual {
x,
e: self.e * self.x.powf(n.x - 1.0) * n.x + self.x.ln() * x * n.e,
}
}
}
#[inline]
fn sqrt(self) -> Self {
let x = self.x.sqrt();
Dual {
x,
e: self.e / (x * 2.0),
}
}
#[inline]
fn exp(self) -> Self {
let x = self.x.exp();
Dual { x, e: self.e * x }
}
#[inline]
fn exp2(self) -> Self {
let x = self.x.exp2();
Dual {
x,
e: self.e * x * 2.0.ln(),
}
}
#[inline]
fn ln(self) -> Self {
Dual {
x: self.x.ln(),
e: self.e / self.x,
}
}
#[inline]
fn log(self, base: f64) -> Self {
Dual {
x: self.x.log(base),
e: self.e / (self.x * base.ln()),
}
}
#[inline]
fn log2(self) -> Self {
Dual {
x: self.x.log2(),
e: self.e / (self.x * 2.0.ln()),
}
}
#[inline]
fn log10(self) -> Self {
Dual {
x: self.x.log10(),
e: self.e / (self.x * 10.0.ln()),
}
}
#[inline]
fn cbrt(self) -> Self {
Dual {
x: self.x.cbrt(),
e: self.e / (3.0 * (self.x * self.x).cbrt()),
}
}
#[inline]
fn hypot(self, y: Self) -> Self {
let x = self.x.hypot(y.x);
Dual {
x,
e: (self.e * self.x + y.e * y.x) / x,
}
}
#[inline]
fn is_nan(self) -> bool {
self.x.is_nan() || self.e.is_nan()
}
#[inline]
fn is_finite(self) -> bool {
self.x.is_finite() && self.e.is_finite()
}
/// 1 / self
#[inline]
fn recip(self) -> Self {
Dual {
x: self.x.recip(),
e: -self.e / (self.x * self.x),
}
}
}
// ----------------------------------------------------------------------------
impl Add<f64> for Dual {
type Output = Dual;
#[inline]
fn add(self, rhs: f64) -> Self::Output {
Dual {
x: self.x + rhs,
e: self.e,
}
}
}
impl Add<Dual> for Dual {
type Output = Dual;
#[inline]
fn add(self, rhs: Dual) -> Self::Output {
Dual {
x: self.x + rhs.x,
e: self.e + rhs.e,
}
}
}
impl AddAssign<f64> for Dual {
#[inline]
fn add_assign(&mut self, rhs: f64) {
self.x += rhs;
}
}
impl AddAssign<Dual> for Dual {
#[inline]
fn add_assign(&mut self, rhs: Dual) {
self.x += rhs.x;
self.e += rhs.e;
}
}
impl Div<f64> for Dual {
type Output = Dual;
#[inline]
fn div(self, rhs: f64) -> Self::Output {
Dual {
x: self.x / rhs,
e: self.e / rhs,
}
}
}
impl Div<Dual> for Dual {
type Output = Dual;
#[inline]
fn div(self, rhs: Dual) -> Self::Output {
Dual {
x: self.x / rhs.x,
e: (self.e - self.x * rhs.e / rhs.x) / rhs.x,
}
}
}
impl DivAssign<f64> for Dual {
#[inline]
fn div_assign(&mut self, rhs: f64) {
self.x /= rhs;
self.e /= rhs;
}
}
impl DivAssign<Dual> for Dual {
#[inline]
fn div_assign(&mut self, rhs: Dual) {
*self = *self / rhs;
}
}
impl Mul<f64> for Dual {
type Output = Dual;
#[inline]
fn mul(self, rhs: f64) -> Self::Output {
Dual {
x: self.x * rhs,
e: self.e * rhs,
}
}
}
impl Mul<Dual> for Dual {
type Output = Dual;
#[inline]
fn mul(self, rhs: Dual) -> Self::Output {
Dual {
x: self.x * rhs.x,
e: self.x * rhs.e + self.e * rhs.x,
}
}
}
impl MulAssign<f64> for Dual {
#[inline]
fn mul_assign(&mut self, rhs: f64) {
self.x *= rhs;
self.e *= rhs;
}
}
impl MulAssign<Dual> for Dual {
#[inline]
fn mul_assign(&mut self, rhs: Dual) {
*self = *self * rhs;
}
}
impl Neg for Dual {
type Output = Dual;
#[inline]
fn neg(self) -> Self::Output {
Dual {
x: -self.x,
e: -self.e,
}
}
}
impl Sub<f64> for Dual {
type Output = Dual;
#[inline]
fn sub(self, rhs: f64) -> Self::Output {
Dual {
x: self.x - rhs,
e: self.e,
}
}
}
impl Sub<Dual> for Dual {
type Output = Dual;
#[inline]
fn sub(self, rhs: Dual) -> Self::Output {
Dual {
x: self.x - rhs.x,
e: self.e - rhs.e,
}
}
}
impl SubAssign<f64> for Dual {
#[inline]
fn sub_assign(&mut self, rhs: f64) {
self.x -= rhs;
}
}
impl SubAssign<Dual> for Dual {
#[inline]
fn sub_assign(&mut self, rhs: Dual) {
self.x -= rhs.x;
self.e -= rhs.e;
}
}
// ----------------------------------------------------------------------------
/// Evaluates the derivative of `f` at `x`
pub fn diff<F>(f: F, x: f64) -> f64
where
F: FnOnce(Dual) -> Dual,
{
f(Dual { x, e: 1.0 }).e
}
/// Evaluates the gradient of `f` at `x`
pub fn grad<F>(f: F, x: &[f64]) -> Vec<f64>
where
F: Fn(&[Dual]) -> Dual,
{
let mut x: Vec<Dual> = x.iter().map(|&x| x.into()).collect();
let mut results = Vec::new();
for i in 0..x.len() {
x[i].e = 1.0;
results.push(f(&x).e);
x[i].e = 0.0;
}
results
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment