solver#

Classes#

StateSpaceSolver

A solver that implements Kalman filtering and RTS smoothing for state space GPs.

Module Contents#

class smolgp.solvers.solver.StateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#

Bases: equinox.Module

A solver that implements Kalman filtering and RTS smoothing for state space GPs.

Given a StateSpaceModel kernel and a set of observed coordinates, this solver computes the Kalman filtered and Rauch-Tung-Striebel (RTS) smoothed posterior means and covariances using jax.lax.scan for efficient JIT-compiled sequential computation.

Predictions at arbitrary test coordinates are handled by predict(), which dispatches among retrodiction, interpolation, and extrapolation depending on whether each test point falls before, between, or after the observed data.

Parameters:
  • kernel (StateSpaceModel) – The kernel function; must be a StateSpaceModel instance.

  • X (JAXArray) – The input coordinates with leading dimension of size N.

  • noise (JAXArray) – Observation noise covariance array of shape (N, D, D), where D is the observation dimension.

kernel#

The kernel defining the state space model.

Type:

StateSpaceModel

X#

The observed input coordinates.

Type:

JAXArray

noise#

Per-observation noise covariance matrices, shape (N, D, D).

Type:

JAXArray

t_states#

Sortable scalar coordinates derived from X via coord_to_sortable().

Type:

JAXArray

X: tinygp.helpers.JAXArray#
kernel: smolgp.kernels.base.StateSpaceModel#
noise: tinygp.helpers.JAXArray#
t_states: tinygp.helpers.JAXArray#
normalization() tinygp.helpers.JAXArray[source]#
Kalman(y, return_v_S=False) Any[source]#

Wrapper for Kalman filter used with this solver

RTS(kalman_results) Any[source]#

Wrapper for RTS smoother used with this solver

condition(y, return_v_S=False) tinygp.helpers.JAXArray[source]#

Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates

predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#

Algorithm for making predictions at arbitrary coordinates X_test.

Parameters:
  • X_test (JAXArray) – The test coordinates; same shape as self.X.

  • conditioned_results (tuple) – The output of condition().

Returns:

A pair (pred_mean, pred_var) of arrays with leading dimension M = len(X_test), giving the predicted state means and covariances at each test coordinate.

Return type:

tuple

Each test point is handled by one of three cases depending on its position relative to the observed data:

  1. Retrodiction — test point precedes all observations: smoothed backward from the first data point using the stationary prior.

  2. Interpolation — test point falls between two observations: Kalman-predicted forward from the nearest past point, then RTS-smoothed backward from the nearest future point.

  3. Extrapolation — test point follows all observations: Kalman-predicted forward from the final filtered state.