solvers#
In smolgp, “solvers” provide a swappable low-level interface for the
Bayesian filtering and smoothing algorithms required for GP conditioning.
New solvers can be contributed as external packages or pull requests to the
smolgp GitHub repository.
The four built-in solvers are:
StateSpaceSolver: Standard Kalman filter and RTS smoother for instantaneous kernels (seesmolgp.kernels.base). This is the default solver.IntegratedStateSpaceSolver: Kalman filter and RTS smoother for integrated (time-averaged) measurement kernels (seesmolgp.kernels.integrated).ParallelStateSpaceSolver: GPU-parallelised version ofStateSpaceSolverwith \(O(\log N)\) complexity on compatible hardware.ParallelIntegratedStateSpaceSolver: GPU-parallelised version ofIntegratedStateSpaceSolver.
All solvers are exact up to numerical precision.
Users generally do not need to instantiate solvers directly; GaussianProcess
selects the appropriate solver automatically based on the kernel type.
Submodules#
Classes#
A solver that implements Kalman filtering and RTS smoothing for state space GPs. |
|
A solver that uses |
|
A solver that uses |
|
A solver that uses |
Package Contents#
- class smolgp.solvers.StateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#
Bases:
equinox.ModuleA solver that implements Kalman filtering and RTS smoothing for state space GPs.
Given a
StateSpaceModelkernel and a set of observed coordinates, this solver computes the Kalman filtered and Rauch-Tung-Striebel (RTS) smoothed posterior means and covariances usingjax.lax.scanfor efficient JIT-compiled sequential computation.Predictions at arbitrary test coordinates are handled by
predict(), which dispatches among retrodiction, interpolation, and extrapolation depending on whether each test point falls before, between, or after the observed data.- Parameters:
kernel (StateSpaceModel) – The kernel function; must be a
StateSpaceModelinstance.X (JAXArray) – The input coordinates with leading dimension of size
N.noise (JAXArray) – Observation noise covariance array of shape
(N, D, D), whereDis the observation dimension.
- kernel#
The kernel defining the state space model.
- Type:
- X#
The observed input coordinates.
- Type:
JAXArray
- noise#
Per-observation noise covariance matrices, shape
(N, D, D).- Type:
JAXArray
- t_states#
Sortable scalar coordinates derived from
Xviacoord_to_sortable().- Type:
JAXArray
- X: tinygp.helpers.JAXArray#
- noise: tinygp.helpers.JAXArray#
- t_states: tinygp.helpers.JAXArray#
- condition(y, return_v_S=False) tinygp.helpers.JAXArray[source]#
Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates
- predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#
Algorithm for making predictions at arbitrary coordinates
X_test.- Parameters:
X_test (JAXArray) – The test coordinates; same shape as
self.X.conditioned_results (tuple) – The output of
condition().
- Returns:
A pair
(pred_mean, pred_var)of arrays with leading dimensionM = len(X_test), giving the predicted state means and covariances at each test coordinate.- Return type:
tuple
Each test point is handled by one of three cases depending on its position relative to the observed data:
Retrodiction — test point precedes all observations: smoothed backward from the first data point using the stationary prior.
Interpolation — test point falls between two observations: Kalman-predicted forward from the nearest past point, then RTS-smoothed backward from the nearest future point.
Extrapolation — test point follows all observations: Kalman-predicted forward from the final filtered state.
- class smolgp.solvers.ParallelStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#
Bases:
equinox.ModuleA solver that uses
jax.lax.associative_scanto implement parallel Kalman filtering and RTS smoothing- X: tinygp.helpers.JAXArray#
- noise: tinygp.helpers.JAXArray#
- condition(y, return_v_S=False) tinygp.helpers.JAXArray[source]#
Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates
- predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#
Algorithm for making predictions at arbitrary coordinates X_test
- Parameters:
X_test – The test coordinates.
conditioned_results – The output of self.condition
- Returns:
Predicted means of the states at X_test pred_var : Predicted variances of the states at X_test
- Return type:
pred_mean
- There are three cases:
- Retrodictionsmoothing from the first data point
using the prior as the prediction
- Interpolationfiltering from most recent data point
and smoothing from next future point
Extrapolation : predicting from final filtered point
- class smolgp.solvers.IntegratedStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#
Bases:
equinox.ModuleA solver that uses
jax.lax.scanto implement Kalman filtering and RTS smoothing for integrated measurements- X: tinygp.helpers.JAXArray#
- noise: tinygp.helpers.JAXArray#
- state_coords: tinygp.helpers.JAXArray#
- condition(y, return_v_S=False) tinygp.helpers.JAXArray[source]#
Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates
- predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#
Algorithm for making predictions at arbitrary coordinates X_test
- Parameters:
X_test – The test coordinates.
conditioned_results – The output of self.condition()
- Returns:
Predicted means of the states at X_test pred_var : Predicted variances of the states at X_test
- Return type:
pred_mean
- There are three cases:
- Retrodictionsmoothing from the first data point
using the prior as the prediction
- Interpolationfiltering from most recent data point
and smoothing from next future point
Extrapolation : predicting from final filtered point
- class smolgp.solvers.ParallelIntegratedStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)[source]#
Bases:
equinox.ModuleA solver that uses
jax.lax.associative_scanto implement parallel Kalman filtering and RTS smoothing for integrated measurements- X: tinygp.helpers.JAXArray#
- noise: tinygp.helpers.JAXArray#
- state_coords: tinygp.helpers.JAXArray#
- _state_coords: tinygp.helpers.JAXArray#
- condition(y, return_v_S=True) tinygp.helpers.JAXArray[source]#
Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates
- predict(X_test, conditioned_results) tinygp.helpers.JAXArray[source]#
Algorithm for making predictions at arbitrary coordinates X_test
- Parameters:
X_test – The test coordinates.
conditioned_results – The output of self.condition()
observation_model – (optional) H for the test points should be a function just like self.kernel.observation_model
- There are three cases:
- Retrodictionsmoothing from the first data point
using the prior as the prediction
- Interpolationfiltering from most recent data point
and smoothing from next future point
Extrapolation : predicting from final filtered point