gp#
Classes#
An object to hold the conditioned means and variances |
|
An object to hold the full predictive states |
|
An interface for designing a Gaussian Process regression model. |
|
The result of conditioning a |
Functions#
Return a new kernel where duplicated leaf kernel names are made unique by appending _1, _2, etc. |
|
|
Default to adding some amount of jitter to the diagonal, just in case, |
Module Contents#
- smolgp.gp.assign_unique_kernel_names(kernel: smolgp.kernels.StateSpaceModel) smolgp.kernels.StateSpaceModel[source]#
Return a new kernel where duplicated leaf kernel names are made unique by appending _1, _2, etc.
For example, if the original kernel has three components named “SHO”, “Matern”, and “Matern”, they will be renamed to “SHO”, “Matern_1”, and “Matern_2”. This is useful for ensuring that the component kernels can be uniquely identified when making predictions at test points or when extracting component contributions.
- class smolgp.gp.ConditionedStates(X, t_states: tinygp.helpers.JAXArray, instid: tinygp.helpers.JAXArray, obsid: tinygp.helpers.JAXArray, stateid: tinygp.helpers.JAXArray, m_pred: tinygp.helpers.JAXArray, P_pred: tinygp.helpers.JAXArray, m_filt: tinygp.helpers.JAXArray, P_filt: tinygp.helpers.JAXArray, m_smooth: tinygp.helpers.JAXArray, P_smooth: tinygp.helpers.JAXArray)[source]#
Bases:
equinox.ModuleAn object to hold the conditioned means and variances
X: len(N) data coordinates t_states: len(K) time coordinates of all states instid : len(N) instrument ID for each measurement obsid : len(K) observation IDs corresponding to the measurement at each state stateid : len(K) state IDs corresponding to each state (0 for exposure-start, 1 for exposure-end) predicted_mean/var : len(K) Kalman predicted state filtered_mean/var : len(K) Kalman filtered state smoothed_mean/var : len(K) RTS smoothed state
- X: tinygp.helpers.JAXArray#
- t_states: tinygp.helpers.JAXArray#
- instid: tinygp.helpers.JAXArray#
- obsid: tinygp.helpers.JAXArray#
- stateid: tinygp.helpers.JAXArray#
- predicted_mean: tinygp.helpers.JAXArray#
- filtered_mean: tinygp.helpers.JAXArray#
- smoothed_mean: tinygp.helpers.JAXArray#
- predicted_cov: tinygp.helpers.JAXArray#
- filtered_cov: tinygp.helpers.JAXArray#
- smoothed_cov: tinygp.helpers.JAXArray#
- class smolgp.gp.PredictedStates(t_states: tinygp.helpers.JAXArray, m: tinygp.helpers.JAXArray, P: tinygp.helpers.JAXArray, kernel: smolgp.kernels.StateSpaceModel)[source]#
Bases:
equinox.ModuleAn object to hold the full predictive states
t_states: time coordinates at each state mean : predictive mean vector for each state cov : predictive covariance for each state
- t_states: tinygp.helpers.JAXArray#
- mean: tinygp.helpers.JAXArray#
- cov: tinygp.helpers.JAXArray#
- kernel: smolgp.kernels.StateSpaceModel#
- project_mean(observation_model) tinygp.helpers.JAXArray[source]#
The projected mean at self.t_states given an observation model.
- project_variance(observation_model) tinygp.helpers.JAXArray[source]#
The projected variance at self.t_states given an observation model.
- property loc: tinygp.helpers.JAXArray#
The overall mean at the predicted states
- property variance: tinygp.helpers.JAXArray#
The overall variance at the predicted states
- get_component(component: str | list[str], return_var: bool = False) PredictedStates[source]#
Extract the predicted states corresponding to a component kernel
- class smolgp.gp.GaussianProcess(kernel: GaussianProcess.__init__.kernels, X: tinygp.helpers.JAXArray, *, noise: tinygp.helpers.JAXArray | None = None, mean: tinygp.means.MeanBase | Callable[[tinygp.helpers.JAXArray], tinygp.helpers.JAXArray] | tinygp.helpers.JAXArray | None = None, solver: Any | None = None, mean_value: tinygp.helpers.JAXArray | None = None, variance_value: tinygp.helpers.JAXArray | None = None, covariance_value: Any | None = None, states: tinygp.helpers.JAXArray | None = None, use_unique_names: bool = True, **solver_kwargs: Any)[source]#
Bases:
equinox.ModuleAn interface for designing a Gaussian Process regression model.
- Parameters:
kernel (Kernel) – The kernel function.
X (JAXArray) – The input coordinates — any PyTree compatible with
kernelwhose leading dimension has sizeN_data. For integrated kernels, pass(t, texp)wheretis the array of exposure midpoints andtexpis the array of exposure durations.noise (JAXArray, optional) – Observation noise covariance matrices with shape
(N, D, D), whereNis the number of data points andDis the observation dimension (usually 1). Each slicenoise[k]is the \(D \times D\) noise covariance for thek-th observation. A 1-D array of shape(N,)is interpreted as scalar per-observation variances. Defaults to \(\sqrt{\varepsilon_{\mathrm{machine}}} \cdot I\) for all observations.mean (Callable, optional) – A callable or constant mean function evaluated as
mean(X).solver – Solver class for filtering and smoothing. If
None(default), selected automatically based on the kernel type.
- num_data: int#
- dtype: jax.numpy.dtype#
- kernel: tinygp.kernels.Kernel#
- X: tinygp.helpers.JAXArray#
- mean_function: tinygp.means.MeanBase#
- mean: tinygp.helpers.JAXArray#
- var: tinygp.helpers.JAXArray | None#
- noise: tinygp.helpers.JAXArray#
- solver: smolgp.solvers.StateSpaceSolver#
- states: ConditionedStates#
- property loc: tinygp.helpers.JAXArray#
If conditioned, this will be the mean at the data points Otherwise, it is just the prior mean.
- property variance: tinygp.helpers.JAXArray#
If conditioned, this will be the variance at the data points Otherwise, it is just the prior variance.
- property covariance: tinygp.helpers.JAXArray#
- Abstractmethod:
- log_probability(y: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#
Compute the log probability of this multivariate normal
- Parameters:
y (JAXArray) – The observed data. This should have the shape
(N_data,), whereN_datawas the zeroth axis of theXdata provided when instantiating this object.- Returns:
The marginal log probability of this multivariate normal model, evaluated at
y.
- condition(y: tinygp.helpers.JAXArray, X_test: tinygp.helpers.JAXArray | None = None, *, include_mean: bool = True, kernel: tinygp.kernels.Kernel | None = None) ConditionResult[source]#
Condition the model on observed data
- Parameters:
y (JAXArray) – The observed data. This should have the shape
(N_data,), whereN_datawas the zeroth axis of theXdata provided when instantiating this object.X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
Xdata provided when instantiating this object. If it is not provided,Xwill be used by default, so the predictions will be made.include_mean (bool, optional) – If
True(default), the predicted values will include the mean function evaluated atX_test.kernel (Kernel, optional) – A kernel to optionally specify the component kernel to be used for predicting after conditioning. See Multicomponent Kernels for an example.
- Returns:
A named tuple where the first element
log_probabilityis the log marginal probability of the model, and the second elementgpis theGaussianProcessobject describing the conditional distribution evaluated atX_test.
- predict(X_test: tinygp.helpers.JAXArray | None = None, y: tinygp.helpers.JAXArray | None = None, *, return_full_state: bool = False, kernel: int | None = None, return_var: bool = False, observation_model: Any | None = None) tinygp.helpers.JAXArray | tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]#
Predict the GP model at new test points conditioned on observed data
- Parameters:
X_test (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
Xdata provided when instantiating this object. If it is not provided,Xwill be used by default, so the predictions will be made at the data coordinates.y (JAXArray) – The observed data. Only needs to be given if the GP has not yet been conditioned. This should have the shape
(N_data,), whereN_datawas the zeroth axis of theXdata provided when instantiating this object.include_mean (bool, optional) – If
True(default), the predicted values will include the mean function evaluated atX_test.return_var (bool, optional) – If
True(default), the variance of the predicted values atX_testwill be returned.return_cov (bool, optional) – If
True, the covariance of the predicted values atX_testwill be returned. Ifreturn_varisTrue, this flag will be ignored.observation_model (Any, optional) – optionally provide a function of X_test to define the output observation model. Default will use that of the kernel.
return_full_state (bool, optional) – If
True, return the full predicted state mean and covariance, rather than projecting to observation space. Default isFalse, i.e. the result is projected through kernel.observation_model.kernel (int, optional) – If specified, the index of the kernel in a multi-component model (for example, a sum or product of kernels) to extract and project (if return_full_state is False) the prediction for.
- Returns:
The mean of the predictive model evaluated at
X_test, with shape(N_test,)whereN_testis the zeroth dimension ofX_test. If eitherreturn_varorreturn_covisTrue, the variance or covariance of the predicted process will also be returned with shape(N_test,)or(N_test, N_test)respectively.
- sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None = None) tinygp.helpers.JAXArray[source]#
Generate samples from the prior process
- Parameters:
key – A
jaxrandom number key array. shape (tuple, optional): Theto (number and shape of samples) – generate.
- Returns:
The sampled realizations from the process with shape
(N_data,) + shapewhereN_datais the zeroth dimension of theXcoordinates provided when instantiating this process.
- numpyro_dist(**kwargs: Any) tinygp.numpyro_support.TinyDistribution[source]#
Get the numpyro MultivariateNormal distribution for this process
- abstractmethod _sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None) tinygp.helpers.JAXArray[source]#
- _compute_log_prob(v: tinygp.helpers.JAXArray, S: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]#
Compute the log-likelihood given v and S from the Kalman filter
- get_component_mean(component: list | str, return_var: bool = False, **kwargs) Any[source]#
Get the predictive mean (and variance) of a particular (or sum of) component kernel in a multi-component model evaluated at self.X
- Parameters:
X (JAXArray, optional) – The coordinates where the prediction should be evaluated. This should have a data type compatible with the
Xdata provided when instantiating this object.component (list | str) – The name(s) of the component kernel(s) to extract the mean for. If a list of names is provided, the joint mean and variance for that collection of kernels will be returned.
return_var (bool, optional) – If
True, also return the variances of each component. Default isFalse.
- Returns:
component_mean (JAXArray) If
return_varisTrue:component_mean (JAXArray) component_var (JAXArray)
- Return type:
If
return_varisFalse
- get_all_component_means(return_var: bool = False, **kwargs) Any[source]#
Get the predictive mean (and optionally variance) of each component kernels individually, evaluated at self.X
- Parameters:
return_var (bool, optional) – If
True, also return the variances of each component. Default isFalse.- Returns:
If
return_varisFalse, a list of JAX arrays containing the means of each component kernel evaluated at the data points. Ifreturn_varisTrue, a tuple where the first element is the list of means as before, and the second element is a list of JAX arrays containing the variances of each component kernel evaluated at the data points.
- class smolgp.gp.ConditionResult[source]#
Bases:
NamedTupleThe result of conditioning a
GaussianProcesson dataThis has two entries,
log_probabilityandgp, that are described below.- log_probability: tinygp.helpers.JAXArray#
The log probability of the conditioned model
In other words, this is the marginal likelihood for the kernel parameters, given the observed data, or the multivariate normal log probability evaluated at the given data.
- gp: GaussianProcess#
A
GaussianProcessdescribing the conditional distributionThis will have a mean and covariance conditioned on the observed data, but it is otherwise a fully functional GP that can sample from or condition further (although that’s probably not going to be very efficient).