Source code for smolgp.gp

from __future__ import annotations
from collections.abc import Sequence
from functools import partial
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    NamedTuple,
)

import equinox as eqx
import jax
import jax.numpy as jnp

from tinygp import kernels, means
from tinygp.helpers import JAXArray

from smolgp.kernels import StateSpaceModel, Sum, Product
from smolgp.kernels.base import extract_leaf_kernels
from smolgp.kernels.integrated import IntegratedStateSpaceModel

from smolgp.solvers import StateSpaceSolver
from smolgp.solvers import ParallelStateSpaceSolver
from smolgp.solvers.integrated import IntegratedStateSpaceSolver
from smolgp.solvers.integrated import ParallelIntegratedStateSpaceSolver

if TYPE_CHECKING:
    from tinygp.numpyro_support import TinyDistribution

import dataclasses


[docs] def assign_unique_kernel_names(kernel: StateSpaceModel) -> StateSpaceModel: """Return a new kernel where duplicated leaf kernel names are made unique by appending _1, _2, etc. For example, if the original kernel has three components named "SHO", "Matern", and "Matern", they will be renamed to "SHO", "Matern_1", and "Matern_2". This is useful for ensuring that the component kernels can be uniquely identified when making predictions at test points or when extracting component contributions. """ leaves = extract_leaf_kernels(kernel) names = [k.name for k in leaves] # Early exit if all names are unique (no duplicates) if len(set(names)) == len(names): return kernel # Otherwise, count occurrences counts = {} for k in leaves: counts[k.name] = counts.get(k.name, 0) + 1 # counter for how many times we've used each duplicated name used = {name: 1 for name, c in counts.items() if c > 1} def _rename(k: StateSpaceModel) -> StateSpaceModel: if isinstance(k, Sum): k1 = _rename(k.kernel1) k2 = _rename(k.kernel2) return Sum(k1, k2) if isinstance(k, Product): k1 = _rename(k.kernel1) k2 = _rename(k.kernel2) return Product(k1, k2) # Leaf if counts[k.name] > 1: idx = used[k.name] used[k.name] += 1 newname = f"{k.name}_{idx}" return dataclasses.replace(k, name=newname) else: # Single occurrence: leave unchanged return k return _rename(kernel)
[docs] class ConditionedStates(eqx.Module): """ An object to hold the conditioned means and variances X: len(N) data coordinates t_states: len(K) time coordinates of all states instid : len(N) instrument ID for each measurement obsid : len(K) observation IDs corresponding to the measurement at each state stateid : len(K) state IDs corresponding to each state (0 for exposure-start, 1 for exposure-end) predicted_mean/var : len(K) Kalman predicted state filtered_mean/var : len(K) Kalman filtered state smoothed_mean/var : len(K) RTS smoothed state """ X: JAXArray t_states: JAXArray instid: JAXArray obsid: JAXArray stateid: JAXArray predicted_mean: JAXArray filtered_mean: JAXArray smoothed_mean: JAXArray predicted_cov: JAXArray filtered_cov: JAXArray smoothed_cov: JAXArray def __init__( self, X, t_states: JAXArray, instid: JAXArray, obsid: JAXArray, stateid: JAXArray, m_pred: JAXArray, P_pred: JAXArray, m_filt: JAXArray, P_filt: JAXArray, m_smooth: JAXArray, P_smooth: JAXArray, ): self.X = X self.t_states = t_states self.instid = instid self.obsid = obsid self.stateid = stateid self.predicted_mean = m_pred self.predicted_cov = P_pred self.filtered_mean = m_filt self.filtered_cov = P_filt self.smoothed_mean = m_smooth self.smoothed_cov = P_smooth
[docs] def __call__(self): state_coords = (self.t_states, self.instid, self.obsid, self.stateid) packaged_results = ( (self.predicted_mean, self.predicted_cov), (self.filtered_mean, self.filtered_cov), (self.smoothed_mean, self.smoothed_cov), ) # This should match the output of solver.condition return state_coords, packaged_results, None
[docs] def project_at_data(self, observation_model): """ Project the states with measurements (e.g. exposure-ends) and sort back into original order as the data """ @jax.jit def project(X, m, P): H = observation_model(X) mu = H @ m var = H @ P @ H.T return mu, var N = jnp.array(self.X).shape[-1] ends_idx = jnp.nonzero(self.stateid == 1, size=N)[0] sort = jnp.argsort(self.obsid[ends_idx]) idx = ends_idx[sort] m_sel = jnp.take(self.smoothed_mean, idx, axis=0) P_sel = jnp.take(self.smoothed_cov, idx, axis=0) mu, var = jax.vmap(project)(self.X, m_sel, P_sel) return mu.squeeze(), var.squeeze()
[docs] class PredictedStates(eqx.Module): """ An object to hold the full predictive states t_states: time coordinates at each state mean : predictive mean vector for each state cov : predictive covariance for each state """ t_states: JAXArray mean: JAXArray cov: JAXArray kernel: StateSpaceModel def __init__( self, t_states: JAXArray, m: JAXArray, P: JAXArray, kernel: StateSpaceModel, ): self.t_states = t_states self.mean = m self.cov = P self.kernel = kernel
[docs] def project_mean(self, observation_model) -> JAXArray: """ The projected mean at self.t_states given an observation model. """ def _project(H, m): mu = H @ m return mu H = jax.vmap(observation_model)(self.t_states) mu = jax.vmap(_project)(H, self.mean) return mu.squeeze()
[docs] def project_variance(self, observation_model) -> JAXArray: """ The projected variance at self.t_states given an observation model. """ def _project(H, P): var = H @ P @ H.T return var H = jax.vmap(observation_model)(self.t_states) var = jax.vmap(_project)(H, self.cov) return var.squeeze()
@property def loc(self) -> JAXArray: """ The overall mean at the predicted states """ return self.project_mean(self.kernel.observation_matrix) @property def variance(self) -> JAXArray: """ The overall variance at the predicted states """ return self.project_variance(self.kernel.observation_matrix)
[docs] def get_component( self, component: str | list[str], return_var: bool = False, ) -> PredictedStates: """ Extract the predicted states corresponding to a component kernel """ if isinstance(component, str): component = [component] ## Get effective observation model for the desired component(s) def H_comp(X): H = self.kernel.observation_model(X, component=component[0]) for k, name in enumerate(component[1:]): H += self.kernel.observation_model(X, component=name) return H ## Project at test coordinates with component observation model component_mean = self.project_mean(H_comp) component_var = self.project_variance(H_comp) if return_var: return component_mean, component_var else: return component_mean
[docs] def get_all_components(self, return_var: bool = False) -> dict[str, Any]: """ Extract the predicted mean/variance corresponding to each component kernel """ components = extract_leaf_kernels(self.kernel) results = {} for kernel in components: name = kernel.name if return_var: mu_m, var_m = self.get_component(name, return_var=True) results[name] = (mu_m, var_m) else: mu_m = self.get_component(name, return_var=False) results[name] = mu_m return results
[docs] class GaussianProcess(eqx.Module): r"""An interface for designing a Gaussian Process regression model. :param kernel: The kernel function. :type kernel: Kernel :param X: The input coordinates — any PyTree compatible with ``kernel`` whose leading dimension has size ``N_data``. For integrated kernels, pass ``(t, texp)`` where ``t`` is the array of exposure midpoints and ``texp`` is the array of exposure durations. :type X: JAXArray :param noise: Observation noise covariance matrices with shape ``(N, D, D)``, where ``N`` is the number of data points and ``D`` is the observation dimension (usually 1). Each slice ``noise[k]`` is the :math:`D \times D` noise covariance for the ``k``-th observation. A 1-D array of shape ``(N,)`` is interpreted as scalar per-observation variances. Defaults to :math:`\sqrt{\varepsilon_{\mathrm{machine}}} \cdot I` for all observations. :type noise: JAXArray, optional :param mean: A callable or constant mean function evaluated as ``mean(X)``. :type mean: Callable, optional :param solver: Solver class for filtering and smoothing. If ``None`` (default), selected automatically based on the kernel type. """ num_data: int = eqx.field(static=True) dtype: jnp.dtype = eqx.field(static=True) kernel: kernels.Kernel X: JAXArray mean_function: means.MeanBase mean: JAXArray var: JAXArray | None noise: JAXArray solver: StateSpaceSolver states: ConditionedStates def __init__( self, kernel: kernels.Kernel, X: JAXArray, *, noise: JAXArray | None = None, mean: means.MeanBase | Callable[[JAXArray], JAXArray] | JAXArray | None = None, solver: Any | None = None, mean_value: JAXArray | None = None, variance_value: JAXArray | None = None, covariance_value: Any | None = None, states: JAXArray | None = None, use_unique_names: bool = True, **solver_kwargs: Any, ): # First, assign unique kernel names if needed if use_unique_names: self.kernel = assign_unique_kernel_names(kernel) else: self.kernel = kernel # Check if the kernel contains any integrated components kernels = extract_leaf_kernels(self.kernel) is_integrated = any([isinstance(k, IntegratedStateSpaceModel) for k in kernels]) is_instantaneous = all([isinstance(k, StateSpaceModel) for k in kernels]) # If using an integrated solver, ensure X has both coords and bin sizes if is_integrated: assert isinstance(X, tuple) and len(X) > 1, ( "IntegratedStateSpaceSolver requires both the data coordinates (e.g. times)" " and bin sizes (e.g. exposure times). These should be passed as X=(t, texp)" " where t is the midpoint of each measurement and texp is the exposure time" " (i.e. each measurement is over the interval [t - texp/2, t + texp/2])." ) # Data coordinates (or tuple of coordinates) self.X = X # Mean function if isinstance(mean, means.MeanBase): self.mean_function = mean elif mean is None: self.mean_function = means.Mean(jnp.zeros(())) else: self.mean_function = means.Mean(mean) if mean_value is None: mean_value = jax.vmap(self.mean_function)(self.X) self.num_data = mean_value.shape[0] self.dtype = mean_value.dtype self.mean = mean_value self.var = variance_value self.states = states if self.mean.ndim > 2: raise ValueError( f"Invalid mean shape: expected ndim = 1 or 2, got ndim={self.mean.ndim}" ) # Observation noise: shape (N, D, D) # A 1-D array of shape (N,) is treated as scalar per-obs variance -> (N, 1, 1) if noise is None: jitter = _default_jitter(self.mean) noise = jnp.full((self.num_data, 1, 1), jitter, dtype=self.dtype) elif jnp.ndim(noise) == 1: noise = jnp.asarray(noise)[:, None, None] self.noise = noise # Set up the solver # TODO: add parallel flag and if so use ParallelIntegratedStateSpaceSolver? if solver is None: if is_integrated: solver = IntegratedStateSpaceSolver elif is_instantaneous: solver = StateSpaceSolver else: raise ValueError( "Must provide a solver if the kernel is not " "a StateSpaceModel or IntegratedStateSpaceModel" ) self.solver = solver( kernel, self.X, self.noise, **solver_kwargs, ) # If solver type (uninstantiated) is passed elif solver in [ StateSpaceSolver, IntegratedStateSpaceSolver, ParallelStateSpaceSolver, ParallelIntegratedStateSpaceSolver, ]: self.solver = solver( kernel, self.X, self.noise, **solver_kwargs, ) # If a pre-instantiated solver is passed (e.g. like condGP) else: self.solver = solver @property def loc(self) -> JAXArray: """ If conditioned, this will be the mean at the data points Otherwise, it is just the prior mean. """ return self.mean @property def variance(self) -> JAXArray: """ If conditioned, this will be the variance at the data points Otherwise, it is just the prior variance. """ return self.var @property def covariance(self) -> JAXArray: # TODO: Eq. 12.55 in Sarkka & Solin 2019 # if G = states.smoothing_gains exists, otherwise # I guess we raise an error that its not conditioned? # return self.covariance_value raise NotImplementedError
[docs] def log_probability(self, y: JAXArray) -> JAXArray: """Compute the log probability of this multivariate normal Args: y (JAXArray): The observed data. This should have the shape ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X`` data provided when instantiating this object. Returns: The marginal log probability of this multivariate normal model, evaluated at ``y``. """ _, _, _, _, v, S = self.solver.Kalman(y, return_v_S=True) return self._compute_log_prob(v, S)
[docs] def condition( self, y: JAXArray, X_test: JAXArray | None = None, *, include_mean: bool = True, kernel: kernels.Kernel | None = None, # TODO: select a component kernel ) -> ConditionResult: """Condition the model on observed data Args: y (JAXArray): The observed data. This should have the shape ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X`` data provided when instantiating this object. X_test (JAXArray, optional): The coordinates where the prediction should be evaluated. This should have a data type compatible with the ``X`` data provided when instantiating this object. If it is not provided, ``X`` will be used by default, so the predictions will be made. include_mean (bool, optional): If ``True`` (default), the predicted values will include the mean function evaluated at ``X_test``. kernel (Kernel, optional): A kernel to optionally specify the component kernel to be used for predicting after conditioning. See :ref:`multicomponent` for an example. Returns: A named tuple where the first element ``log_probability`` is the log marginal probability of the model, and the second element ``gp`` is the :class:`GaussianProcess` object describing the conditional distribution evaluated at ``X_test``. """ # If X_test is provided, we need to check that the tree structure # matches that of the input data, and that the shapes are all compatible # (i.e. the dimension of the inputs must match). This is slightly # convoluted since we need to support arbitrary pytrees. if X_test is not None: matches = jax.tree_util.tree_map( lambda a, b: ( jnp.ndim(a) == jnp.ndim(b) and jnp.shape(a)[1:] == jnp.shape(b)[1:] ), self.X, X_test, ) if not jax.tree_util.tree_reduce(lambda a, b: a and b, matches): raise ValueError( "`X_test` must have the same tree structure as the input `X`, " "and all but the leading dimension must have matching sizes" ) ## Condition on the data and return likelihood ingredients conditioned_results = self.solver.condition(y, return_v_S=True) ## unpack into prediction at the states state_coords, conditioned_states, (v, S) = conditioned_results ( (m_predicted, P_predicted), (m_filtered, P_filtered), (m_smoothed, P_smoothed), ) = conditioned_states if isinstance( self.solver, (IntegratedStateSpaceSolver, ParallelIntegratedStateSpaceSolver), ): t_states, instid, obsid, stateid = state_coords else: # If not integrated, t_states = X and id arrays are 'defaulted' t_states = self.kernel.coord_to_sortable(state_coords) instid = jnp.zeros_like(t_states, dtype=int) obsid = jnp.arange(len(t_states), dtype=int) stateid = jnp.ones_like(t_states, dtype=int) # all "have data" # Save the conditioned state values to a new GP object # so we can use them to make quick predictions at test # points with subsequent calls to self.predict states = ConditionedStates( self.X, t_states, instid, obsid, stateid, m_predicted, P_predicted, m_filtered, P_filtered, m_smoothed, P_smoothed, ) ## Grab likelihood (v and S will already be ## filtered down to the "at the data" states) log_prob = self._compute_log_prob(v, S) ## Make predictions at X_test if given if kernel is None: # If no component kernel passed, use the full model observation_model = self.kernel.observation_model else: # Otherwise use the observation model of the passed # kernel, where we zero out all the other components observation_model = lambda X: self.kernel.observation_model( X, component=kernel.name ) if X_test is not None: # If X_test was given, also predit at those points mu, var = self.solver.predict(X_test, conditioned_results) else: # Otherwise, project the conditioned states # (at the data points) to observation space X_test = self.X mu, var = states.project_at_data(observation_model) ## Create the conditioned GP condGP = GaussianProcess( kernel=self.kernel, X=X_test, noise=self.noise, # mean=self.mean, solver=self.solver, mean_value=mu, variance_value=var, states=states, ) # Return the likelihood and conditioned GP return ConditionResult(log_probability=log_prob, gp=condGP)
[docs] def predict( self, X_test: JAXArray | None = None, y: JAXArray | None = None, *, return_full_state: bool = False, kernel: int | None = None, # include_mean: bool = True, return_var: bool = False, # return_cov: bool = False, observation_model: Any | None = None, ) -> JAXArray | tuple[JAXArray, JAXArray]: """Predict the GP model at new test points conditioned on observed data Args: X_test (JAXArray, optional): The coordinates where the prediction should be evaluated. This should have a data type compatible with the ``X`` data provided when instantiating this object. If it is not provided, ``X`` will be used by default, so the predictions will be made at the data coordinates. y (JAXArray): The observed data. Only needs to be given if the GP has not yet been conditioned. This should have the shape ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X`` data provided when instantiating this object. include_mean (bool, optional): If ``True`` (default), the predicted values will include the mean function evaluated at ``X_test``. return_var (bool, optional): If ``True`` (default), the variance of the predicted values at ``X_test`` will be returned. return_cov (bool, optional): If ``True``, the covariance of the predicted values at ``X_test`` will be returned. If ``return_var`` is ``True``, this flag will be ignored. observation_model (Any, optional): optionally provide a function of X_test to define the output observation model. Default will use that of the kernel. return_full_state (bool, optional): If ``True``, return the full predicted state mean and covariance, rather than projecting to observation space. Default is ``False``, i.e. the result is projected through kernel.observation_model. kernel (int, optional): If specified, the index of the kernel in a multi-component model (for example, a sum or product of kernels) to extract and project (if return_full_state is False) the prediction for. Returns: The mean of the predictive model evaluated at ``X_test``, with shape ``(N_test,)`` where ``N_test`` is the zeroth dimension of ``X_test``. If either ``return_var`` or ``return_cov`` is ``True``, the variance or covariance of the predicted process will also be returned with shape ``(N_test,)`` or ``(N_test, N_test)`` respectively. """ if self.states is None: # Need to condition the GP first assert y is not None, ( "The GP has not been conditioned yet, and no data array `y` was given." ) llh, condGP = self.condition(y) return condGP.predict( X_test, return_full_state=return_full_state, kernel=kernel, return_var=return_var, # return_cov=return_cov, observation_model=observation_model, ) else: if X_test is None: # If no X_test given, predict at the data points if return_full_state: mu = self.states.smoothed_mean var = self.states.smoothed_cov else: if kernel is None: # already computed here mu, var = self.loc, self.var else: # extract component kernel & project name = kernel if isinstance(kernel, str) else kernel.name H_comp = lambda X: self.kernel.observation_model( X, component=name ) mu, var = self.states.project_at_data(H_comp) else: # Predicting at new test points H_test = ( self.kernel.observation_model if observation_model is None else observation_model ) mean, variance = self.solver.predict(X_test, self.states()) if return_full_state: mu = mean var = variance return PredictedStates( t_states=X_test, m=mu, P=var, kernel=self.kernel ) else: if kernel is not None: name = kernel if isinstance(kernel, str) else kernel.name H_test = lambda X: self.kernel.observation_model( X, component=name ) H = jax.vmap(H_test)(X_test) mu = jax.vmap(lambda H_i, m: H_i @ m)(H, mean).squeeze() var = jax.vmap(lambda H_i, P: H_i @ P @ H_i.T)( H, variance ).squeeze() if return_var: return mu, var # if return_cov: # return mu, var return mu
## TODO: how to define the sample function?
[docs] def sample( self, key: jax.random.KeyArray, shape: Sequence[int] | None = None, ) -> JAXArray: """Generate samples from the prior process Args: key: A ``jax`` random number key array. shape (tuple, optional): The number and shape of samples to generate. Returns: The sampled realizations from the process with shape ``(N_data,) + shape`` where ``N_data`` is the zeroth dimension of the ``X`` coordinates provided when instantiating this process. """ return self._sample(key, shape)
[docs] def numpyro_dist(self, **kwargs: Any) -> TinyDistribution: """Get the numpyro MultivariateNormal distribution for this process""" from tinygp.numpyro_support import TinyDistribution return TinyDistribution(self, **kwargs)
[docs] @partial(jax.jit, static_argnums=(2,)) def _sample( self, key: jax.random.KeyArray, shape: Sequence[int] | None, ) -> JAXArray: raise NotImplementedError
## TODO: implement sampling for state space model ## fast method to try: https://www.stats.ox.ac.uk/~doucet/doucet_simulationconditionalgaussian.pdf ## ## or alternatively call the tinygp version? copied below: # if shape is None: # shape = (self.num_data,) # else: # shape = (self.num_data,) + tuple(shape) # normal_samples = jax.random.normal(key, shape=shape, dtype=self.dtype) # return self.mean + jnp.moveaxis( # self.solver.dot_triangular(normal_samples), 0, -1 # )
[docs] @jax.jit def _compute_log_prob(self, v: JAXArray, S: JAXArray) -> JAXArray: """ Compute the log-likelihood given v and S from the Kalman filter """ ## More readable version: # def llh(k): # v_k, S_k = v[k], S[k] # L_k = jnp.linalg.cholesky(S_k) # w = jax.scipy.linalg.solve_triangular(L_k, v_k, lower=True) # quad = jnp.dot(w, w) # logdetS_k = 2.0 * jnp.sum(jnp.log(jnp.diag(L_k))) # d = v_k.shape[0] # return quad + logdetS_k + d*jnp.log(2*jnp.pi) # loglike = -0.5 * jnp.sum(jax.vmap(llh)(jnp.arange(len(v)))) L = jax.vmap(jnp.linalg.cholesky)(S) # [T, D, D] w = jax.scipy.linalg.solve_triangular(L, v[..., None], lower=True) w = jnp.squeeze(w, axis=-1) quad = jnp.sum(w**2, axis=1) logdetS = 2.0 * jnp.sum(jnp.log(jnp.diagonal(L, axis1=-2, axis2=-1)), axis=1) d = v.shape[1] log_probs = quad + logdetS + d * jnp.log(2.0 * jnp.pi) loglike = -0.5 * jnp.sum(log_probs) return jnp.where(jnp.isfinite(loglike), loglike, -jnp.inf)
[docs] def get_component_mean( self, component: list | str, return_var: bool = False, **kwargs, ) -> Any: """ Get the predictive mean (and variance) of a particular (or sum of) component kernel in a multi-component model evaluated at self.X Args: X (JAXArray, optional): The coordinates where the prediction should be evaluated. This should have a data type compatible with the ``X`` data provided when instantiating this object. component (list | str): The name(s) of the component kernel(s) to extract the mean for. If a list of names is provided, the joint mean and variance for that collection of kernels will be returned. return_var (bool, optional): If ``True``, also return the variances of each component. Default is ``False``. Returns: If ``return_var`` is ``False``: component_mean (JAXArray) If ``return_var`` is ``True``: component_mean (JAXArray) component_var (JAXArray) """ if self.states is None: raise ValueError( "The GP must be conditioned before getting component means." ) if isinstance(component, str): component = [component] ## Get effective observation model for the desired component(s) def H_comp(X): H = self.kernel.observation_model(X, component=component[0]) for k, name in enumerate(component[1:]): H += self.kernel.observation_model(X, component=name) return H ## Project at data component_mean, component_var = self.states.project_at_data(H_comp) if return_var: return component_mean, component_var else: return component_mean
[docs] def get_all_component_means(self, return_var: bool = False, **kwargs) -> Any: """ Get the predictive mean (and optionally variance) of each component kernels individually, evaluated at self.X Args: return_var (bool, optional): If ``True``, also return the variances of each component. Default is ``False``. Returns: If ``return_var`` is ``False``, a list of JAX arrays containing the means of each component kernel evaluated at the data points. If ``return_var`` is ``True``, a tuple where the first element is the list of means as before, and the second element is a list of JAX arrays containing the variances of each component kernel evaluated at the data points. """ if self.states is None: raise ValueError( "The GP must be conditioned before getting component means." ) ## First, extract all kernels kernels = extract_leaf_kernels(self.kernel) ## Loop through and project each component results = {} for k, kernel in enumerate(kernels): mu, var = self.get_component_mean( component=kernel.name, return_var=True, kwargs=kwargs ) if return_var: results[kernel.name] = (mu, var) else: results[kernel.name] = mu return results
[docs] class ConditionResult(NamedTuple): """The result of conditioning a :class:`GaussianProcess` on data This has two entries, ``log_probability`` and ``gp``, that are described below. """ log_probability: JAXArray """The log probability of the conditioned model In other words, this is the marginal likelihood for the kernel parameters, given the observed data, or the multivariate normal log probability evaluated at the given data. """ gp: GaussianProcess """A :class:`GaussianProcess` describing the conditional distribution This will have a mean and covariance conditioned on the observed data, but it is otherwise a fully functional GP that can sample from or condition further (although that's probably not going to be very efficient). """
[docs] def _default_jitter(reference: JAXArray) -> JAXArray: """Default to adding some amount of jitter to the diagonal, just in case, we use sqrt(eps) for the dtype of the mean function because that seems to give sensible results in general. """ return jnp.sqrt(jnp.finfo(reference.dtype).eps)