from __future__ import annotations
from collections.abc import Sequence
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
NamedTuple,
)
import equinox as eqx
import jax
import jax.numpy as jnp
from tinygp import kernels, means
from tinygp.helpers import JAXArray
from smolgp.kernels import StateSpaceModel, Sum, Product
from smolgp.kernels.base import extract_leaf_kernels
from smolgp.kernels.integrated import IntegratedStateSpaceModel
from smolgp.solvers import StateSpaceSolver
from smolgp.solvers import ParallelStateSpaceSolver
from smolgp.solvers.integrated import IntegratedStateSpaceSolver
from smolgp.solvers.integrated import ParallelIntegratedStateSpaceSolver
if TYPE_CHECKING:
from tinygp.numpyro_support import TinyDistribution
import dataclasses
[docs]
def assign_unique_kernel_names(kernel: StateSpaceModel) -> StateSpaceModel:
"""Return a new kernel where duplicated leaf kernel names are made unique by appending _1, _2, etc.
For example, if the original kernel has three components
named "SHO", "Matern", and "Matern", they will be renamed
to "SHO", "Matern_1", and "Matern_2". This is useful for
ensuring that the component kernels can be uniquely identified
when making predictions at test points or when extracting
component contributions.
"""
leaves = extract_leaf_kernels(kernel)
names = [k.name for k in leaves]
# Early exit if all names are unique (no duplicates)
if len(set(names)) == len(names):
return kernel
# Otherwise, count occurrences
counts = {}
for k in leaves:
counts[k.name] = counts.get(k.name, 0) + 1
# counter for how many times we've used each duplicated name
used = {name: 1 for name, c in counts.items() if c > 1}
def _rename(k: StateSpaceModel) -> StateSpaceModel:
if isinstance(k, Sum):
k1 = _rename(k.kernel1)
k2 = _rename(k.kernel2)
return Sum(k1, k2)
if isinstance(k, Product):
k1 = _rename(k.kernel1)
k2 = _rename(k.kernel2)
return Product(k1, k2)
# Leaf
if counts[k.name] > 1:
idx = used[k.name]
used[k.name] += 1
newname = f"{k.name}_{idx}"
return dataclasses.replace(k, name=newname)
else:
# Single occurrence: leave unchanged
return k
return _rename(kernel)
[docs]
class ConditionedStates(eqx.Module):
"""
An object to hold the conditioned means and variances
X: len(N) data coordinates
t_states: len(K) time coordinates of all states
instid : len(N) instrument ID for each measurement
obsid : len(K) observation IDs corresponding to the measurement at each state
stateid : len(K) state IDs corresponding to each state (0 for exposure-start, 1 for exposure-end)
predicted_mean/var : len(K) Kalman predicted state
filtered_mean/var : len(K) Kalman filtered state
smoothed_mean/var : len(K) RTS smoothed state
"""
X: JAXArray
t_states: JAXArray
instid: JAXArray
obsid: JAXArray
stateid: JAXArray
predicted_mean: JAXArray
filtered_mean: JAXArray
smoothed_mean: JAXArray
predicted_cov: JAXArray
filtered_cov: JAXArray
smoothed_cov: JAXArray
def __init__(
self,
X,
t_states: JAXArray,
instid: JAXArray,
obsid: JAXArray,
stateid: JAXArray,
m_pred: JAXArray,
P_pred: JAXArray,
m_filt: JAXArray,
P_filt: JAXArray,
m_smooth: JAXArray,
P_smooth: JAXArray,
):
self.X = X
self.t_states = t_states
self.instid = instid
self.obsid = obsid
self.stateid = stateid
self.predicted_mean = m_pred
self.predicted_cov = P_pred
self.filtered_mean = m_filt
self.filtered_cov = P_filt
self.smoothed_mean = m_smooth
self.smoothed_cov = P_smooth
[docs]
def __call__(self):
state_coords = (self.t_states, self.instid, self.obsid, self.stateid)
packaged_results = (
(self.predicted_mean, self.predicted_cov),
(self.filtered_mean, self.filtered_cov),
(self.smoothed_mean, self.smoothed_cov),
)
# This should match the output of solver.condition
return state_coords, packaged_results, None
[docs]
def project_at_data(self, observation_model):
"""
Project the states with measurements (e.g. exposure-ends)
and sort back into original order as the data
"""
@jax.jit
def project(X, m, P):
H = observation_model(X)
mu = H @ m
var = H @ P @ H.T
return mu, var
N = jnp.array(self.X).shape[-1]
ends_idx = jnp.nonzero(self.stateid == 1, size=N)[0]
sort = jnp.argsort(self.obsid[ends_idx])
idx = ends_idx[sort]
m_sel = jnp.take(self.smoothed_mean, idx, axis=0)
P_sel = jnp.take(self.smoothed_cov, idx, axis=0)
mu, var = jax.vmap(project)(self.X, m_sel, P_sel)
return mu.squeeze(), var.squeeze()
[docs]
class PredictedStates(eqx.Module):
"""
An object to hold the full predictive states
t_states: time coordinates at each state
mean : predictive mean vector for each state
cov : predictive covariance for each state
"""
t_states: JAXArray
mean: JAXArray
cov: JAXArray
kernel: StateSpaceModel
def __init__(
self,
t_states: JAXArray,
m: JAXArray,
P: JAXArray,
kernel: StateSpaceModel,
):
self.t_states = t_states
self.mean = m
self.cov = P
self.kernel = kernel
[docs]
def project_mean(self, observation_model) -> JAXArray:
"""
The projected mean at self.t_states given an observation model.
"""
def _project(H, m):
mu = H @ m
return mu
H = jax.vmap(observation_model)(self.t_states)
mu = jax.vmap(_project)(H, self.mean)
return mu.squeeze()
[docs]
def project_variance(self, observation_model) -> JAXArray:
"""
The projected variance at self.t_states given an observation model.
"""
def _project(H, P):
var = H @ P @ H.T
return var
H = jax.vmap(observation_model)(self.t_states)
var = jax.vmap(_project)(H, self.cov)
return var.squeeze()
@property
def loc(self) -> JAXArray:
"""
The overall mean at the predicted states
"""
return self.project_mean(self.kernel.observation_matrix)
@property
def variance(self) -> JAXArray:
"""
The overall variance at the predicted states
"""
return self.project_variance(self.kernel.observation_matrix)
[docs]
def get_component(
self,
component: str | list[str],
return_var: bool = False,
) -> PredictedStates:
"""
Extract the predicted states corresponding to a component kernel
"""
if isinstance(component, str):
component = [component]
## Get effective observation model for the desired component(s)
def H_comp(X):
H = self.kernel.observation_model(X, component=component[0])
for k, name in enumerate(component[1:]):
H += self.kernel.observation_model(X, component=name)
return H
## Project at test coordinates with component observation model
component_mean = self.project_mean(H_comp)
component_var = self.project_variance(H_comp)
if return_var:
return component_mean, component_var
else:
return component_mean
[docs]
def get_all_components(self, return_var: bool = False) -> dict[str, Any]:
"""
Extract the predicted mean/variance corresponding to each component kernel
"""
components = extract_leaf_kernels(self.kernel)
results = {}
for kernel in components:
name = kernel.name
if return_var:
mu_m, var_m = self.get_component(name, return_var=True)
results[name] = (mu_m, var_m)
else:
mu_m = self.get_component(name, return_var=False)
results[name] = mu_m
return results
[docs]
class GaussianProcess(eqx.Module):
r"""An interface for designing a Gaussian Process regression model.
:param kernel: The kernel function.
:type kernel: Kernel
:param X: The input coordinates — any PyTree compatible with ``kernel``
whose leading dimension has size ``N_data``.
For integrated kernels, pass ``(t, texp)`` where ``t`` is the array of
exposure midpoints and ``texp`` is the array of exposure durations.
:type X: JAXArray
:param noise: Observation noise covariance matrices with shape
``(N, D, D)``, where ``N`` is the number of data points and ``D`` is
the observation dimension (usually 1). Each slice ``noise[k]`` is the
:math:`D \times D` noise covariance for the ``k``-th observation.
A 1-D array of shape ``(N,)`` is interpreted as scalar per-observation
variances. Defaults to :math:`\sqrt{\varepsilon_{\mathrm{machine}}}
\cdot I` for all observations.
:type noise: JAXArray, optional
:param mean: A callable or constant mean function evaluated as
``mean(X)``.
:type mean: Callable, optional
:param solver: Solver class for filtering and smoothing. If ``None``
(default), selected automatically based on the kernel type.
"""
num_data: int = eqx.field(static=True)
dtype: jnp.dtype = eqx.field(static=True)
kernel: kernels.Kernel
X: JAXArray
mean_function: means.MeanBase
mean: JAXArray
var: JAXArray | None
noise: JAXArray
solver: StateSpaceSolver
states: ConditionedStates
def __init__(
self,
kernel: kernels.Kernel,
X: JAXArray,
*,
noise: JAXArray | None = None,
mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None,
solver: Any | None = None,
mean_value: JAXArray | None = None,
variance_value: JAXArray | None = None,
covariance_value: Any | None = None,
states: JAXArray | None = None,
use_unique_names: bool = True,
**solver_kwargs: Any,
):
# First, assign unique kernel names if needed
if use_unique_names:
self.kernel = assign_unique_kernel_names(kernel)
else:
self.kernel = kernel
# Check if the kernel contains any integrated components
kernels = extract_leaf_kernels(self.kernel)
is_integrated = any([isinstance(k, IntegratedStateSpaceModel) for k in kernels])
is_instantaneous = all([isinstance(k, StateSpaceModel) for k in kernels])
# If using an integrated solver, ensure X has both coords and bin sizes
if is_integrated:
assert isinstance(X, tuple) and len(X) > 1, (
"IntegratedStateSpaceSolver requires both the data coordinates (e.g. times)"
" and bin sizes (e.g. exposure times). These should be passed as X=(t, texp)"
" where t is the midpoint of each measurement and texp is the exposure time"
" (i.e. each measurement is over the interval [t - texp/2, t + texp/2])."
)
# Data coordinates (or tuple of coordinates)
self.X = X
# Mean function
if isinstance(mean, means.MeanBase):
self.mean_function = mean
elif mean is None:
self.mean_function = means.Mean(jnp.zeros(()))
else:
self.mean_function = means.Mean(mean)
if mean_value is None:
mean_value = jax.vmap(self.mean_function)(self.X)
self.num_data = mean_value.shape[0]
self.dtype = mean_value.dtype
self.mean = mean_value
self.var = variance_value
self.states = states
if self.mean.ndim > 2:
raise ValueError(
f"Invalid mean shape: expected ndim = 1 or 2, got ndim={self.mean.ndim}"
)
# Observation noise: shape (N, D, D)
# A 1-D array of shape (N,) is treated as scalar per-obs variance -> (N, 1, 1)
if noise is None:
jitter = _default_jitter(self.mean)
noise = jnp.full((self.num_data, 1, 1), jitter, dtype=self.dtype)
elif jnp.ndim(noise) == 1:
noise = jnp.asarray(noise)[:, None, None]
self.noise = noise
# Set up the solver
# TODO: add parallel flag and if so use ParallelIntegratedStateSpaceSolver?
if solver is None:
if is_integrated:
solver = IntegratedStateSpaceSolver
elif is_instantaneous:
solver = StateSpaceSolver
else:
raise ValueError(
"Must provide a solver if the kernel is not "
"a StateSpaceModel or IntegratedStateSpaceModel"
)
self.solver = solver(
kernel,
self.X,
self.noise,
**solver_kwargs,
)
# If solver type (uninstantiated) is passed
elif solver in [
StateSpaceSolver,
IntegratedStateSpaceSolver,
ParallelStateSpaceSolver,
ParallelIntegratedStateSpaceSolver,
]:
self.solver = solver(
kernel,
self.X,
self.noise,
**solver_kwargs,
)
# If a pre-instantiated solver is passed (e.g. like condGP)
else:
self.solver = solver
@property
def loc(self) -> JAXArray:
"""
If conditioned, this will be the mean at the data points
Otherwise, it is just the prior mean.
"""
return self.mean
@property
def variance(self) -> JAXArray:
"""
If conditioned, this will be the variance at the data points
Otherwise, it is just the prior variance.
"""
return self.var
@property
def covariance(self) -> JAXArray:
# TODO: Eq. 12.55 in Sarkka & Solin 2019
# if G = states.smoothing_gains exists, otherwise
# I guess we raise an error that its not conditioned?
# return self.covariance_value
raise NotImplementedError
[docs]
def log_probability(self, y: JAXArray) -> JAXArray:
"""Compute the log probability of this multivariate normal
Args:
y (JAXArray): The observed data. This should have the shape
``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
data provided when instantiating this object.
Returns:
The marginal log probability of this multivariate normal model,
evaluated at ``y``.
"""
_, _, _, _, v, S = self.solver.Kalman(y, return_v_S=True)
return self._compute_log_prob(v, S)
[docs]
def condition(
self,
y: JAXArray,
X_test: JAXArray | None = None,
*,
include_mean: bool = True,
kernel: kernels.Kernel | None = None, # TODO: select a component kernel
) -> ConditionResult:
"""Condition the model on observed data
Args:
y (JAXArray): The observed data. This should have the shape
``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
data provided when instantiating this object.
X_test (JAXArray, optional): The coordinates where the prediction
should be evaluated. This should have a data type compatible
with the ``X`` data provided when instantiating this object. If
it is not provided, ``X`` will be used by default, so the
predictions will be made.
include_mean (bool, optional): If ``True`` (default), the predicted
values will include the mean function evaluated at ``X_test``.
kernel (Kernel, optional): A kernel to optionally specify the component
kernel to be used for predicting after conditioning. See
:ref:`multicomponent` for an example.
Returns:
A named tuple where the first element ``log_probability`` is the log
marginal probability of the model, and the second element ``gp`` is
the :class:`GaussianProcess` object describing the conditional
distribution evaluated at ``X_test``.
"""
# If X_test is provided, we need to check that the tree structure
# matches that of the input data, and that the shapes are all compatible
# (i.e. the dimension of the inputs must match). This is slightly
# convoluted since we need to support arbitrary pytrees.
if X_test is not None:
matches = jax.tree_util.tree_map(
lambda a, b: (
jnp.ndim(a) == jnp.ndim(b) and jnp.shape(a)[1:] == jnp.shape(b)[1:]
),
self.X,
X_test,
)
if not jax.tree_util.tree_reduce(lambda a, b: a and b, matches):
raise ValueError(
"`X_test` must have the same tree structure as the input `X`, "
"and all but the leading dimension must have matching sizes"
)
## Condition on the data and return likelihood ingredients
conditioned_results = self.solver.condition(y, return_v_S=True)
## unpack into prediction at the states
state_coords, conditioned_states, (v, S) = conditioned_results
(
(m_predicted, P_predicted),
(m_filtered, P_filtered),
(m_smoothed, P_smoothed),
) = conditioned_states
if isinstance(
self.solver,
(IntegratedStateSpaceSolver, ParallelIntegratedStateSpaceSolver),
):
t_states, instid, obsid, stateid = state_coords
else:
# If not integrated, t_states = X and id arrays are 'defaulted'
t_states = self.kernel.coord_to_sortable(state_coords)
instid = jnp.zeros_like(t_states, dtype=int)
obsid = jnp.arange(len(t_states), dtype=int)
stateid = jnp.ones_like(t_states, dtype=int) # all "have data"
# Save the conditioned state values to a new GP object
# so we can use them to make quick predictions at test
# points with subsequent calls to self.predict
states = ConditionedStates(
self.X,
t_states,
instid,
obsid,
stateid,
m_predicted,
P_predicted,
m_filtered,
P_filtered,
m_smoothed,
P_smoothed,
)
## Grab likelihood (v and S will already be
## filtered down to the "at the data" states)
log_prob = self._compute_log_prob(v, S)
## Make predictions at X_test if given
if kernel is None:
# If no component kernel passed, use the full model
observation_model = self.kernel.observation_model
else:
# Otherwise use the observation model of the passed
# kernel, where we zero out all the other components
observation_model = lambda X: self.kernel.observation_model(
X, component=kernel.name
)
if X_test is not None:
# If X_test was given, also predit at those points
mu, var = self.solver.predict(X_test, conditioned_results)
else:
# Otherwise, project the conditioned states
# (at the data points) to observation space
X_test = self.X
mu, var = states.project_at_data(observation_model)
## Create the conditioned GP
condGP = GaussianProcess(
kernel=self.kernel,
X=X_test,
noise=self.noise,
# mean=self.mean,
solver=self.solver,
mean_value=mu,
variance_value=var,
states=states,
)
# Return the likelihood and conditioned GP
return ConditionResult(log_probability=log_prob, gp=condGP)
[docs]
def predict(
self,
X_test: JAXArray | None = None,
y: JAXArray | None = None,
*,
return_full_state: bool = False,
kernel: int | None = None,
# include_mean: bool = True,
return_var: bool = False,
# return_cov: bool = False,
observation_model: Any | None = None,
) -> JAXArray | tuple[JAXArray, JAXArray]:
"""Predict the GP model at new test points conditioned on observed data
Args:
X_test (JAXArray, optional): The coordinates where the prediction
should be evaluated. This should have a data type compatible
with the ``X`` data provided when instantiating this object. If
it is not provided, ``X`` will be used by default, so the
predictions will be made at the data coordinates.
y (JAXArray): The observed data. Only needs to be given if the GP
has not yet been conditioned. This should have the shape
``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
data provided when instantiating this object.
include_mean (bool, optional): If ``True`` (default), the predicted
values will include the mean function evaluated at ``X_test``.
return_var (bool, optional): If ``True`` (default), the variance of the
predicted values at ``X_test`` will be returned.
return_cov (bool, optional): If ``True``, the covariance of the
predicted values at ``X_test`` will be returned. If
``return_var`` is ``True``, this flag will be ignored.
observation_model (Any, optional): optionally provide a function of
X_test to define the output observation model.
Default will use that of the kernel.
return_full_state (bool, optional): If ``True``, return the full predicted state
mean and covariance, rather than projecting to observation space. Default is
``False``, i.e. the result is projected through kernel.observation_model.
kernel (int, optional): If specified, the index of the kernel in a
multi-component model (for example, a sum or product of kernels)
to extract and project (if return_full_state is False) the prediction for.
Returns:
The mean of the predictive model evaluated at ``X_test``, with shape
``(N_test,)`` where ``N_test`` is the zeroth dimension of
``X_test``. If either ``return_var`` or ``return_cov`` is ``True``,
the variance or covariance of the predicted process will also be
returned with shape ``(N_test,)`` or ``(N_test, N_test)``
respectively.
"""
if self.states is None:
# Need to condition the GP first
assert y is not None, (
"The GP has not been conditioned yet, and no data array `y` was given."
)
llh, condGP = self.condition(y)
return condGP.predict(
X_test,
return_full_state=return_full_state,
kernel=kernel,
return_var=return_var,
# return_cov=return_cov,
observation_model=observation_model,
)
else:
if X_test is None:
# If no X_test given, predict at the data points
if return_full_state:
mu = self.states.smoothed_mean
var = self.states.smoothed_cov
else:
if kernel is None:
# already computed here
mu, var = self.loc, self.var
else:
# extract component kernel & project
name = kernel if isinstance(kernel, str) else kernel.name
H_comp = lambda X: self.kernel.observation_model(
X, component=name
)
mu, var = self.states.project_at_data(H_comp)
else:
# Predicting at new test points
H_test = (
self.kernel.observation_model
if observation_model is None
else observation_model
)
mean, variance = self.solver.predict(X_test, self.states())
if return_full_state:
mu = mean
var = variance
return PredictedStates(
t_states=X_test, m=mu, P=var, kernel=self.kernel
)
else:
if kernel is not None:
name = kernel if isinstance(kernel, str) else kernel.name
H_test = lambda X: self.kernel.observation_model(
X, component=name
)
H = jax.vmap(H_test)(X_test)
mu = jax.vmap(lambda H_i, m: H_i @ m)(H, mean).squeeze()
var = jax.vmap(lambda H_i, P: H_i @ P @ H_i.T)(
H, variance
).squeeze()
if return_var:
return mu, var
# if return_cov:
# return mu, var
return mu
## TODO: how to define the sample function?
[docs]
def sample(
self,
key: jax.random.KeyArray,
shape: Sequence[int] | None = None,
) -> JAXArray:
"""Generate samples from the prior process
Args:
key: A ``jax`` random number key array. shape (tuple, optional): The
number and shape of samples to
generate.
Returns:
The sampled realizations from the process with shape ``(N_data,) +
shape`` where ``N_data`` is the zeroth dimension of the ``X``
coordinates provided when instantiating this process.
"""
return self._sample(key, shape)
[docs]
def numpyro_dist(self, **kwargs: Any) -> TinyDistribution:
"""Get the numpyro MultivariateNormal distribution for this process"""
from tinygp.numpyro_support import TinyDistribution
return TinyDistribution(self, **kwargs)
[docs]
@partial(jax.jit, static_argnums=(2,))
def _sample(
self,
key: jax.random.KeyArray,
shape: Sequence[int] | None,
) -> JAXArray:
raise NotImplementedError
## TODO: implement sampling for state space model
## fast method to try: https://www.stats.ox.ac.uk/~doucet/doucet_simulationconditionalgaussian.pdf
##
## or alternatively call the tinygp version? copied below:
# if shape is None:
# shape = (self.num_data,)
# else:
# shape = (self.num_data,) + tuple(shape)
# normal_samples = jax.random.normal(key, shape=shape, dtype=self.dtype)
# return self.mean + jnp.moveaxis(
# self.solver.dot_triangular(normal_samples), 0, -1
# )
[docs]
@jax.jit
def _compute_log_prob(self, v: JAXArray, S: JAXArray) -> JAXArray:
"""
Compute the log-likelihood given v and S from the Kalman filter
"""
## More readable version:
# def llh(k):
# v_k, S_k = v[k], S[k]
# L_k = jnp.linalg.cholesky(S_k)
# w = jax.scipy.linalg.solve_triangular(L_k, v_k, lower=True)
# quad = jnp.dot(w, w)
# logdetS_k = 2.0 * jnp.sum(jnp.log(jnp.diag(L_k)))
# d = v_k.shape[0]
# return quad + logdetS_k + d*jnp.log(2*jnp.pi)
# loglike = -0.5 * jnp.sum(jax.vmap(llh)(jnp.arange(len(v))))
L = jax.vmap(jnp.linalg.cholesky)(S) # [T, D, D]
w = jax.scipy.linalg.solve_triangular(L, v[..., None], lower=True)
w = jnp.squeeze(w, axis=-1)
quad = jnp.sum(w**2, axis=1)
logdetS = 2.0 * jnp.sum(jnp.log(jnp.diagonal(L, axis1=-2, axis2=-1)), axis=1)
d = v.shape[1]
log_probs = quad + logdetS + d * jnp.log(2.0 * jnp.pi)
loglike = -0.5 * jnp.sum(log_probs)
return jnp.where(jnp.isfinite(loglike), loglike, -jnp.inf)
[docs]
def get_component_mean(
self,
component: list | str,
return_var: bool = False,
**kwargs,
) -> Any:
"""
Get the predictive mean (and variance) of a particular
(or sum of) component kernel in a multi-component model
evaluated at self.X
Args:
X (JAXArray, optional): The coordinates where the prediction
should be evaluated. This should have a data type compatible
with the ``X`` data provided when instantiating this object.
component (list | str): The name(s) of the component kernel(s)
to extract the mean for. If a list of names is provided,
the joint mean and variance for that collection of kernels
will be returned.
return_var (bool, optional): If ``True``, also return the variances
of each component. Default is ``False``.
Returns:
If ``return_var`` is ``False``:
component_mean (JAXArray)
If ``return_var`` is ``True``:
component_mean (JAXArray)
component_var (JAXArray)
"""
if self.states is None:
raise ValueError(
"The GP must be conditioned before getting component means."
)
if isinstance(component, str):
component = [component]
## Get effective observation model for the desired component(s)
def H_comp(X):
H = self.kernel.observation_model(X, component=component[0])
for k, name in enumerate(component[1:]):
H += self.kernel.observation_model(X, component=name)
return H
## Project at data
component_mean, component_var = self.states.project_at_data(H_comp)
if return_var:
return component_mean, component_var
else:
return component_mean
[docs]
def get_all_component_means(self, return_var: bool = False, **kwargs) -> Any:
"""
Get the predictive mean (and optionally variance) of each
component kernels individually, evaluated at self.X
Args:
return_var (bool, optional): If ``True``, also return the variances
of each component. Default is ``False``.
Returns:
If ``return_var`` is ``False``, a list of JAX arrays containing the
means of each component kernel evaluated at the data points.
If ``return_var`` is ``True``, a tuple where the first element is
the list of means as before, and the second element is a list of
JAX arrays containing the variances of each component kernel
evaluated at the data points.
"""
if self.states is None:
raise ValueError(
"The GP must be conditioned before getting component means."
)
## First, extract all kernels
kernels = extract_leaf_kernels(self.kernel)
## Loop through and project each component
results = {}
for k, kernel in enumerate(kernels):
mu, var = self.get_component_mean(
component=kernel.name, return_var=True, kwargs=kwargs
)
if return_var:
results[kernel.name] = (mu, var)
else:
results[kernel.name] = mu
return results
[docs]
class ConditionResult(NamedTuple):
"""The result of conditioning a :class:`GaussianProcess` on data
This has two entries, ``log_probability`` and ``gp``, that are described
below.
"""
log_probability: JAXArray
"""The log probability of the conditioned model
In other words, this is the marginal likelihood for the kernel parameters,
given the observed data, or the multivariate normal log probability
evaluated at the given data.
"""
gp: GaussianProcess
"""A :class:`GaussianProcess` describing the conditional distribution
This will have a mean and covariance conditioned on the observed data, but
it is otherwise a fully functional GP that can sample from or condition
further (although that's probably not going to be very efficient).
"""
[docs]
def _default_jitter(reference: JAXArray) -> JAXArray:
"""Default to adding some amount of jitter to the diagonal, just in case,
we use sqrt(eps) for the dtype of the mean function because that seems to
give sensible results in general.
"""
return jnp.sqrt(jnp.finfo(reference.dtype).eps)