gp#

Classes#

ConditionedStates

An object to hold the conditioned means and variances

PredictedStates

An object to hold the full predictive states

GaussianProcess

An interface for designing a Gaussian Process regression model.

ConditionResult

The result of conditioning a GaussianProcess on data

Functions#

assign_unique_kernel_names(...)

Return a new kernel where duplicated leaf kernel names are made unique by appending _1, _2, etc.

_default_jitter(→ tinygp.helpers.JAXArray)

Default to adding some amount of jitter to the diagonal, just in case,

Module Contents#

smolgp.gp.assign_unique_kernel_names(kernel: smolgp.kernels.StateSpaceModel) smolgp.kernels.StateSpaceModel[source]#

Return a new kernel where duplicated leaf kernel names are made unique by appending _1, _2, etc.

For example, if the original kernel has three components named “SHO”, “Matern”, and “Matern”, they will be renamed to “SHO”, “Matern_1”, and “Matern_2”. This is useful for ensuring that the component kernels can be uniquely identified when making predictions at test points or when extracting component contributions.

class smolgp.gp.ConditionedStates(X, t_states: tinygp.helpers.JAXArray, instid: tinygp.helpers.JAXArray, obsid: tinygp.helpers.JAXArray, stateid: tinygp.helpers.JAXArray, m_pred: tinygp.helpers.JAXArray, P_pred: tinygp.helpers.JAXArray, m_filt: tinygp.helpers.JAXArray, P_filt: tinygp.helpers.JAXArray, m_smooth: tinygp.helpers.JAXArray, P_smooth: tinygp.helpers.JAXArray)[source]#

Bases: equinox.Module

An object to hold the conditioned means and variances

X: len(N) data coordinates t_states: len(K) time coordinates of all states instid : len(N) instrument ID for each measurement obsid : len(K) observation IDs corresponding to the measurement at each state stateid : len(K) state IDs corresponding to each state (0 for exposure-start, 1 for exposure-end) predicted_mean/var : len(K) Kalman predicted state filtered_mean/var : len(K) Kalman filtered state smoothed_mean/var : len(K) RTS smoothed state

X: tinygp.helpers.JAXArray#
t_states: tinygp.helpers.JAXArray#
instid: tinygp.helpers.JAXArray#
obsid: tinygp.helpers.JAXArray#
stateid: tinygp.helpers.JAXArray#
predicted_mean: tinygp.helpers.JAXArray#
filtered_mean: tinygp.helpers.JAXArray#
smoothed_mean: tinygp.helpers.JAXArray#
predicted_cov: tinygp.helpers.JAXArray#
filtered_cov: tinygp.helpers.JAXArray#
smoothed_cov: tinygp.helpers.JAXArray#
__call__()[source]#
project_at_data(observation_model)[source]#

Project the states with measurements (e.g. exposure-ends) and sort back into original order as the data

class smolgp.gp.PredictedStates(t_states: tinygp.helpers.JAXArray, m: tinygp.helpers.JAXArray, P: tinygp.helpers.JAXArray, kernel: smolgp.kernels.StateSpaceModel)[source]#

Bases: equinox.Module

An object to hold the full predictive states

t_states: time coordinates at each state mean : predictive mean vector for each state cov : predictive covariance for each state

t_states: tinygp.helpers.JAXArray#
mean: tinygp.helpers.JAXArray#
cov: tinygp.helpers.JAXArray#
kernel: smolgp.kernels.StateSpaceModel#
project_mean(observation_model) tinygp.helpers.JAXArray[source]#

The projected mean at self.t_states given an observation model.

project_variance(observation_model) tinygp.helpers.JAXArray[source]#

The projected variance at self.t_states given an observation model.

property loc: tinygp.helpers.JAXArray#

The overall mean at the predicted states

property variance: tinygp.helpers.JAXArray#

The overall variance at the predicted states

get_component(component: str | list[str], return_var: bool = False) PredictedStates[source]#

Extract the predicted states corresponding to a component kernel

get_all_components(return_var: bool = False) dict[str, Any][source]#

