from __future__ import annotations
from typing import Any
import jax
import jax.numpy as jnp
import equinox as eqx
from tinygp.helpers import JAXArray
from tinygp.solvers.quasisep.solver import QuasisepSolver
from smolgp.kernels.base import StateSpaceModel
from smolgp.solvers.kalman import KalmanFilter
from smolgp.solvers.rts import RTSSmoother
[docs]
class StateSpaceSolver(eqx.Module):
r"""A solver that implements Kalman filtering and RTS smoothing for state space GPs.
Given a :class:`~smolgp.kernels.StateSpaceModel` kernel and a set of observed
coordinates, this solver computes the Kalman filtered and
Rauch-Tung-Striebel (RTS) smoothed posterior means and covariances using
``jax.lax.scan`` for efficient JIT-compiled sequential computation.
Predictions at arbitrary test coordinates are handled by :meth:`predict`,
which dispatches among retrodiction, interpolation, and extrapolation
depending on whether each test point falls before, between, or after
the observed data.
:param kernel: The kernel function; must be a
:class:`~smolgp.kernels.StateSpaceModel` instance.
:type kernel: StateSpaceModel
:param X: The input coordinates with leading dimension of size ``N``.
:type X: JAXArray
:param noise: Observation noise covariance array of shape ``(N, D, D)``,
where ``D`` is the observation dimension.
:type noise: JAXArray
Attributes:
kernel (StateSpaceModel): The kernel defining the state space model.
X (JAXArray): The observed input coordinates.
noise (JAXArray): Per-observation noise covariance matrices, shape ``(N, D, D)``.
t_states (JAXArray): Sortable scalar coordinates derived from ``X`` via
:meth:`~smolgp.kernels.StateSpaceModel.coord_to_sortable`.
"""
X: JAXArray
kernel: StateSpaceModel
noise: JAXArray
t_states: JAXArray
def __init__(
self,
kernel: StateSpaceModel,
X: JAXArray,
noise: JAXArray,
):
"""Build a :class:`StateSpaceSolver` for a given kernel and coordinates
Args:
kernel: The kernel function.
X: The input coordinates.
noise: Observation noise covariance array of shape ``(N, D, D)``.
"""
self.kernel = kernel
self.X = X
self.noise = noise
self.t_states = self.kernel.coord_to_sortable(X)
[docs]
def normalization(self) -> JAXArray:
# TODO: do we want/can we implement this in state space? for now, fall back to quasisep
class _NoiseAdapter:
def __init__(self, n):
self._n = n
def diagonal(self):
return self._n[0, 0, :]
return QuasisepSolver(
self.kernel, self.X, _NoiseAdapter(self.noise)
).normalization()
[docs]
def Kalman(self, y, return_v_S=False) -> Any:
"""Wrapper for Kalman filter used with this solver"""
y_nd = y[:, None] if y.ndim == 1 else y
return KalmanFilter(
self.kernel, self.t_states, y_nd, self.noise, return_v_S=return_v_S
)
[docs]
def RTS(self, kalman_results) -> Any:
"""Wrapper for RTS smoother used with this solver"""
return RTSSmoother(self.kernel, self.t_states, kalman_results)
[docs]
def condition(self, y, return_v_S=False) -> JAXArray:
"""
Compute the Kalman predicted, filtered, and RTS smoothed
means and covariances at each of the input coordinates
"""
# Kalman filtering
kalman_results = self.Kalman(y, return_v_S=return_v_S)
if return_v_S:
m_filtered, P_filtered, m_predicted, P_predicted, v, S = kalman_results
v_S = (v, S)
else:
m_filtered, P_filtered, m_predicted, P_predicted = kalman_results
v_S = None
# RTS smoothing
rts_results = self.RTS((m_filtered, P_filtered, m_predicted, P_predicted))
m_smoothed, P_smoothed = rts_results
# Pack-up results and return
t_states = self.kernel.coord_to_sortable(self.X)
conditioned_states = (
(m_predicted, P_predicted),
(m_filtered, P_filtered),
(m_smoothed, P_smoothed),
)
return t_states, conditioned_states, v_S
[docs]
@jax.jit
def predict(self, X_test, conditioned_results) -> JAXArray:
"""Algorithm for making predictions at arbitrary coordinates ``X_test``.
Args:
X_test (JAXArray): The test coordinates; same shape as ``self.X``.
conditioned_results (tuple): The output of :meth:`condition`.
Returns:
tuple: A pair ``(pred_mean, pred_var)`` of arrays with leading
dimension ``M = len(X_test)``, giving the predicted state means
and covariances at each test coordinate.
Each test point is handled by one of three cases depending on its
position relative to the observed data:
1. **Retrodiction** — test point precedes all observations: smoothed
backward from the first data point using the stationary prior.
2. **Interpolation** — test point falls between two observations:
Kalman-predicted forward from the nearest past point, then
RTS-smoothed backward from the nearest future point.
3. **Extrapolation** — test point follows all observations:
Kalman-predicted forward from the final filtered state.
"""
# Unpack conditioned results
state_coords, conditioned_states, _ = conditioned_results
(
(m_predicted, P_predicted),
(m_filtered, P_filtered),
(m_smoothed, P_smoothed),
) = conditioned_states
t_states, instid, obsid, stateid = state_coords
# Unpack test coordinates
t_test = self.kernel.coord_to_sortable(X_test)
# N = len(self.X) # number of data points (unused)
K = len(t_states) # number of states
M = len(t_test) # number of test points
Pinf = self.kernel.stationary_covariance()
if not isinstance(Pinf, JAXArray): # if multicomponent model
# need dense version for jnp.linalg.solve in retrodict
Pinf = Pinf.to_dense()
# Prior mean for retrodiction
# mean = jnp.zeros(self.kernel.d) # TODO: mean function of base kernel
# m0 = jnp.block([mean] + self.kernel.num_insts * [jnp.zeros(self.kernel.d)])
m0 = jnp.zeros(self.kernel.dimension)
# Nearest (future) datapoint
## TODO: we assume X_states is sorted here; should we enforce that?
k_nexts = jnp.searchsorted(t_states, t_test, side="right")
# Which method to use for each test point:
past = k_nexts <= 0 # Retrodiction
future = k_nexts >= K # Forecast
during = ~past & ~future # Interpolate
cases = past.astype(int) * 0 + during.astype(int) * 1 + future.astype(int) * 2
# Shorthand for matrices
A = self.kernel.transition_matrix
Q = self.kernel.process_noise
def kalman(k_prev, ktest):
"""
Kalman prediction from most recent
filtered (not smoothed) state
"""
dt = t_test[ktest] - t_states[k_prev]
m_k = m_filtered[k_prev]
P_k = P_filtered[k_prev]
A_star = A(0, dt) # transition matrix from t_k to t_star
Q_star = Q(0, dt) # process noise from t_k to t_star
m_star_pred = A_star @ m_k
P_star_pred = A_star @ P_k @ A_star.T + Q_star
# No Kalman update since we have no data at t_star, so we're done
return m_star_pred, P_star_pred
def smooth(k_next, ktest, m_star_pred, P_star_pred):
"""
RTS smooth the prediction (ktest) using
the nearest future (smoothed) state (k_next)
m_star_pred and P_star_pred are the output of kalman(k, k_star)
"""
# Next (future) data point predicted & smoothed state
m_pred_next = m_predicted[k_next]
P_pred_next = P_predicted[k_next]
m_hat_next = m_smoothed[k_next]
P_hat_next = P_smoothed[k_next]
# Time-lag between states
dt = t_states[k_next] - t_test[ktest]
# Transition matrix for this step
A_k = A(0, dt)
# Compute smoothing gain
# P_pred_next_inv = jnp.linalg.inv(P_pred_next)
# G_k = P_star_pred @ A_k.T @ P_pred_next_inv # smoothing gain
G_k = jnp.linalg.solve(
P_pred_next.T, (P_star_pred @ A_k.T).T
).T # more stable
# Update state and covariance
m_star_hat = m_star_pred + G_k @ (m_hat_next - m_pred_next)
P_star_hat = P_star_pred + G_k @ (P_hat_next - P_pred_next) @ G_k.T
return m_star_hat, P_star_hat
def retrodict(ktest):
"""Reverse-extrapolate from first datapoint t_star"""
m_star, P_star = smooth(0, ktest, m0, Pinf)
return m_star, P_star
def interpolate(ktest):
"""Interpolate between nearest data points"""
# Get nearest data point before and after the test point
k_next = k_nexts[ktest]
k_prev = k_next - 1
# 1. Kalman predict from most recent data point (in past)
m_star_pred, P_star_pred = kalman(k_prev, ktest)
# 2. RTS smooth from next nearest data point (in future)
m_star_hat, P_star_hat = smooth(k_next, ktest, m_star_pred, P_star_pred)
return m_star_hat, P_star_hat
def extrapolate(ktest):
"""Kalman predict from from last datapoint t_star"""
m_star, P_star = kalman(-1, ktest)
return m_star, P_star
def predict_point(ktest):
"""
Switch between retrodiction, interpolation, and extrapolation
for a single test point ktest
"""
return jax.lax.switch(
cases[ktest], (retrodict, interpolate, extrapolate), (ktest)
)
# Calculate predictions
ktests = jnp.arange(0, M, 1)
(pred_mean, pred_var) = jax.vmap(predict_point)(ktests)
return pred_mean, pred_var