Parallelized GP solvers on GPU#
In An Introduction to State Space Gaussian Processes, we saw how the traditional Kalman filter and RTS smoother sequentially solve for the conditional filtered and smoothed distributions at each data point. As it turns out, these distributions have associative properties that enable a reframing of the Kalman/RTS algorithms as an all-prefix-sums problem, which can be efficiently solved by means of parallel-scan algorithms in
runtime complexity, for \(N\) data ponts and \(T\) parallel workers. We can see this gives the usual parallel speedup factor of \(T\) when \(N \gg T\) (usual scenario), although if \(T \gtrsim N\) you may see scaling as good as \(\mathcal{O}(\log N)\).
Särkkä and García-Fernández (2021) introduced this parallel method, which was extended to the case of integrated measurements in Yaghoobi and Särkkä (2025), though those authors use a different framework to handle the integrations than we do. The smolgp method of augmenting the state space (see Integrated Measurements) instead lets us use an elaboration of the Särkkä and García-Fernández (2021) method, which is described in Section 3.2.4 of Rubenzahl and Hattori et al. (2026).
Running on GPU
The parallel solvers are only significantly faster than their sequential counterparts when run on a GPU (see Benchmarks). Make sure your hardware supports jax[cuda], which you can install alongside smolgp with
uv add smolgp[cuda]
or uv add smolgp[cuda12] or uv add smolgp[cuda13] for a specific version.
To use the parallel solver, simply build the GP object with solver=smolgp.solvers.ParallelStateSpaceSolver:
gp_smol = smolgp.GaussianProcess(
kernel=smolgp.kernels.SHO(omega=2*jnp.pi/50, quality=5.0, sigma=1.0),
X=t_train,
noise=yerr_train**2,
solver=smolgp.solvers.ParallelStateSpaceSolver,
)
gp_smol.log_probability(y_train)
Array(-74.10304301, dtype=float64)
gp_tiny = tinygp.GaussianProcess(kernel=kernel_tiny, X=t_train, diag=yerr_train**2)
gp_tiny.log_probability(y_train)
Array(-74.10304301, dtype=float64)
Tip
For integrated data, instead use solver=smolgp.solvers.ParallelIntegratedStateSpaceSolver.