Extract the predicted mean/variance corresponding to each component kernel

class smolgp.gp.GaussianProcess(kernel: GaussianProcess.__init__.kernels, X: tinygp.helpers.JAXArray, *, noise: tinygp.helpers.JAXArray | None = None, mean: tinygp.means.MeanBase | Callable[[tinygp.helpers.JAXArray], tinygp.helpers.JAXArray] | tinygp.helpers.JAXArray | None = None, solver: Any | None = None, mean_value: tinygp.helpers.JAXArray | None = None, variance_value: tinygp.helpers.JAXArray | None = None, covariance_value: Any | None = None, states: tinygp.helpers.JAXArray | None = None, use_unique_names: bool = True, **solver_kwargs: Any)[source]#

Bases: equinox.Module

An interface for designing a Gaussian Process regression model.

Parameters:
  • kernel (Kernel) – The kernel function.

  • X (JAXArray) – The input coordinates — any PyTree compatible with kernel whose leading dimension has size N_data. For integrated kernels, pass (t, texp) where t is the array of exposure midpoints and texp is the array of exposure durations.

  • noise (JAXArray, optional) – Observation noise covariance matrices with shape (N, D, D), where N is the number of data points and D is the observation dimension (usually 1). Each slice noise[k] is the \(D \times D\) noise covariance for the k-th observation. A 1-D array of shape (N,) is interpreted as scalar per-observation variances. Defaults to \(\sqrt{\varepsilon_{\mathrm{machine}}} \cdot I\) for all observations.

  • mean (Callable, optional) – A callable or constant mean function evaluated as mean(X).

  • solver – Solver class for filtering and smoothing. If None (default), selected automatically based on the kernel type.

num_data: int#
dtype: jax.numpy.dtype#
kernel: tinygp.kernels.Kernel#
X: tinygp.helpers.JAXArray#
mean_function: tinygp.means.MeanBase#
mean: tinygp.helpers.JAXArray#
var: tinygp.helpers.JAXArray | None#
noise: tinygp.helpers.JAXArray#
solver: smolgp.solvers.StateSpaceSolver#
states: ConditionedStates#
property loc: tinygp.helpers.JAXArray#

If conditioned, this will be the mean at the data points Otherwise, it is just the prior mean.

property variance: tinygp.helpers.JAXArray#

If conditioned, this will be the variance at the data points Otherwise, it is just the prior variance.

property covariance: tinygp.helpers.JAXArray#
Abstractmethod:

log_probability(y: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#

Compute the log probability of this multivariate normal

Parameters:

y (JAXArray) – The observed data. This should have the shape (N_data,), where N_data was the zeroth axis of the X data provided when instantiating this object.

Returns:

The marginal log probability of this multivariate normal model, evaluated at y.

condition(y: tinygp.helpers.JAXArray, X_test: tinygp.helpers.JAXArray | None = None, *, include_mean: bool = True, kernel: tinygp.kernels.Kernel | None = None) ConditionResult[source]#

Condition the model on observed data

Parameters:
  • y (JAXArray) – The observed data. This should have the shape (N_data,), where N_data was the zeroth axis of the X data provided when instantiating this object.

  • X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the X data provided when instantiating this object. If it is not provided, X will be used by default, so the predictions will be made.

  • include_mean (bool, optional) – If True (default), the predicted values will include the mean function evaluated at X_test.

  • kernel (Kernel, optional) – A kernel to optionally specify the component kernel to be used for predicting after conditioning. See Multicomponent Kernels for an example.

Returns:

A named tuple where the first element log_probability is the log marginal probability of the model, and the second element gp is the GaussianProcess object describing the conditional distribution evaluated at X_test.

predict(X_test: tinygp.helpers.JAXArray | None = None, y: tinygp.helpers.JAXArray | None = None, *, return_full_state: bool = False, kernel: int | None = None, return_var: bool = False, observation_model: Any | None = None) tinygp.helpers.JAXArray | tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]#

Predict the GP model at new test points conditioned on observed data

