Created
November 8, 2020 19:51
-
-
Save emilk/c027311e5d0e8b69953c83a3ec283b74 to your computer and use it in GitHub Desktop.
Dual numbers in Rust, for automatic differentiation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
fmt, | |
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, | |
}; | |
pub trait Scalar: | |
fmt::Debug | |
+ fmt::Display | |
+ std::iter::Sum | |
+ Copy | |
+ Clone | |
+ Default | |
+ From<f64> | |
+ Add<f64, Output = Self> | |
+ Add<Self, Output = Self> | |
+ AddAssign<f64> | |
+ AddAssign<Self> | |
+ Div<f64, Output = Self> | |
+ Div<Self, Output = Self> | |
+ DivAssign<f64> | |
+ DivAssign<Self> | |
+ Mul<f64, Output = Self> | |
+ Mul<Self, Output = Self> | |
+ Neg<Output = Self> | |
+ Sub<f64, Output = Self> | |
+ Sub<Self, Output = Self> | |
+ SubAssign<f64> | |
+ SubAssign<Self> | |
+ PartialOrd<Self> | |
+ PartialOrd<f64> | |
{ | |
// Same order as in https://doc.rust-lang.org/std/primitive.f64.html | |
// Some yet to be implemented | |
fn abs(self) -> Self; | |
fn mul_add(self, a: Self, b: Self) -> Self; | |
fn powi(self, n: i32) -> Self; | |
fn powf(self, n: f64) -> Self; | |
fn pow(self, n: Self) -> Self; | |
fn sqrt(self) -> Self; | |
fn exp(self) -> Self; | |
fn exp2(self) -> Self; | |
fn ln(self) -> Self; | |
fn log(self, base: f64) -> Self; | |
fn log2(self) -> Self; | |
fn log10(self) -> Self; | |
fn cbrt(self) -> Self; | |
fn hypot(self, other: Self) -> Self; | |
// fn sin(self) -> Self; | |
// fn cos(self) -> Self; | |
// fn tan(self) -> Self; | |
// fn asin(self) -> Self; | |
// fn acos(self) -> Self; | |
// fn atan(self) -> Self; | |
// fn atan2(self, other: Self) -> Self; | |
fn is_nan(self) -> bool; | |
fn is_finite(self) -> bool; | |
fn recip(self) -> Self; | |
} | |
#[rustfmt::skip] | |
impl Scalar for f64 { | |
#[inline] fn abs(self) -> Self { f64::abs(self) } | |
#[inline] fn mul_add(self, a: Self, b: Self) -> Self { f64::mul_add(self, a, b) } | |
#[inline] fn powi(self, n: i32) -> Self { f64::powi(self, n) } | |
#[inline] fn powf(self, n: f64) -> Self { f64::powf(self, n) } | |
#[inline] fn pow(self, n: Self) -> Self { f64::powf(self, n) } | |
#[inline] fn sqrt(self) -> Self { f64::sqrt(self) } | |
#[inline] fn exp(self) -> Self { f64::exp(self) } | |
#[inline] fn exp2(self) -> Self { f64::exp2(self) } | |
#[inline] fn ln(self) -> Self { f64::ln(self) } | |
#[inline] fn log(self, base: f64) -> Self { f64::log(self, base) } | |
#[inline] fn log2(self) -> Self { f64::log2(self) } | |
#[inline] fn log10(self) -> Self { f64::log10(self) } | |
#[inline] fn cbrt(self) -> Self { f64::cbrt(self) } | |
#[inline] fn hypot(self, other: Self) -> Self { f64::hypot(self, other) } | |
#[inline] fn is_nan(self) -> bool { f64::is_nan(self) } | |
#[inline] fn is_finite(self) -> bool { f64::is_finite(self) } | |
#[inline] fn recip(self) -> Self { f64::recip(self) } | |
} | |
// ---------------------------------------------------------------------------- | |
/// A dual number, defined as e^2 = 0 | |
#[derive(Copy, Clone, Default, PartialEq, PartialOrd)] | |
pub struct Dual { | |
pub x: f64, | |
/// epsilon, or can be thought of as dx. | |
/// With const generics, we should change this into a vector | |
/// so that we can differentiate on several variables in one go. | |
pub e: f64, | |
} | |
impl Dual { | |
pub fn constant(x: f64) -> Dual { | |
Dual { x, e: 0.0 } | |
} | |
} | |
impl fmt::Debug for Dual { | |
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
write!(f, "{} + {}ε", self.x, self.e) | |
} | |
} | |
impl fmt::Display for Dual { | |
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | |
write!(f, "{} + {}ε", self.x, self.e) | |
} | |
} | |
impl From<f64> for Dual { | |
#[inline] | |
fn from(x: f64) -> Dual { | |
Dual { x, e: 0.0 } | |
} | |
} | |
impl PartialEq<f64> for Dual { | |
fn eq(&self, other: &f64) -> bool { | |
self.x.eq(other) | |
} | |
fn ne(&self, other: &f64) -> bool { | |
self.x.ne(other) | |
} | |
} | |
impl PartialOrd<f64> for Dual { | |
fn partial_cmp(&self, other: &f64) -> Option<std::cmp::Ordering> { | |
self.x.partial_cmp(other) | |
} | |
} | |
impl std::iter::Sum for Dual { | |
fn sum<I>(iter: I) -> Self | |
where | |
I: Iterator<Item = Dual>, | |
{ | |
let mut sum = Self::default(); | |
for item in iter { | |
sum += item; | |
} | |
sum | |
} | |
} | |
impl Scalar for Dual { | |
#[inline] | |
fn abs(self) -> Self { | |
if self.x < 0.0 { | |
-self | |
} else { | |
self | |
} | |
} | |
#[inline] | |
fn mul_add(self, a: Self, b: Self) -> Self { | |
Dual { | |
x: self.x.mul_add(a.x, b.x), | |
e: self.x * a.e + self.e * a.x + b.e, | |
} | |
} | |
fn powi(self, n: i32) -> Self { | |
Dual { | |
x: self.x.powi(n), | |
e: self.e * (n as f64) * self.x.powi(n - 1), | |
} | |
} | |
fn powf(self, n: f64) -> Self { | |
Dual { | |
x: self.x.powf(n), | |
e: self.e * n * self.x.powf(n - 1.0), | |
} | |
} | |
fn pow(self, n: Self) -> Self { | |
if n.e == 0.0 { | |
self.powf(n.x) | |
} else if self.x == 0.0 && 1.0 < n.x { | |
Dual { x: 0.0, e: 0.0 } | |
} else if self.x == 0.0 && n.x == 1.0 { | |
self | |
} else { | |
let x = self.x.powf(n.x); | |
Dual { | |
x, | |
e: self.e * self.x.powf(n.x - 1.0) * n.x + self.x.ln() * x * n.e, | |
} | |
} | |
} | |
#[inline] | |
fn sqrt(self) -> Self { | |
let x = self.x.sqrt(); | |
Dual { | |
x, | |
e: self.e / (x * 2.0), | |
} | |
} | |
#[inline] | |
fn exp(self) -> Self { | |
let x = self.x.exp(); | |
Dual { x, e: self.e * x } | |
} | |
#[inline] | |
fn exp2(self) -> Self { | |
let x = self.x.exp2(); | |
Dual { | |
x, | |
e: self.e * x * 2.0.ln(), | |
} | |
} | |
#[inline] | |
fn ln(self) -> Self { | |
Dual { | |
x: self.x.ln(), | |
e: self.e / self.x, | |
} | |
} | |
#[inline] | |
fn log(self, base: f64) -> Self { | |
Dual { | |
x: self.x.log(base), | |
e: self.e / (self.x * base.ln()), | |
} | |
} | |
#[inline] | |
fn log2(self) -> Self { | |
Dual { | |
x: self.x.log2(), | |
e: self.e / (self.x * 2.0.ln()), | |
} | |
} | |
#[inline] | |
fn log10(self) -> Self { | |
Dual { | |
x: self.x.log10(), | |
e: self.e / (self.x * 10.0.ln()), | |
} | |
} | |
#[inline] | |
fn cbrt(self) -> Self { | |
Dual { | |
x: self.x.cbrt(), | |
e: self.e / (3.0 * (self.x * self.x).cbrt()), | |
} | |
} | |
#[inline] | |
fn hypot(self, y: Self) -> Self { | |
let x = self.x.hypot(y.x); | |
Dual { | |
x, | |
e: (self.e * self.x + y.e * y.x) / x, | |
} | |
} | |
#[inline] | |
fn is_nan(self) -> bool { | |
self.x.is_nan() || self.e.is_nan() | |
} | |
#[inline] | |
fn is_finite(self) -> bool { | |
self.x.is_finite() && self.e.is_finite() | |
} | |
/// 1 / self | |
#[inline] | |
fn recip(self) -> Self { | |
Dual { | |
x: self.x.recip(), | |
e: -self.e / (self.x * self.x), | |
} | |
} | |
} | |
// ---------------------------------------------------------------------------- | |
impl Add<f64> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn add(self, rhs: f64) -> Self::Output { | |
Dual { | |
x: self.x + rhs, | |
e: self.e, | |
} | |
} | |
} | |
impl Add<Dual> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn add(self, rhs: Dual) -> Self::Output { | |
Dual { | |
x: self.x + rhs.x, | |
e: self.e + rhs.e, | |
} | |
} | |
} | |
impl AddAssign<f64> for Dual { | |
#[inline] | |
fn add_assign(&mut self, rhs: f64) { | |
self.x += rhs; | |
} | |
} | |
impl AddAssign<Dual> for Dual { | |
#[inline] | |
fn add_assign(&mut self, rhs: Dual) { | |
self.x += rhs.x; | |
self.e += rhs.e; | |
} | |
} | |
impl Div<f64> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn div(self, rhs: f64) -> Self::Output { | |
Dual { | |
x: self.x / rhs, | |
e: self.e / rhs, | |
} | |
} | |
} | |
impl Div<Dual> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn div(self, rhs: Dual) -> Self::Output { | |
Dual { | |
x: self.x / rhs.x, | |
e: (self.e - self.x * rhs.e / rhs.x) / rhs.x, | |
} | |
} | |
} | |
impl DivAssign<f64> for Dual { | |
#[inline] | |
fn div_assign(&mut self, rhs: f64) { | |
self.x /= rhs; | |
self.e /= rhs; | |
} | |
} | |
impl DivAssign<Dual> for Dual { | |
#[inline] | |
fn div_assign(&mut self, rhs: Dual) { | |
*self = *self / rhs; | |
} | |
} | |
impl Mul<f64> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn mul(self, rhs: f64) -> Self::Output { | |
Dual { | |
x: self.x * rhs, | |
e: self.e * rhs, | |
} | |
} | |
} | |
impl Mul<Dual> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn mul(self, rhs: Dual) -> Self::Output { | |
Dual { | |
x: self.x * rhs.x, | |
e: self.x * rhs.e + self.e * rhs.x, | |
} | |
} | |
} | |
impl MulAssign<f64> for Dual { | |
#[inline] | |
fn mul_assign(&mut self, rhs: f64) { | |
self.x *= rhs; | |
self.e *= rhs; | |
} | |
} | |
impl MulAssign<Dual> for Dual { | |
#[inline] | |
fn mul_assign(&mut self, rhs: Dual) { | |
*self = *self * rhs; | |
} | |
} | |
impl Neg for Dual { | |
type Output = Dual; | |
#[inline] | |
fn neg(self) -> Self::Output { | |
Dual { | |
x: -self.x, | |
e: -self.e, | |
} | |
} | |
} | |
impl Sub<f64> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn sub(self, rhs: f64) -> Self::Output { | |
Dual { | |
x: self.x - rhs, | |
e: self.e, | |
} | |
} | |
} | |
impl Sub<Dual> for Dual { | |
type Output = Dual; | |
#[inline] | |
fn sub(self, rhs: Dual) -> Self::Output { | |
Dual { | |
x: self.x - rhs.x, | |
e: self.e - rhs.e, | |
} | |
} | |
} | |
impl SubAssign<f64> for Dual { | |
#[inline] | |
fn sub_assign(&mut self, rhs: f64) { | |
self.x -= rhs; | |
} | |
} | |
impl SubAssign<Dual> for Dual { | |
#[inline] | |
fn sub_assign(&mut self, rhs: Dual) { | |
self.x -= rhs.x; | |
self.e -= rhs.e; | |
} | |
} | |
// ---------------------------------------------------------------------------- | |
/// Evaluates the derivative of `f` at `x` | |
pub fn diff<F>(f: F, x: f64) -> f64 | |
where | |
F: FnOnce(Dual) -> Dual, | |
{ | |
f(Dual { x, e: 1.0 }).e | |
} | |
/// Evaluates the gradient of `f` at `x` | |
pub fn grad<F>(f: F, x: &[f64]) -> Vec<f64> | |
where | |
F: Fn(&[Dual]) -> Dual, | |
{ | |
let mut x: Vec<Dual> = x.iter().map(|&x| x.into()).collect(); | |
let mut results = Vec::new(); | |
for i in 0..x.len() { | |
x[i].e = 1.0; | |
results.push(f(&x).e); | |
x[i].e = 0.0; | |
} | |
results | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment