diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b7f6d72 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "ezdiff" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +num-traits = "0.2" diff --git a/README.md b/README.md new file mode 100644 index 0000000..cfd5325 --- /dev/null +++ b/README.md @@ -0,0 +1,78 @@ +# ezDiff + +A tiny forward automatic differentiation library for learning purposes. If you need something fully featured, check out [`hyperdual`](https://crates.io/crates/hyperdual). + +## What is automatic differentiation? + +AutoDiff is a way to automatically calculate derivatives. AutoDiff is *not* symbolic (computer algebra) or numerical differentiation, instead it computes derivatives at the same time as the regular values are being evaluated, by using dual numbers. + +### Dual Numbers + +They sound fancy, but it's essentially just using a tuple of values `(x, dx)`, instead of a single value. Instead of computing `f(x) = y` and passing in `x` to find `y`, we pass in a tuple `(x, dx)` and compute *both* `(y, dy/dx)` at the same time using operator overloading. + +### Operator Overloading + +Operator overloading is a technique where you *overload* mathematical operators in a programming language (`+`, `-`, `*`, `/`) with your own implementation. To do this in Rust, all we need to do is take our dual number type, and implement the math operation traits on it: `impl Add for Dual { ... }`. + +### How does it actually *work*? + +So, we've talked in abstract how we replace `x` when evaluating `f(x)` with a tuple `(x, dx)`, then do *something* to that number with operator overloading. This is where we get to the neat trick (mathematicians *hate* this one weird trick). The trick is conceptually simple - use the chain rule. + +The chain rule tells us that when we are trying to find the derivative of some really complicated function, we can split up the complicated function into many small, simple functions that are easy to evaluate. + +$\frac{d}{dx}[f(g(x))] = f'(g(x)) * g'(x)$ + +Let's say we have some function $y = cos(x^2)$. I don't know how to find the derivative of that, so let's split it up into some smaller functions I do know how to evaluate. Let's say: + +$f(g(x)) = cos(g(x))$ and $g(x) = x^2$ + +We no longer need to find the derivative of $cos(x^2)$! Now we only need to find the derivative of $cos(u)$ and $x^2$ separately, then mush them together. + +So, how do dual numbers come into play here? If you take a look at how we evaluated $cos(x^2)$ by breaking it into smaller parts, well, that's also how we evaluate a function normally! I can't easily compute $cos(x^2)$, but I can compute $x^2$, then plug it into $cos()$. We can take advantage of this to compute the "primal" (x) and the derivative (d/dx). Let's do this step-by-step: + +1. Let's start by defining a dual number as `Dual = (x, dx)`. +2. Plug it into the function we want to evaluate: `cos(Dual^2)` +3. We can begin evaluating by following the normal order of operations, and find the value of `Dual^2` + + `Dual^2 = (x^2, 2*x*dx)` + + Here we calculate the primal value on the left, and the derivative on the right. The primal is just, well, the normal operation, `x^2`. For the derivative, all we need to do is answer: what is the *derivative* of `x^2`? Here we can simply use the power rule: $\frac{d}{dx} x^n = nx^{n-1} dx$. Concretely, $\frac{d}{dx} x^2 = 2*x* dx$. + +4. We can continue with the order of operations and evaluate $cos(Dual)$, using the rules for derivatives of trig functions: $\frac{d}{dx}cos(x) = -sin(x)dx$. + + `Dual.cos() = (x.cos(), -x.sin() * dx)` + + Now, remember at this point `Dual` already has values inside it from when we evaluated `Dual^2`. When we substitute that in, we see that our `Dual` contains: + + `Dual.x = (x^2).cos()` + + `Dual.dx = -x.sin() * (2*x*dx)` + + Anywhere we see `x` or `dx`, we replace that with the previous value stored in the dual number. + +5. That's really all there is to it. If we plug in our initial conditions into `Dual`, let's say at $x = 5$ we start with `Dual(5, 1)`. Note the derivative always starts as $1$. + + `y = (x^2).cos() = 0.9912028118634736` + + `dy/dx = -x.sin() * (2*x*dx) = 1.3235175009777302` + +In this example, we only needed to know the derivative of $cos(x)$ and $x^2$. We encode this by overloading those operators (defining math operators for our `Dual` type), and letting the language evaluate our function normally. All we have to do is: + +```rs +let x = Dual::new(5.0); +let y = (x.pow(2.0)).cos(); +dbg!(y); +``` + +which returns + +``` +y = Dual { + val: 0.9912028118634736, + dot: 1.3235175009777302, +} +``` + +## Why is this useful? + +Instead of needing to iteratively approximate the value (numerical, finite differences), or attempt to find a symbolic representation of the derivative (computer algebra), we can compute the derivative damn near for free, and the compiler can optimize it inline with our code. Because the implementation is just defining the derivative as a math operation the same way the equation would normally be evaluated, the computational complexity of the derivative is proportional to the complexity of the original equation! \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..947b6ef --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,428 @@ +use std::{ + fmt::Debug, + ops::{Add, Div, Mul, Neg, Sub}, +}; + +use num_traits::{Float, Pow}; + +#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +pub struct Dual { + x: F, + dx: F, +} + +impl Dual { + #[inline] + pub fn new(val: F) -> Self { + Self { + x: val, + dx: F::one(), + } + } + + #[inline] + pub fn sqrt(self) -> Self { + self.pow(F::from(0.5).unwrap()) + } + + #[inline] + pub fn exp(self) -> Self { + Dual { + x: self.x.exp(), + dx: self.x.exp() * self.dx, + } + } + + #[inline] + pub fn ln(self) -> Self { + Dual { + x: self.x.ln(), + dx: self.x.powi(-1) * self.dx, + } + } + + #[inline] + pub fn log(self, base: F) -> Self { + Dual { + x: self.x.log(base), + dx: (base.ln() * self.x).powi(-1) * self.dx, + } + } + + #[inline] + pub fn sin(self) -> Self { + Dual { + x: self.x.sin(), + dx: self.x.cos() * self.dx, + } + } + + #[inline] + pub fn cos(self) -> Self { + Dual { + x: self.x.cos(), + dx: -self.x.sin() * self.dx, + } + } + + #[inline] + pub fn tan(self) -> Self { + Dual { + x: self.x.tan(), + dx: self.x.cos().powi(-2) * self.dx, + } + } + + #[inline] + pub fn asin(self) -> Self { + Dual { + x: self.x.asin(), + dx: (F::one() - self.x.powi(2)).sqrt().powi(-1) * self.dx, + } + } + + #[inline] + pub fn acos(self) -> Self { + Dual { + x: self.x.acos(), + dx: -(F::one() - self.x.powi(2)).sqrt().powi(-1) * self.dx, + } + } + + #[inline] + pub fn atan(self) -> Self { + Dual { + x: self.x.atan(), + dx: (F::one() + self.x.powi(2)).powi(-1) * self.dx, + } + } + + pub fn value(&self) -> F { + self.x + } + + pub fn derivative(&self) -> F { + self.dx + } +} + +impl Neg for Dual { + type Output = Dual; + + fn neg(self) -> Self::Output { + Dual { + x: self.x.neg(), + dx: self.dx.neg(), + } + } +} + +// Sum rule +impl Add for Dual { + type Output = Dual; + + fn add(self, rhs: Self) -> Self::Output { + Dual { + x: self.x + rhs.x, + dx: self.dx + rhs.dx, + } + } +} + +// Sum constant +impl Add for Dual { + type Output = Dual; + + fn add(self, rhs: F) -> Self::Output { + Dual { + x: self.x + rhs, + dx: self.dx, + } + } +} + +// Sum constant +impl Add> for f32 { + type Output = Dual; + + fn add(self, rhs: Dual) -> Self::Output { + Dual { + x: rhs.x + self, + dx: rhs.dx, + } + } +} + +// Sum constant +impl Add> for f64 { + type Output = Dual; + + fn add(self, rhs: Dual) -> Self::Output { + Dual { + x: rhs.x + self, + dx: rhs.dx, + } + } +} + +// Sum constant +impl Add> for (F,) { + type Output = Dual; + + fn add(self, rhs: Dual) -> Self::Output { + Dual { + x: rhs.x + self.0, + dx: rhs.dx, + } + } +} + +// Difference rule +impl Sub for Dual { + type Output = Dual; + + fn sub(self, rhs: Self) -> Self::Output { + Dual { + x: self.x - rhs.x, + dx: self.dx - rhs.dx, + } + } +} + +// Product rule +impl Mul for Dual { + type Output = Dual; + + fn mul(self, rhs: Dual) -> Self::Output { + Dual { + x: self.x * rhs.x, + dx: self.x * rhs.dx + rhs.x * self.dx, + } + } +} + +// Constant multiple rule +impl Mul for Dual { + type Output = Dual; + + fn mul(self, rhs: F) -> Self::Output { + Dual { + x: self.x * rhs, + dx: self.dx * rhs, + } + } +} + +// Constant multiple rule +impl Mul> for f32 { + type Output = Dual; + + fn mul(self, rhs: Dual) -> Self::Output { + Dual { + x: self * rhs.x, + dx: self * rhs.dx, + } + } +} + +// Constant multiple rule +impl Mul> for f64 { + type Output = Dual; + + fn mul(self, rhs: Dual) -> Self::Output { + Dual { + x: self * rhs.x, + dx: self * rhs.dx, + } + } +} + +// Quotient rule +impl Div for Dual { + type Output = Dual; + + fn div(self, rhs: Dual) -> Self::Output { + Dual { + x: self.x / rhs.x, + dx: (self.x * rhs.dx + rhs.x * self.dx) / (rhs.x * rhs.x), + } + } +} + +// Power rule +impl Pow for Dual { + type Output = Dual; + + fn pow(self, rhs: F) -> Self::Output { + Dual { + x: self.x.powf(rhs), + dx: rhs * self.x.powf(rhs - F::one()) * self.dx, // n * x^(n-1) * d/dx + } + } +} + +// Inverse(?) power rule a^x +impl Pow> for (F,) { + type Output = Dual; + + fn pow(self, rhs: Dual) -> Self::Output { + Dual { + x: self.0.powf(rhs.x), + dx: self.0.ln() * self.0.powf(rhs.x) * rhs.dx, + } + } +} + +// Inverse(?) power rule a^x +impl Pow> for f32 { + type Output = Dual; + + fn pow(self, rhs: Dual) -> Self::Output { + Dual { + x: self.powf(rhs.x), + dx: self.ln() * self.powf(rhs.x) * rhs.dx, + } + } +} + +// Inverse(?) power rule a^x +impl Pow> for f64 { + type Output = Dual; + + fn pow(self, rhs: Dual) -> Self::Output { + Dual { + x: self.powf(rhs.x), + dx: self.ln() * self.powf(rhs.x) * rhs.dx, + } + } +} + +#[macro_export] +macro_rules! dual { + ($a:expr) => {{ + Dual::new($a) + }}; +} + +impl From for Dual { + #[inline] + fn from(input: f32) -> Self { + Dual::new(input) + } +} + +impl From for Dual { + #[inline] + fn from(input: f64) -> Self { + Dual::new(input) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn simple() { + let x = dual!(3.0); + let y = x * x + 2.0; + assert_eq!(y.x, 11.0); + assert_eq!(y.dx, 6.0); + } + + #[test] + fn sin() { + let sin = |x: Dual<_>| x.sin(); + let y_1 = sin(dual!(1.0)); + assert_eq!(y_1.x, 0.8414709848078965); + assert_eq!(y_1.dx, 0.5403023058681398); + } + + #[test] + fn cos() { + let cos = |x: Dual<_>| x.cos(); + let y_1 = cos(dual!(1.0)); + assert_eq!(y_1.x, 0.5403023058681398); + assert_eq!(y_1.dx, -0.8414709848078965); + } + + #[test] + fn tan() { + let tan = |x: Dual<_>| x.tan(); + let y_1 = tan(dual!(1.0)); + assert_eq!(y_1.x, 1.5574077246549023); + assert_eq!(y_1.dx, 3.425518820814759); + } + + #[test] + fn asin() { + let asin = |x: Dual<_>| x.asin(); + let y_05 = asin(dual!(0.5)); + assert_eq!(y_05.x, 0.5235987755982989); + assert_eq!(y_05.dx, 1.1547005383792517); + } + + #[test] + fn acos() { + let acos = |x: Dual<_>| x.acos(); + let y_05 = acos(dual!(0.5)); + assert_eq!(y_05.x, 1.0471975511965979); + assert_eq!(y_05.dx, -1.1547005383792517); + } + + #[test] + fn atan() { + let atan = |x: Dual<_>| x.atan(); + let y_05 = atan(dual!(0.5)); + assert_eq!(y_05.x, 0.4636476090008061); + assert_eq!(y_05.dx, 0.8); + } + + #[test] + fn sqrt() { + let sqrt = |x: Dual<_>| x.sqrt(); + let y_1 = sqrt(dual!(1.0)); + assert_eq!(y_1.x, 1.0); + assert_eq!(y_1.dx, 0.5); + } + + #[test] + fn exp() { + let exp = |x: Dual<_>| x.exp(); + let y_1 = exp(dual!(1.0)); + assert_eq!(y_1.x, std::f32::consts::E); + assert_eq!(y_1.dx, std::f32::consts::E); + } + + #[test] + fn ln() { + let ln = |x: Dual<_>| x.ln(); + let y_2 = ln(dual!(2.0)); + assert_eq!(y_2.x, 0.6931471805599453); + assert_eq!(y_2.dx, 0.5); + } + + #[test] + fn log() { + let log = |x: Dual<_>| x.log(10.0); + let y_2 = log(dual!(2.0)); + assert_eq!(y_2.x, 0.30102999566398114); + assert_eq!(y_2.dx, 0.21714724095162588); + } + + #[test] + fn add_mul_consts() { + let f = |x: Dual| 1.0 + x * 3.0; + let y_2 = f(dual!(2.0)); + assert_eq!(y_2.x, 7.0); + assert_eq!(y_2.dx, 3.0); + } + + #[test] + fn product() { + let f = |x: Dual| x.sin() * x.cos(); + let y_1 = f(dual!(1.0)); + assert_eq!(y_1.x, 0.45464867); + assert_eq!(y_1.dx, -0.4161468); + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..be57eec --- /dev/null +++ b/src/main.rs @@ -0,0 +1,13 @@ +use ezdiff::*; +use num_traits::Pow; + +pub fn main() { + let x = dual!(5.0); + let y = (x.pow(2.0)).cos(); + dbg!(y); + + // f(x) = cos(x^2) + 3x + let f = |x: f32| (dual!(x).pow(2.0)).cos() + 3.0 * dual!(x); + // evaluate at x = 2 + dbg!(f(2.0)); +}