Parallelized GP solvers on GPU

Hide code cell content

try:
    import smolgp
except ImportError:
    %pip install -q smolgp

try:
    import tinygp
except ImportError:
    %pip install -q tinygp
    
import jax
key = jax.random.PRNGKey(0)
jax.config.update("jax_enable_x64", True)

Parallelized GP solvers on GPU#

In An Introduction to State Space Gaussian Processes, we saw how the traditional Kalman filter and RTS smoother sequentially solve for the conditional filtered and smoothed distributions at each data point. As it turns out, these distributions have associative properties that enable a reframing of the Kalman/RTS algorithms as an all-prefix-sums problem, which can be efficiently solved by means of parallel-scan algorithms in

(1)#\[\begin{align} \mathcal{O}(N/T + \log T) \end{align}\]

runtime complexity, for \(N\) data ponts and \(T\) parallel workers. We can see this gives the usual parallel speedup factor of \(T\) when \(N \gg T\) (usual scenario), although if \(T \gtrsim N\) you may see scaling as good as \(\mathcal{O}(\log N)\).

Särkkä and García-Fernández (2021) introduced this parallel method, which was extended to the case of integrated measurements in Yaghoobi and Särkkä (2025), though those authors use a different framework to handle the integrations than we do. The smolgp method of augmenting the state space (see Integrated Measurements) instead lets us use an elaboration of the Särkkä and García-Fernández (2021) method, which is described in Section 3.2.4 of Rubenzahl and Hattori et al. (2026).

Running on GPU

The parallel solvers are only significantly faster than their sequential counterparts when run on a GPU (see Benchmarks). Make sure your hardware supports jax[cuda], which you can install alongside smolgp with

uv add smolgp[cuda]

or uv add smolgp[cuda12] or uv add smolgp[cuda13] for a specific version.

Hide code cell content

import jax.numpy as jnp
from scipy.interpolate import make_smoothing_spline

# True kernel for sampling the underlying process
kernel_tiny = tinygp.kernels.quasisep.SHO(omega=2*jnp.pi/50, quality=5.0, sigma=1.0)

def get_true_process(true_kernel, tmin=0, tmax=1000, dt=1):
    t = jnp.arange(tmin, tmax, dt)
    true_gp = tinygp.GaussianProcess(true_kernel, t)
    # NOTE: gp.sample adds small random noise for numerical stability
    y_sample = true_gp.sample(key=jax.random.PRNGKey(32)) 
    f = make_smoothing_spline(t, y_sample, lam=dt/6)
    return t, f

## True process
t_true, f = get_true_process(kernel_tiny, tmin=0, tmax=1000, dt=1)
y_true = f(t_true)

## Mock data
t_train  = jnp.sort(jax.random.uniform(key, (50,), minval=0, maxval=1000))
yerr = 0.75 * jnp.ones_like(t_train)
y_train = f(t_train) + yerr * jax.random.normal(key, t_train.shape)
yerr_train = jnp.full_like(t_train, yerr)

To use the parallel solver, simply build the GP object with solver=smolgp.solvers.ParallelStateSpaceSolver:

gp_smol = smolgp.GaussianProcess(
    kernel=smolgp.kernels.SHO(omega=2*jnp.pi/50, quality=5.0, sigma=1.0),
    X=t_train,
    noise=yerr_train**2,
    solver=smolgp.solvers.ParallelStateSpaceSolver,
)
gp_smol.log_probability(y_train)
Array(-74.10304301, dtype=float64)
gp_tiny = tinygp.GaussianProcess(kernel=kernel_tiny, X=t_train, diag=yerr_train**2)
gp_tiny.log_probability(y_train)
Array(-74.10304301, dtype=float64)

Tip

For integrated data, instead use solver=smolgp.solvers.ParallelIntegratedStateSpaceSolver.