smolgp#
smolgp is designed to be a drop-in extension of the tinygp
library for building Gaussian Process (GP) models in Python. As such, it is also built on top of
jax. The driving design philosophy is to match the API of tinygp
as closely as possible. With only a few exceptions, any existing code you have that uses tinygp should
work with smolgp by simply by finding-and-replacing tiny with smol.
smolgp uses the state space
representations of Gaussian Processes to implement linear-time (or up to logN with
parallelization on GPU) solvers for GP regression and forecasting. It also implements
“integrated” kernels that can model time-averaged measurements, such as those from
long-exposure astronomical observations, which can also be solved in linear time and are
also compatible with the parallel methods.
The primary way that you interact with smolgp is to construct
“kernel” functions using the building blocks provided in the kernels
subpackage (see smolgp.kernels), and then passing that to a
GaussianProcess object to do all the computations. Check out the
Tutorials for a more complete introduction.
Submodules#
Attributes#
Classes#
An interface for designing a Gaussian Process regression model. |
Package Contents#
- class smolgp.GaussianProcess(kernel: GaussianProcess.__init__.kernels, X: tinygp.helpers.JAXArray, *, noise: tinygp.helpers.JAXArray | None = None, mean: tinygp.means.MeanBase | Callable[[tinygp.helpers.JAXArray], tinygp.helpers.JAXArray] | tinygp.helpers.JAXArray | None = None, solver: Any | None = None, mean_value: tinygp.helpers.JAXArray | None = None, variance_value: tinygp.helpers.JAXArray | None = None, covariance_value: Any | None = None, states: tinygp.helpers.JAXArray | None = None, use_unique_names: bool = True, **solver_kwargs: Any)[source]#
Bases:
equinox.ModuleAn interface for designing a Gaussian Process regression model.
- Parameters:
kernel (Kernel) – The kernel function.
X (JAXArray) – The input coordinates — any PyTree compatible with
kernelwhose leading dimension has sizeN_data. For integrated kernels, pass(t, texp)wheretis the array of exposure midpoints andtexpis the array of exposure durations.noise (JAXArray, optional) – Observation noise covariance matrices with shape
(N, D, D), whereNis the number of data points andDis the observation dimension (usually 1). Each slicenoise[k]is the \(D \times D\) noise covariance for thek-th observation. A 1-D array of shape(N,)is interpreted as scalar per-observation variances. Defaults to \(\sqrt{\varepsilon_{\mathrm{machine}}} \cdot I\) for all observations.mean (Callable, optional) – A callable or constant mean function evaluated as
mean(X).solver – Solver class for filtering and smoothing. If
None(default), selected automatically based on the kernel type.
- num_data: int#
- dtype: jax.numpy.dtype#
- kernel: tinygp.kernels.Kernel#
- X: tinygp.helpers.JAXArray#
- mean_function: tinygp.means.MeanBase#
- mean: tinygp.helpers.JAXArray#
- var: tinygp.helpers.JAXArray | None#
- noise: tinygp.helpers.JAXArray#
- solver: smolgp.solvers.StateSpaceSolver#
- states: ConditionedStates#
- property loc: tinygp.helpers.JAXArray#
If conditioned, this will be the mean at the data points Otherwise, it is just the prior mean.
- property variance: tinygp.helpers.JAXArray#
If conditioned, this will be the variance at the data points Otherwise, it is just the prior variance.
- property covariance: tinygp.helpers.JAXArray#
- Abstractmethod:
- log_probability(y: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#
Compute the log probability of this multivariate normal
- Parameters:
y (JAXArray) – The observed data. This should have the shape
(N_data,), whereN_datawas the zeroth axis of theXdata provided when instantiating this object.- Returns:
The marginal log probability of this multivariate normal model, evaluated at
y.
- condition(y: tinygp.helpers.JAXArray, X_test: tinygp.helpers.JAXArray | None = None, *, include_mean: bool = True, kernel: tinygp.kernels.Kernel | None = None) ConditionResult[source]#
Condition the model on observed data
- Parameters:
y (JAXArray) – The observed data. This should have the shape
(N_data,), whereN_datawas the zeroth axis of theXdata provided when instantiating this object.X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
Xdata provided when instantiating this object. If it is not provided,Xwill be used by default, so the predictions will be made.include_mean (bool, optional) – If
True(default), the predicted values will include the mean function evaluated atX_test.kernel (Kernel, optional) – A kernel to optionally specify the component kernel to be used for predicting after conditioning. See Multicomponent Kernels for an example.
- Returns:
A named tuple where the first element
log_probabilityis the log marginal probability of the model, and the second elementgpis theGaussianProcessobject describing the conditional distribution evaluated atX_test.
- predict(X_test: tinygp.helpers.JAXArray | None = None, y: tinygp.helpers.JAXArray | None = None, *, return_full_state: bool = False, kernel: int | None = None, return_var: bool = False, observation_model: Any | None = None) tinygp.helpers.JAXArray | tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]#
Predict the GP model at new test points conditioned on observed data
- Parameters:
X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
Xdata provided when instantiating this object. If it is not provided,Xwill be used by default, so the predictions will be made at the data coordinates.y (JAXArray) – The observed data. Only needs to be given if the GP has not yet been conditioned. This should have the shape
(N_data,), whereN_datawas the zeroth axis of theXdata provided when instantiating this object.include_mean (bool, optional) – If
True(default), the predicted values will include the mean function evaluated atX_test.return_var (bool, optional) – If
True(default), the variance of the predicted values atX_testwill be returned.return_cov (bool, optional) – If
True, the covariance of the predicted values atX_testwill be returned. Ifreturn_varisTrue, this flag will be ignored.observation_model (Any, optional) – optionally provide a function of X_test to define the output observation model. Default will use that of the kernel.
return_full_state (bool, optional) – If
True, return the full predicted state mean and covariance, rather than projecting to observation space. Default isFalse, i.e. the result is projected through kernel.observation_model.kernel (int, optional) – If specified, the index of the kernel in a multi-component model (for example, a sum or product of kernels) to extract and project (if return_full_state is False) the prediction for.
- Returns:
The mean of the predictive model evaluated at
X_test, with shape(N_test,)whereN_testis the zeroth dimension ofX_test. If eitherreturn_varorreturn_covisTrue, the variance or covariance of the predicted process will also be returned with shape(N_test,)or(N_test, N_test)respectively.
- sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None = None) tinygp.helpers.JAXArray[source]#
Generate samples from the prior process
- Parameters:
key – A
jaxrandom number key array. shape (tuple, optional): Theto (number and shape of samples) – generate.
- Returns:
The sampled realizations from the process with shape
(N_data,) + shapewhereN_datais the zeroth dimension of theXcoordinates provided when instantiating this process.
- numpyro_dist(**kwargs: Any) tinygp.numpyro_support.TinyDistribution[source]#
Get the numpyro MultivariateNormal distribution for this process
- abstractmethod _sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None) tinygp.helpers.JAXArray[source]#
- _compute_log_prob(v: tinygp.helpers.JAXArray, S: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#
Compute the log-likelihood given v and S from the Kalman filter
- get_component_mean(component: list | str, return_var: bool = False, **kwargs) Any[source]#
Get the predictive mean (and variance) of a particular (or sum of) component kernel in a multi-component model evaluated at self.X
- Parameters:
X (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
Xdata provided when instantiating this object.component (list | str) – The name(s) of the component kernel(s) to extract the mean for. If a list of names is provided, the joint mean and variance for that collection of kernels will be returned.
return_var (bool, optional) – If
True, also return the variances of each component. Default isFalse.
- Returns:
component_mean (JAXArray) If
return_varisTrue:component_mean (JAXArray) component_var (JAXArray)
- Return type:
If
return_varisFalse
- get_all_component_means(return_var: bool = False, **kwargs) Any[source]#
Get the predictive mean (and optionally variance) of each component kernels individually, evaluated at self.X
- Parameters:
return_var (bool, optional) – If
True, also return the variances of each component. Default isFalse.- Returns:
If
return_varisFalse, a list of JAX arrays containing the means of each component kernel evaluated at the data points. Ifreturn_varisTrue, a tuple where the first element is the list of means as before, and the second element is a list of JAX arrays containing the variances of each component kernel evaluated at the data points.
- smolgp.__version__#