smolgp#

smolgp is designed to be a drop-in extension of the tinygp library for building Gaussian Process (GP) models in Python. As such, it is also built on top of jax. The driving design philosophy is to match the API of tinygp as closely as possible. With only a few exceptions, any existing code you have that uses tinygp should work with smolgp by simply by finding-and-replacing tiny with smol.

smolgp uses the state space representations of Gaussian Processes to implement linear-time (or up to logN with parallelization on GPU) solvers for GP regression and forecasting. It also implements “integrated” kernels that can model time-averaged measurements, such as those from long-exposure astronomical observations, which can also be solved in linear time and are also compatible with the parallel methods.

The primary way that you interact with smolgp is to construct “kernel” functions using the building blocks provided in the kernels subpackage (see smolgp.kernels), and then passing that to a GaussianProcess object to do all the computations. Check out the Tutorials for a more complete introduction.

Submodules#

Attributes#

Classes#

GaussianProcess

An interface for designing a Gaussian Process regression model.

Package Contents#

class smolgp.GaussianProcess(kernel: GaussianProcess.__init__.kernels, X: tinygp.helpers.JAXArray, *, noise: tinygp.helpers.JAXArray | None = None, mean: tinygp.means.MeanBase | Callable[[tinygp.helpers.JAXArray], tinygp.helpers.JAXArray] | tinygp.helpers.JAXArray | None = None, solver: Any | None = None, mean_value: tinygp.helpers.JAXArray | None = None, variance_value: tinygp.helpers.JAXArray | None = None, covariance_value: Any | None = None, states: tinygp.helpers.JAXArray | None = None, use_unique_names: bool = True, **solver_kwargs: Any)[source]#

Bases: equinox.Module

An interface for designing a Gaussian Process regression model.

Parameters:
  • kernel (Kernel) – The kernel function.

  • X (JAXArray) – The input coordinates — any PyTree compatible with kernel whose leading dimension has size N_data. For integrated kernels, pass (t, texp) where t is the array of exposure midpoints and texp is the array of exposure durations.

  • noise (JAXArray, optional) – Observation noise covariance matrices with shape (N, D, D), where N is the number of data points and D is the observation dimension (usually 1). Each slice noise[k] is the \(D \times D\) noise covariance for the k-th observation. A 1-D array of shape (N,) is interpreted as scalar per-observation variances. Defaults to \(\sqrt{\varepsilon_{\mathrm{machine}}} \cdot I\) for all observations.

  • mean (Callable, optional) – A callable or constant mean function evaluated as mean(X).

  • solver – Solver class for filtering and smoothing. If None (default), selected automatically based on the kernel type.

num_data: int#
dtype: jax.numpy.dtype#
kernel: tinygp.kernels.Kernel#
X: tinygp.helpers.JAXArray#
mean_function: tinygp.means.MeanBase#
mean: tinygp.helpers.JAXArray#
var: tinygp.helpers.JAXArray | None#
noise: tinygp.helpers.JAXArray#
solver: smolgp.solvers.StateSpaceSolver#
states: ConditionedStates#
property loc: tinygp.helpers.JAXArray#

If conditioned, this will be the mean at the data points Otherwise, it is just the prior mean.

property variance: tinygp.helpers.JAXArray#

If conditioned, this will be the variance at the data points Otherwise, it is just the prior variance.

property covariance: tinygp.helpers.JAXArray#
Abstractmethod:

