Skip to content

Commit

Permalink
Adding a Kalman-filtering based solver as a baseline (dfm#67)
Browse files Browse the repository at this point in the history
* adding kalman solver as benchmark

* lint

* adding news fragment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] authored Mar 10, 2022
1 parent cff64c1 commit b18db83
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 13 deletions.
7 changes: 3 additions & 4 deletions docs/benchmarks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@
"try:\n",
" import celerite2\n",
"except ImportError:\n",
" %pip install -q celerite2\n",
"\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)"
" %pip install -q celerite2"
]
},
{
Expand Down Expand Up @@ -130,6 +127,8 @@
"import celerite2\n",
"import tinygp\n",
"\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"sigma = 1.5\n",
"rho = 2.5\n",
"jitter = 0.1\n",
Expand Down
2 changes: 2 additions & 0 deletions news/67.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added a minimal solver based on Kalman filtering to use as a baseline for
checking the performance of the :class:`tinygp.solvers.QuasisepSolver`.
27 changes: 18 additions & 9 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ class Celerite(Quasisep):
k(\tau)=\exp(-c\,\tau)\,\left[a\,\cos(d\,\tau)+b\,\sin(d\,\tau)\right]
for :math:`\tau = |x_i - x_j|`.
In order to be positive definite, the parameters of this kernel must satisfy
:math:`a\,c - b\,d > 0`, and you will see NaNs if you use parameters that
don't satisfy this relationship.
"""
a: JAXArray
b: JAXArray
Expand All @@ -322,23 +326,24 @@ def design_matrix(self) -> JAXArray:
return jnp.array([[-self.c, -self.d], [self.d, -self.c]])

def stationary_covariance(self) -> JAXArray:
a = self.a
b = self.b
c = self.c
d = self.d
diff = jnp.square(c) - jnp.square(d)
return jnp.array(
[
[a, b],
[
b * diff + 2 * a * c * d,
-self.a * diff + 2 * b * c * d,
],
[1, -c / d],
[-c / d, 1 + 2 * jnp.square(c) / jnp.square(d)],
]
)

def observation_model(self, X: JAXArray) -> JAXArray:
return jnp.array([1.0, 0.0])
a = self.a
b = self.b
c = self.c
d = self.d
s = jnp.square(c) + jnp.square(d)
f = jnp.sqrt(a * c + b * d)
g = jnp.sqrt((a * c - b * d) * s)
return jnp.array([d * f, c * f - g]) / jnp.sqrt(2 * c * s)

def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
dt = X2 - X1
Expand Down Expand Up @@ -484,6 +489,10 @@ class Matern32(Quasisep):
scale: JAXArray
sigma: JAXArray = field(default_factory=lambda: jnp.ones(()))

def noise(self) -> JAXArray:
f = np.sqrt(3) / self.scale
return 4 * f**3

def design_matrix(self) -> JAXArray:
f = np.sqrt(3) / self.scale
return jnp.array([[0, 1], [-jnp.square(f), -2 * f]])
Expand Down
128 changes: 128 additions & 0 deletions src/tinygp/solvers/kalman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-

from __future__ import annotations

__all__ = ["kalman_filter"]

from typing import Any, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np

from tinygp.helpers import JAXArray, dataclass
from tinygp.kernels.base import Kernel
from tinygp.noise import Diagonal, Noise
from tinygp.solvers.solver import Solver


@dataclass
class KalmanSolver(Solver):
"""A scalable solver that uses Kalman filtering
This implementation is very limited and it is meant primarily
You generally won't instantiate this object directly but, if you do, you'll
probably want to use the :func:`KalmanSolver.init` method instead of the
usual constructor.
"""

X: JAXArray
A: JAXArray
H: JAXArray
s: JAXArray
K: JAXArray

@classmethod
def init(
cls,
kernel: Kernel,
X: JAXArray,
noise: Noise,
*,
covariance: Optional[Any] = None,
) -> "KalmanSolver":
"""Build a :class:`KalmanSolver` for a given kernel and coordinates
Args:
kernel: The kernel function. This must be an instance of a subclass
of :class:`tinygp.kernels.quasisep.Quasisep`.
X: The input coordinates.
noise: The noise model for the process. This must be diagonal for
this solver.
covariance: Not yet supported by this solver.
"""
from tinygp.kernels.quasisep import Quasisep

assert isinstance(kernel, Quasisep)
assert isinstance(noise, Diagonal)
assert covariance is None

Pinf = kernel.stationary_covariance()
A = jax.vmap(kernel.transition_matrix)(
jax.tree_util.tree_map(lambda y: jnp.append(y[0], y[:-1]), X), X
)
H = jax.vmap(kernel.observation_model)(X)
s, K = kalman_gains(Pinf, A, H, noise.diag)
return cls(X=X, A=A, H=H, s=s, K=K)

def variance(self) -> JAXArray:
raise NotImplementedError

def covariance(self) -> JAXArray:
raise NotImplementedError

def normalization(self) -> JAXArray:
return 0.5 * jnp.sum(jnp.log(2 * np.pi * self.s))

def solve_triangular(
self, y: JAXArray, *, transpose: bool = False
) -> JAXArray:
assert not transpose
return kalman_filter(self.A, self.H, self.K, y) / jnp.sqrt(self.s)

def dot_triangular(self, y: JAXArray) -> JAXArray:
raise NotImplementedError

def condition(
self, kernel: Kernel, X_test: Optional[JAXArray], noise: Noise
) -> Any:
raise NotImplementedError


@jax.jit
def kalman_gains(
Pinf: JAXArray, A: JAXArray, H: JAXArray, diag: JAXArray
) -> Tuple[JAXArray, JAXArray]:
def step(carry, data): # type: ignore
Pp = carry
Ak, hk, dk = data

Pn = Pinf + Ak.transpose() @ (Pp - Pinf) @ Ak
tmp = Pn @ hk
sk = hk @ tmp + dk
Kk = tmp / sk
Pk = Pn - sk * jnp.outer(Kk, Kk)

return Pk, (sk, Kk)

init = Pinf
return jax.lax.scan(step, init, (A, H, diag))[1]


@jax.jit
def kalman_filter(
A: JAXArray, H: JAXArray, K: JAXArray, y: JAXArray
) -> JAXArray:
def step(carry, data): # type: ignore
mp = carry
Ak, hk, Kk, yk = data

mn = Ak.transpose() @ mp
vk = yk - hk @ mn
mk = mn + Kk * vk

return mk, vk

init = jnp.zeros_like(H[0])
return jax.lax.scan(step, init, (A, H, K, y))[1]
73 changes: 73 additions & 0 deletions tests/test_solvers/test_kalman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# mypy: ignore-errors

import jax
import jax.numpy as jnp
import numpy as np
import pytest

from tinygp import GaussianProcess
from tinygp.kernels import quasisep
from tinygp.solvers import QuasisepSolver
from tinygp.solvers.kalman import KalmanSolver, kalman_filter, kalman_gains


@pytest.fixture
def random():
return np.random.default_rng(84930)


@pytest.fixture
def data(random):
x = np.sort(random.uniform(-3, 3, 50))
y = np.sin(x)
return x, y


@pytest.fixture(
params=[
quasisep.Matern32(sigma=1.8, scale=1.5),
1.8**2 * quasisep.Matern32(1.5),
quasisep.Matern52(sigma=1.8, scale=1.5),
quasisep.Exp(sigma=1.8, scale=1.5),
quasisep.Cosine(sigma=1.8, scale=1.5),
quasisep.SHO(sigma=1.8, omega=1.5, quality=3.0),
quasisep.SHO(sigma=1.8, omega=1.5, quality=0.2),
quasisep.Celerite(1.1, 0.8, 0.9, 0.1),
1.5 * quasisep.Matern52(1.5) + 0.3 * quasisep.Exp(1.5),
quasisep.Matern52(1.5) * quasisep.SHO(omega=1.5, quality=0.1),
1.5 * quasisep.Matern52(1.5) * quasisep.Celerite(1.1, 0.8, 0.9, 0.1),
quasisep.CARMA.init(
alpha=np.array([1.4, 2.3, 1.5]), beta=np.array([0.1, 0.5])
),
]
)
def kernel(request):
return request.param


def test_filter(kernel, data):
x, y = data
diag = jnp.full_like(x, 0.1)

logp0 = GaussianProcess(kernel, x, diag=diag).log_probability(y)

Pinf = kernel.stationary_covariance()
A = jax.vmap(kernel.transition_matrix)(jnp.append(x[0], x[:-1]), x)
H = jax.vmap(kernel.observation_model)(x)
s, K = kalman_gains(Pinf, A, H, diag)
v = kalman_filter(A, H, K, y)
logp = -0.5 * jnp.sum(jnp.square(v) / s + jnp.log(2 * jnp.pi * s))

np.testing.assert_allclose(logp, logp0)


def test_consistent_with_direct(kernel, data):
x, y = data
gp1 = GaussianProcess(kernel, x, diag=0.1, solver=KalmanSolver)
gp2 = GaussianProcess(kernel, x, diag=0.1, solver=QuasisepSolver)

np.testing.assert_allclose(gp1.log_probability(y), gp2.log_probability(y))
np.testing.assert_allclose(
gp1.solver.normalization(), gp2.solver.normalization()
)

0 comments on commit b18db83

Please sign in to comment.