Parameters:
  • X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the X data provided when instantiating this object. If it is not provided, X will be used by default, so the predictions will be made at the data coordinates.

  • y (JAXArray) – The observed data. Only needs to be given if the GP has not yet been conditioned. This should have the shape (N_data,), where N_data was the zeroth axis of the X data provided when instantiating this object.

  • include_mean (bool, optional) – If True (default), the predicted values will include the mean function evaluated at X_test.

  • return_var (bool, optional) – If True (default), the variance of the predicted values at X_test will be returned.

  • return_cov (bool, optional) – If True, the covariance of the predicted values at X_test will be returned. If return_var is True, this flag will be ignored.

  • observation_model (Any, optional) – optionally provide a function of X_test to define the output observation model. Default will use that of the kernel.

  • return_full_state (bool, optional) – If True, return the full predicted state mean and covariance, rather than projecting to observation space. Default is False, i.e. the result is projected through kernel.observation_model.

  • kernel (int, optional) – If specified, the index of the kernel in a multi-component model (for example, a sum or product of kernels) to extract and project (if return_full_state is False) the prediction for.

Returns:

The mean of the predictive model evaluated at X_test, with shape (N_test,) where N_test is the zeroth dimension of X_test. If either return_var or return_cov is True, the variance or covariance of the predicted process will also be returned with shape (N_test,) or (N_test, N_test) respectively.

sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None = None) tinygp.helpers.JAXArray[source]#

Generate samples from the prior process

Parameters:
  • key – A jax random number key array. shape (tuple, optional): The

  • to (number and shape of samples) – generate.

Returns:

The sampled realizations from the process with shape (N_data,) + shape where N_data is the zeroth dimension of the X coordinates provided when instantiating this process.

numpyro_dist(**kwargs: Any) tinygp.numpyro_support.TinyDistribution[source]#

Get the numpyro MultivariateNormal distribution for this process

abstractmethod _sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None) tinygp.helpers.JAXArray[source]#
_compute_log_prob(v: tinygp.helpers.JAXArray, S: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#

Compute the log-likelihood given v and S from the Kalman filter

get_component_mean(component: list | str, return_var: bool = False, **kwargs) Any[source]#

Get the predictive mean (and variance) of a particular (or sum of) component kernel in a multi-component model evaluated at self.X

Parameters:
  • X (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the X data provided when instantiating this object.

  • component (list | str) – The name(s) of the component kernel(s) to extract the mean for. If a list of names is provided, the joint mean and variance for that collection of kernels will be returned.

  • return_var (bool, optional) – If True, also return the variances of each component. Default is False.

Returns:

component_mean (JAXArray) If return_var is True:

component_mean (JAXArray) component_var (JAXArray)

Return type:

If return_var is False

get_all_component_means(return_var: bool = False, **kwargs) Any[source]#

Get the predictive mean (and optionally variance) of each component kernels individually, evaluated at self.X

Parameters:

return_var (bool, optional) – If True, also return the variances of each component. Default is False.

Returns:

If return_var is False, a list of JAX arrays containing the means of each component kernel evaluated at the data points. If return_var is True, a tuple where the first element is the list of means as before, and the second element is a list of JAX arrays containing the variances of each component kernel evaluated at the data points.

class smolgp.gp.ConditionResult[source]#

Bases: NamedTuple

The result of conditioning a GaussianProcess on data

This has two entries, log_probability and gp, that are described below.

log_probability: tinygp.helpers.JAXArray#

The log probability of the conditioned model

In other words, this is the marginal likelihood for the kernel parameters, given the observed data, or the multivariate normal log probability evaluated at the given data.

gp: GaussianProcess#

A GaussianProcess describing the conditional distribution

This will have a mean and covariance conditioned on the observed data, but it is otherwise a fully functional GP that can sample from or condition further (although that’s probably not going to be very efficient).

smolgp.gp._default_jitter(reference: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#

Default to adding some amount of jitter to the diagonal, just in case, we use sqrt(eps) for the dtype of the mean function because that seems to give sensible results in general.