log_probability(y: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#

Compute the log probability of this multivariate normal

Parameters:

y (JAXArray) – The observed data. This should have the shape (N_data,), where N_data was the zeroth axis of the X data provided when instantiating this object.

Returns:

The marginal log probability of this multivariate normal model, evaluated at y.

condition(y: tinygp.helpers.JAXArray, X_test: tinygp.helpers.JAXArray | None = None, *, include_mean: bool = True, kernel: tinygp.kernels.Kernel | None = None) ConditionResult[source]#

Condition the model on observed data

Parameters:
  • y (JAXArray) – The observed data. This should have the shape (N_data,), where N_data was the zeroth axis of the X data provided when instantiating this object.

  • X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the X data provided when instantiating this object. If it is not provided, X will be used by default, so the predictions will be made.

  • include_mean (bool, optional) – If True (default), the predicted values will include the mean function evaluated at X_test.

  • kernel (Kernel, optional) – A kernel to optionally specify the component kernel to be used for predicting after conditioning. See Multicomponent Kernels for an example.

Returns:

A named tuple where the first element log_probability is the log marginal probability of the model, and the second element gp is the GaussianProcess object describing the conditional distribution evaluated at X_test.

predict(X_test: tinygp.helpers.JAXArray | None = None, y: tinygp.helpers.JAXArray | None = None, *, return_full_state: bool = False, kernel: int | None = None, return_var: bool = False, observation_model: Any | None = None) tinygp.helpers.JAXArray | tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]#

Predict the GP model at new test points conditioned on observed data

Parameters:
  • X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the X data provided when instantiating this object. If it is not provided, X will be used by default, so the predictions will be made at the data coordinates.

  • y (JAXArray) – The observed data. Only needs to be given if the GP has not yet been conditioned. This should have the shape (N_data,), where N_data was the zeroth axis of the X data provided when instantiating this object.

  • include_mean (bool, optional) – If True (default), the predicted values will include the mean function evaluated at X_test.

  • return_var (bool, optional) – If True (default), the variance of the predicted values at X_test will be returned.

  • return_cov (bool, optional) – If True, the covariance of the predicted values at X_test will be returned. If return_var is True, this flag will be ignored.

  • observation_model (Any, optional) – optionally provide a function of X_test to define the output observation model. Default will use that of the kernel.

  • return_full_state (bool, optional) – If True, return the full predicted state mean and covariance, rather than projecting to observation space. Default is False, i.e. the result is projected through kernel.observation_model.

  • kernel (int, optional) – If specified, the index of the kernel in a multi-component model (for example, a sum or product of kernels) to extract and project (if return_full_state is False) the prediction for.

Returns:

The mean of the predictive model evaluated at X_test, with shape (N_test,) where N_test is the zeroth dimension of X_test. If either return_var or return_cov is True, the variance or covariance of the predicted process will also be returned with shape (N_test,) or (N_test, N_test) respectively.

sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None = None) tinygp.helpers.JAXArray[source]#

Generate samples from the prior process

Parameters:
  • key – A jax random number key array. shape (tuple, optional): The

  • to (number and shape of samples) – generate.

Returns:

The sampled realizations from the process with shape (N_data,) + shape where N_data is the zeroth dimension of the X coordinates provided when instantiating this process.

numpyro_dist(**kwargs: Any) tinygp.numpyro_support.TinyDistribution[source]#

Get the numpyro MultivariateNormal distribution for this process

abstractmethod _sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None) tinygp.helpers.JAXArray[source]#
_compute_log_prob(v: tinygp.helpers.JAXArray, S: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#

Compute the log-likelihood given v and S from the Kalman filter

get_component_mean(component: list | str, return_var: bool = False, **kwargs) Any[source]#

Get the predictive mean (and variance) of a particular (or sum of) component kernel in a multi-component model evaluated at self.X

Parameters:
  • X (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the X data provided when instantiating this object.

  • component (list | str) – The name(s) of the component kernel(s) to extract the mean for. If a list of names is provided, the joint mean and variance for that collection of kernels will be returned.

  • return_var (bool, optional) – If True, also return the variances of each component. Default is False.

Returns:

component_mean (JAXArray) If return_var is True:

component_mean (JAXArray) component_var (JAXArray)

Return type:

If return_var is False

get_all_component_means(return_var: bool = False, **kwargs) Any[source]#

Get the predictive mean (and optionally variance) of each component kernels individually, evaluated at self.X

Parameters:

return_var (bool, optional) – If True, also return the variances of each component. Default is False.

Returns:

If return_var is False, a list of JAX arrays containing the means of each component kernel evaluated at the data points. If return_var is True, a tuple where the first element is the list of means as before, and the second element is a list of JAX arrays containing the variances of each component kernel evaluated at the data points.

smolgp.__version__#