solvers

Contents

solvers#

In smolgp, “solvers” provide a swappable low-level interface for the Bayesian filtering and smoothing algorithms required for GP conditioning. New solvers can be contributed as external packages or pull requests to the smolgp GitHub repository.

The four built-in solvers are:

  1. StateSpaceSolver: Standard Kalman filter and RTS smoother for instantaneous kernels (see smolgp.kernels.base). This is the default solver.

  2. IntegratedStateSpaceSolver: Kalman filter and RTS smoother for integrated (time-averaged) measurement kernels (see smolgp.kernels.integrated).

  3. ParallelStateSpaceSolver: GPU-parallelised version of StateSpaceSolver with \(O(\log N)\) complexity on compatible hardware.

  4. ParallelIntegratedStateSpaceSolver: GPU-parallelised version of IntegratedStateSpaceSolver.

All solvers are exact up to numerical precision.

Users generally do not need to instantiate solvers directly; GaussianProcess selects the appropriate solver automatically based on the kernel type.

Submodules#

Classes#

StateSpaceSolver

A solver that implements Kalman filtering and RTS smoothing for state space GPs.

ParallelStateSpaceSolver

A solver that uses jax.lax.associative_scan to implement

IntegratedStateSpaceSolver

A solver that uses jax.lax.scan to implement Kalman filtering

ParallelIntegratedStateSpaceSolver

A solver that uses jax.lax.associative_scan to implement

Package Contents#

class smolgp.solvers.StateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#

Bases: equinox.Module

A solver that implements Kalman filtering and RTS smoothing for state space GPs.

Given a StateSpaceModel kernel and a set of observed coordinates, this solver computes the Kalman filtered and Rauch-Tung-Striebel (RTS) smoothed posterior means and covariances using jax.lax.scan for efficient JIT-compiled sequential computation.

Predictions at arbitrary test coordinates are handled by predict(), which dispatches among retrodiction, interpolation, and extrapolation depending on whether each test point falls before, between, or after the observed data.

Parameters:
  • kernel (StateSpaceModel) – The kernel function; must be a StateSpaceModel instance.

  • X (JAXArray) – The input coordinates with leading dimension of size N.

  • noise (JAXArray) – Observation noise covariance array of shape (N, D, D), where D is the observation dimension.

kernel#

The kernel defining the state space model.

Type:

StateSpaceModel

X#

The observed input coordinates.

Type:

JAXArray

noise#

Per-observation noise covariance matrices, shape (N, D, D).

Type:

JAXArray

t_states#

Sortable scalar coordinates derived from X via coord_to_sortable().

Type:

JAXArray

X: tinygp.helpers.JAXArray#
kernel: smolgp.kernels.base.StateSpaceModel#
noise: tinygp.helpers.JAXArray#
t_states: tinygp.helpers.JAXArray#
normalization() tinygp.helpers.JAXArray[source]#
Kalman(y, return_v_S=False) Any[source]#

Wrapper for Kalman filter used with this solver

RTS(kalman_results) Any[source]#

Wrapper for RTS smoother used with this solver

condition(y, return_v_S=False) tinygp.helpers.JAXArray[source]#

Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates

predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#

Algorithm for making predictions at arbitrary coordinates X_test.

Parameters:
  • X_test (JAXArray) – The test coordinates; same shape as self.X.

  • conditioned_results (tuple) – The output of condition().

Returns:

A pair (pred_mean, pred_var) of arrays with leading dimension M = len(X_test), giving the predicted state means and covariances at each test coordinate.

Return type:

tuple

Each test point is handled by one of three cases depending on its position relative to the observed data:

  1. Retrodiction — test point precedes all observations: smoothed backward from the first data point using the stationary prior.

  2. Interpolation — test point falls between two observations: Kalman-predicted forward from the nearest past point, then RTS-smoothed backward from the nearest future point.

  3. Extrapolation — test point follows all observations: Kalman-predicted forward from the final filtered state.

class smolgp.solvers.ParallelStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#

Bases: equinox.Module

A solver that uses jax.lax.associative_scan to implement parallel Kalman filtering and RTS smoothing

X: tinygp.helpers.JAXArray#
kernel: smolgp.kernels.base.StateSpaceModel#
noise: tinygp.helpers.JAXArray#
normalization() tinygp.helpers.JAXArray[source]#
Kalman(y, return_v_S=True) Any[source]#

Wrapper for Kalman filter used with this solver

RTS(kalman_results) Any[source]#

Wrapper for RTS smoother used with this solver

condition(y, return_v_S=False) tinygp.helpers.JAXArray[source]#

Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates

predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#

Algorithm for making predictions at arbitrary coordinates X_test

Parameters:
  • X_test – The test coordinates.

  • conditioned_results – The output of self.condition

Returns:

Predicted means of the states at X_test pred_var : Predicted variances of the states at X_test

Return type:

pred_mean

There are three cases:
  1. Retrodictionsmoothing from the first data point

    using the prior as the prediction

  2. Interpolationfiltering from most recent data point

    and smoothing from next future point

  3. Extrapolation : predicting from final filtered point

class smolgp.solvers.IntegratedStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#

Bases: equinox.Module

A solver that uses jax.lax.scan to implement Kalman filtering and RTS smoothing for integrated measurements

X: tinygp.helpers.JAXArray#
kernel: smolgp.kernels.base.StateSpaceModel#
noise: tinygp.helpers.JAXArray#
state_coords: tinygp.helpers.JAXArray#
normalization() tinygp.helpers.JAXArray[source]#
Kalman(y, return_v_S=False) Any[source]#

Wrapper for Kalman filter used with this solver

RTS(kalman_results) Any[source]#

Wrapper for RTS smoother used with this solver

condition(y, return_v_S=False) tinygp.helpers.JAXArray[source]#

Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates

predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#

Algorithm for making predictions at arbitrary coordinates X_test

Parameters:
  • X_test – The test coordinates.

  • conditioned_results – The output of self.condition()

Returns:

Predicted means of the states at X_test pred_var : Predicted variances of the states at X_test

Return type:

pred_mean

There are three cases:
  1. Retrodictionsmoothing from the first data point

    using the prior as the prediction

  2. Interpolationfiltering from most recent data point

    and smoothing from next future point

  3. Extrapolation : predicting from final filtered point

class smolgp.solvers.ParallelIntegratedStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#

Bases: equinox.Module

A solver that uses jax.lax.associative_scan to implement parallel Kalman filtering and RTS smoothing for integrated measurements

X: tinygp.helpers.JAXArray#
kernel: smolgp.kernels.base.StateSpaceModel#
noise: tinygp.helpers.JAXArray#
state_coords: tinygp.helpers.JAXArray#
_state_coords: tinygp.helpers.JAXArray#
normalization() tinygp.helpers.JAXArray[source]#
Kalman(y, return_v_S=True) Any[source]#

Wrapper for Kalman filter used with this solver

RTS(kalman_results) Any[source]#

Wrapper for RTS smoother used with this solver

condition(y, return_v_S=True) tinygp.helpers.JAXArray[source]#

Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates

predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#

Algorithm for making predictions at arbitrary coordinates X_test

Parameters:
  • X_test – The test coordinates.

  • conditioned_results – The output of self.condition()

  • observation_model – (optional) H for the test points should be a function just like self.kernel.observation_model

There are three cases:
  1. Retrodictionsmoothing from the first data point

    using the prior as the prediction

  2. Interpolationfiltering from most recent data point

    and smoothing from next future point

  3. Extrapolation : predicting from final filtered point