from __future__ import annotations
import jax
import jax.numpy as jnp
from tinygp.helpers import JAXArray
[docs]
def KalmanFilter(kernel, X, y, R, return_v_S=False):
"""
Wrapper for jitted kalman_filter function
Parameters:
kernel: StateSpaceModel kernel
X: data coordinates, e.g. time or (time, texp, instid)
y: observations, shape (N, D)
R: observation noise covariance, shape (N, D, D)
Returns:
m_filtered: filtered means
P_filtered: filtered covariances
m_predicted: predicted means
P_predicted: predicted covariances
"""
H = kernel.observation_model
A = kernel.transition_matrix
Q = kernel.process_noise
m0 = jnp.zeros(kernel.dimension)
P0 = kernel.stationary_covariance()
if not isinstance(P0, JAXArray):
P0 = P0.to_dense() # needed for carry in jax.lax.scan
t = kernel.coord_to_sortable(X)
output = kalman_filter(A, Q, H, R, t, y, m0, P0)
if return_v_S:
return output
else:
m_filtered, P_filtered, m_predicted, P_predicted, v, S = output
return m_filtered, P_filtered, m_predicted, P_predicted
[docs]
@jax.jit
def kalman_filter(A, Q, H, R, t, y, m0, P0):
"""
Jax implementation of the Kalman filter algorithm
See Theorem 4.2 (pdf page 77) in "Bayesian Filtering and Smoothing"
by Simo S{\"a}rkk{\"a} for detailed description of the algorithm and notation.
e.g. _prev is _{k-1}
_pred is _k^{-}
Total runtime complexity is O(N*d^3) where N is the number
of time steps and d is the dimension of the state vector.
"""
N = len(t) # number of data points
def step(carry, k):
"""
Routine for a single step of the Kalman filter
Parameters:
carry: (x_prev, P_prev) - previous state and covariance
k: index of the current time step
Returns:
- Conditioned state (m_k) and covariance (P_k) to carry to next iteration
- Full output for completed scan (m_k, P_k, m_pred, P_pred)
"""
# Unpack previous state and covariance
m_prev, P_prev = carry
# Logic to check if first time step:
# If k==0 we use the prior x0, P0
# and zero time-lag (Delta=0)
Delta = jax.lax.cond(
k > 0,
lambda i: t[i] - t[i - 1],
lambda _: 0.0,
k,
)
# Get transition matrix
A_prev = A(0, Delta)
Q_prev = Q(0, Delta)
# Predict (Eq. 4.20)
m_pred = A_prev @ m_prev
P_pred = A_prev @ P_prev @ A_prev.T + Q_prev
# Update (Eq. 4.21)
## TODO: let t be a tuple and use coord_to_sortable to get time axis out
## That way we can pass the full t into H and it can use e.g. instid etc.
H_k = H(t[k]) # observation model for this time step
y_pred = H_k @ m_pred # predicted observation
v_k = y[k] - y_pred # "innovation" or "surprise" term
S_k = H_k @ P_pred @ H_k.T + R[k] # uncertainy in predicted observation
# S_k_inv = jnp.linalg.inv(S_k)
# K_k = P_pred @ H_k.T @ S_k_inv # Kalman gain
K_k = jnp.linalg.solve(S_k.T, (P_pred @ H_k.T).T).T # more stable
m_k = m_pred + K_k @ v_k # conditioned state estimate
P_k = P_pred - K_k @ S_k @ K_k.T # conditioned covariance estimate
return (m_k, P_k), (m_k, P_k, m_pred, P_pred, v_k, S_k)
# Initialize carry with prior state and covariance
init_carry = (m0, P0)
# Run the filter over all time steps, unpack, and return results
_, outputs = jax.lax.scan(step, init_carry, jnp.arange(N))
return outputs