Source code for smolgp.kernels.integrated

# TODO:
# 1. copy base.py kernel object style
# 2. add integrated_transition_matrix and integrated_process_noise
# 3. add attribute/property for num_insts
# 4. define each of the usual matrix components to be the augmented version
#    e.g. stationary_covariance --> BlockDiag(sho.stationary_covariance, identity)

# in the solver, user will have passed t, texp, instid, and y
# from there, stateid will get auto-created according to t and texp

"""
These kernels are compatible with :class:`smolgp.solvers.integrated.IntegratedStateSpaceSolver`,
which uses Bayesian filtering and smoothing algorithms to perform scalable GP
inference. (see :mod:`smolgp.solvers` for more technical details).
On GPU, a performance boost may be observed for large datasets by using the
:class:`smolgp.solvers.parallel.ParallelStateSpaceSolver` class.

Like the quasisep kernels, these methods are experimental, so you may find
the documentation patchy in places. You are encouraged to `open issues or
pull requests <https://github.com/smolgp-dev/smolgp/issues>`_ as you find gaps.
"""

from __future__ import annotations

__all__ = [
    "IntegratedSHO",
    "IntegratedExp",
    "IntegratedMatern32",
    "IntegratedMatern52",
    "IntegratedCosine",
]

import equinox as eqx
import jax
import jax.numpy as jnp

from tinygp.helpers import JAXArray

import smolgp.kernels
from smolgp.kernels import StateSpaceModel
from smolgp.helpers import Phibar_from_VanLoan


class IntegratedStateSpaceModel(StateSpaceModel):
    r"""Base class for a :class:`StateSpaceModel` augmented with an integral state.

    Augments a base model with latent state :math:`x` by appending an integral
    state :math:`z`, forming the joint state :math:`[x;\, z]`. This enables
    modeling of time-averaged measurements such as long-exposure observations.

    Coordinates for an integrated model must be a tuple ``(t, delta, instid)``:

    - ``t``: measurement coordinate (e.g. the midpoint time of each exposure)
    - ``delta``: integration range (e.g. exposure time); each measurement
      spans :math:`[t - \delta/2,\; t + \delta/2]`
    - ``instid``: integer index identifying which instrument/dataset each
      measurement belongs to (supports overlapping multi-instrument datasets)
    """

    base_model: StateSpaceModel  # the base (non-integrated) SSM
    num_insts: int = eqx.field(static=True)  # number of integral states

    @property
    def d(self) -> int:
        """The dimension of the base (non-integrated) state space model"""
        return self.base_model.dimension

    @property
    def dimension(self) -> int:
        """The dimension of the augmented state space model"""
        return self.d + self.num_insts

    def coord_to_sortable(self, X: JAXArray) -> JAXArray:
        """
        A helper function used to convert coordinates to sortable 1-D values

        If X is a tuple, e.g. of (time, delta, instid), this assumes the first coordinate is the sortable one
        """
        if isinstance(X, tuple):
            return X[0]
        else:
            return X

    def design_matrix(self) -> JAXArray:
        """The augmented design (also called the feedback) matrix for the process, $F$"""
        F = self.base_model.design_matrix()
        F_aug = jnp.zeros((self.dimension, self.dimension))
        F_aug = F_aug.at[: self.d, : self.d].set(F)
        for i in range(self.num_insts):
            F_aug = F_aug.at[self.d + i, 0].set(1.0)
        return F_aug

    def stationary_covariance(self) -> JAXArray:
        """The augmented stationary covariance of the process, Pinf"""
        Pinf = self.base_model.stationary_covariance()
        Pinf_aug = jnp.diag(jnp.ones(self.dimension)).at[: self.d, : self.d].set(Pinf)
        return Pinf_aug

    def observation_matrix(self, X: JAXArray) -> JAXArray:
        """The augmented observation model for the process, $H$"""

        ## TODO: make sure this works for multivariate data, e.g. like:
        # H_base = self.base_model.observation_model(t)
        # H_z = H_base/delta # observe the average value over exposure
        # H_aug = jnp.zeros((H_base.shape[0], self.dimension))
        # H_aug = jax.lax.dynamic_update_slice(H_aug, H_z, (self.d*(1+instid),))
        ## Below is hardcoded for 1-D data

        def H_integral(t: JAXArray, delta: JAXArray, instid: int) -> JAXArray:
            """Observation model for integral state"""
            H_z = jnp.array([1.0 / delta])
            H_aug = jnp.zeros(self.dimension)
            H_aug = jax.lax.dynamic_update_slice(H_aug, H_z, (self.d + instid,))
            return H_aug

        def H_latent(t: JAXArray, instid: int) -> JAXArray:
            """Observation model for latent (non-integral) state"""
            # H_x = self.base_model.observation_model(X) # TODO: use this to get the shapes right
            H_x = jnp.zeros(self.d).at[0].set(1)  # hardcoded 1-D version for now
            H_aug = jnp.zeros(self.dimension)
            H_aug = jax.lax.dynamic_update_slice(H_aug, H_x, (0,))
            return H_aug

        if isinstance(X, tuple) or isinstance(X, list):
            # Observing integral state (z) with exposure time (delta)
            t, delta, instid = X
            H_aug = jax.lax.cond(
                delta > 0,
                lambda _: H_integral(t, delta, instid),
                lambda _: H_latent(t, instid),
                operand=None,
            )
        else:
            # default to latent state if no exposure time provided
            H_aug = H_latent(X, instid=0)

        return jnp.array([H_aug])

    def noise(self) -> JAXArray:
        """The spectral density of the white noise process, $Q_c$"""
        return self.base_model.noise()

    def noise_effect_matrix(self) -> JAXArray:
        """The augmented noise effect matrix, $L$"""
        L = self.base_model.noise_effect_matrix()
        L_aug = jnp.vstack([L] + [0.0] * self.num_insts)
        return L_aug

    def integrated_transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
        """
        The integrated transition matrix between two states at coordinates X1 and X2, $A_k$

        By default uses the Van Loan method to compute Phibar = ∫0^dt exp(F s) ds

        Overload this method if you wish to define the integrated transition matrix analytically.
        """
        F = self.base_model.design_matrix()
        t1 = self.coord_to_sortable(X1)
        t2 = self.coord_to_sortable(X2)
        dt = t2 - t1
        return Phibar_from_VanLoan(F, dt)

    def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
        """
        The augmented transition matrix between two states at coordinates X1 and X2, $A_k$
        """
        t1 = self.coord_to_sortable(X1)
        t2 = self.coord_to_sortable(X2)
        PHI = self.base_model.transition_matrix(t1, t2)
        INTPHI = self.integrated_transition_matrix(t1, t2)[0, :]
        PHIAUG = jnp.eye(self.dimension)
        PHIAUG = PHIAUG.at[: self.d, : self.d].set(PHI)
        for i in range(self.num_insts):
            PHIAUG = PHIAUG.at[self.d + i : self.d + i + 1, : self.d].set(INTPHI)
        return PHIAUG

    def integrated_process_noise(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
        """
        Computes the submatrices Qaug12, Qaug21, and Qaug22
        needed to assemble the augmented process noise matrix.

        By default uses the Van Loan method to compute these submatrices.
        Overload this method if you wish to define these submatrices analytically.
        """

        t1 = self.coord_to_sortable(X1)
        t2 = self.coord_to_sortable(X2)
        dt = t2 - t1
        F = self.base_model.design_matrix()
        L = self.base_model.noise_effect_matrix()
        Qc = self.base_model.noise()

        vanloan = smolgp.helpers.VanLoan(F, L, Qc, dt)
        F3 = vanloan["F3"]
        H2 = vanloan["H2"]
        K1 = vanloan["K1"]

        M = F3.T @ H2
        F3TK1 = F3.T @ K1
        W = F3TK1 + F3TK1.T

        Qaug12 = M[:, :1]
        Qaug21 = Qaug12.T
        Qaug22 = W[:1, :1]
        return Qaug12, Qaug21, Qaug22

    # @partial(
    #     jax.jit,
    #     static_argnames=("force_numerical"),
    # )
    def process_noise(
        self, X1: JAXArray, X2: JAXArray, force_numerical: bool = False
    ) -> JAXArray:
        """
        The augmented process noise matrix $Q_k$

        Default behavior computes Q from the Van Loan
        matrix exponential involving F, L, and Qc

        Overload this method if you wish to define the
        integrated process noise analytically.
        """

        t1 = self.coord_to_sortable(X1)
        t2 = self.coord_to_sortable(X2)
        dt = t2 - t1
        if force_numerical:
            Qaug12, Qaug21, Qaug22 = super(type(self), self).integrated_process_noise(
                X1, X2
            )
        else:
            Qaug12, Qaug21, Qaug22 = self.integrated_process_noise(X1, X2)
        Qbase = self.base_model.process_noise(0, dt)
        QAUG = jnp.tile(Qaug22, (self.dimension, self.dimension))
        QAUG = QAUG.at[: self.d, : self.d].set(Qbase)
        for i in range(self.num_insts):
            QAUG = QAUG.at[: self.d, self.d + i : self.d + i + 1].set(Qaug12)
            QAUG = QAUG.at[self.d + i : self.d + i + 1, : self.d].set(Qaug21)
        return QAUG

    def reset_matrix(self, instid: int = 0) -> JAXArray:
        """
        The reset matrix, RESET_k,for instrument `instid` (0-indexed)

        By default, resets only the integral states to zero.
        Overload this method if you wish to define a different reset behavior.
        """
        diag = jnp.ones(self.dimension)
        diag = jax.lax.dynamic_update_slice(diag, jnp.array([0.0]), (self.d + instid,))
        return jnp.diag(diag)


[docs] class IntegratedSHO(IntegratedStateSpaceModel): r"""The :class:`~smolgp.kernels.SHO` kernel integrated over a finite time range :math:`\delta`. Models the time-averaged version of the damped, driven stochastic harmonic oscillator kernel (see :class:`~smolgp.kernels.SHO`). Each measurement is the average of the latent GP over an exposure window of length :math:`\delta` centred on the observation time. Args: omega: The natural frequency :math:`\omega_0`. quality: The quality factor :math:`Q`. sigma (optional): The amplitude :math:`\sigma`. Defaults to 1. Specifying it here provides a slight performance boost over multiplying the kernel by a scalar after construction. num_insts (optional): Number of distinct instrument datasets. Defaults to 1. """ omega: JAXArray | float quality: JAXArray | float sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(())) eta: JAXArray | float def __init__( self, omega: JAXArray | float, quality: JAXArray | float, sigma: JAXArray | float = jnp.ones(()), num_insts: int = 1, name: str = "IntegratedSHO", **kwargs, ): self.num_insts = num_insts self.name = name # SHO parameters self.omega = omega self.quality = quality self.sigma = sigma self.eta = jnp.sqrt(jnp.abs(1 - 1 / (4 * self.quality**2))) # Base model self.base_model = smolgp.kernels.SHO( omega=self.omega, quality=self.quality, sigma=self.sigma )
[docs] def integrated_transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: """The integrated transition matrix Phibar for the SHO process""" # Shorthand notations n = self.eta w = self.omega q = self.quality a = -0.5 * w / q b = n * w a2plusb2 = jnp.square(a) + jnp.square(b) A = 1 / (2 * n * q) B = 1 / (n * w) # = 1/b C = -w / n def critical(t1: JAXArray, t2: JAXArray) -> JAXArray: ## TODO: returning numerical result until we do this integral by hand F = self.base_model.design_matrix() return Phibar_from_VanLoan(F, t2 - t1) def underdamped(t1: JAXArray, t2: JAXArray) -> JAXArray: ## General integral from t1->t2: # def Int_ecos(t): # return jnp.exp(a * t) * (a * jnp.cos(b * t) + b * jnp.sin(b * t)) # def Int_esin(t): # return jnp.exp(a * t) * (a * jnp.sin(b * t) - b * jnp.cos(b * t)) # Ic = Int_ecos(t2) - Int_ecos(t1) # Is = Int_esin(t2) - Int_esin(t1) # Phibar11 = Ic + A * Is # Phibar12 = B * Is # Phibar21 = C * Is # Phibar22 = Ic - A * Is ## Paper version: hardcoded for t1=0, dt=t2-t1 dt = t2 - t1 arg = b * dt exp = jnp.exp(a * dt) sin = jnp.sin(arg) cos = jnp.cos(arg) asin = a * sin bsin = b * sin acos = a * cos bcos = b * cos Ic = acos + bsin Is = asin - bcos Phibar11 = exp * (Ic + A * Is) - (a - A * b) Phibar12 = B * (Is * exp + b) Phibar21 = C * (Is * exp + b) Phibar22 = exp * (Ic - A * Is) - (a + A * b) return jnp.array([[Phibar11, Phibar12], [Phibar21, Phibar22]]) / a2plusb2 def overdamped(t1: JAXArray, t2: JAXArray) -> JAXArray: ## TODO: returning numerical result until we do this integral by hand F = self.base_model.design_matrix() return Phibar_from_VanLoan(F, t2 - t1) # Return the appropriate form based on quality factor t1 = self.coord_to_sortable(X1) t2 = self.coord_to_sortable(X2) return jax.lax.cond( jnp.allclose(q, 0.5), critical, lambda t1, t2: jax.lax.cond(q > 0.5, underdamped, overdamped, t1, t2), t1, t2, )
[docs] def integrated_process_noise(self, X1: JAXArray, X2: JAXArray) -> JAXArray: """The integrated process noise submatrices for the SHO process""" t1 = self.coord_to_sortable(X1) t2 = self.coord_to_sortable(X2) dt = t2 - t1 n = self.eta w = self.omega q = self.quality a = -0.5 * w / q b = n * w sigma2 = jnp.square(self.sigma) A = 1 / (2 * n * q) def critical(dt: JAXArray) -> JAXArray: # TODO: returning numerical result until we do this integral by hand return super(type(self), self).integrated_process_noise(0, dt) def underdamped(dt: JAXArray) -> JAXArray: x = a * dt arg = b * dt w2 = jnp.square(w) q2 = jnp.square(q) q4 = jnp.square(q2) exp = jnp.exp(x) exp2 = jnp.exp(2 * x) exp2m1 = jnp.expm1(2 * x) sin = jnp.sin(arg) cos = jnp.cos(arg) sinsq = jnp.square(jnp.sin(arg)) sin2 = jnp.sin(2 * arg) cos2 = jnp.cos(2 * arg) A2 = jnp.square(A) iQ12_1 = jnp.square(exp * (cos + A * sin) - 1) / (q * w) iQ12_2 = A * exp * (4 * sin - exp * sin2) - 2 * A2 * exp2 * sinsq + exp2m1 part1 = 8 * q * w * dt + 4 * q2 - 12 part2 = A2 * exp2 * (cos2 - 16 * q4) part3_1 = 16 * exp * (cos + (1 - 2 * q2) * A * sin) part3_2 = exp2 * ((1 - 3 * A2) / A * sin2 - 3 * cos2) part3 = part3_1 + part3_2 iQ22 = 1 / (4 * q2 * w2) * (part1 + part2 + part3) iQ22 = jnp.maximum(iQ22, 0.0) # prevent underflows at dt=0 Qaug12 = sigma2 * jnp.array([[iQ12_1], [iQ12_2]]) Qaug22 = sigma2 * jnp.array([[iQ22]]) # Prevent underflows Qaug12 = jnp.where(jnp.abs(Qaug12) < 1e-14, jnp.zeros_like(Qaug12), Qaug12) Qaug22 = jnp.where(jnp.abs(Qaug22) < 1e-14, jnp.zeros_like(Qaug22), Qaug22) Qaug21 = Qaug12.T return Qaug12, Qaug21, Qaug22 def overdamped(dt: JAXArray) -> JAXArray: ## TODO: returning numerical result until we do this integral by hand return super(type(self), self).integrated_process_noise(0, dt) return jax.lax.cond( jnp.allclose(q, 0.5), critical, lambda dt: jax.lax.cond(q > 0.5, underdamped, overdamped, dt), dt, )
## TODO: is there a way to automate this? aka make a generic IntegratedKernel class... ## Default constructions for all kernels in smolgp.kernels.base ## IntegratedStateSpaceModel parent class will handle the augmentation ## All component matrices will be auto-generated numerically (e.g. A via expm, Q via Van Loan)
[docs] class IntegratedExp(IntegratedStateSpaceModel): r"""The :class:`~smolgp.kernels.Exp` (Ornstein–Uhlenbeck / Matérn-1/2) kernel integrated over a finite time range :math:`\delta`. Args: scale: The length scale :math:`\ell`. sigma (optional): The amplitude :math:`\sigma`. Defaults to 1. num_insts (optional): Number of distinct instrument datasets. Defaults to 1. """ scale: JAXArray | float sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(())) lam: JAXArray | float def __init__( self, scale: JAXArray | float, sigma: JAXArray | float = jnp.ones(()), num_insts: int = 1, name: str = "IntegratedExp", **kwargs, ): self.scale = scale self.sigma = sigma self.name = name self.num_insts = num_insts self.base_model = smolgp.kernels.Exp(scale=self.scale, sigma=self.sigma) self.lam = self.base_model.lam
[docs] class IntegratedMatern32(IntegratedStateSpaceModel): r"""The :class:`~smolgp.kernels.Matern32` kernel integrated over a finite time range :math:`\delta`. Args: scale: The length scale :math:`\ell`. sigma (optional): The amplitude :math:`\sigma`. Defaults to 1. num_insts (optional): Number of distinct instrument datasets. Defaults to 1. """ scale: JAXArray | float sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(())) lam: JAXArray | float def __init__( self, scale: JAXArray | float, sigma: JAXArray | float = jnp.ones(()), num_insts: int = 1, name: str = "IntegratedMatern32", **kwargs, ): self.scale = scale self.sigma = sigma self.name = name self.num_insts = num_insts self.base_model = smolgp.kernels.Matern32(scale=self.scale, sigma=self.sigma) self.lam = self.base_model.lam
[docs] class IntegratedMatern52(IntegratedStateSpaceModel): r"""The :class:`~smolgp.kernels.Matern52` kernel integrated over a finite time range :math:`\delta`. Args: scale: The length scale :math:`\ell`. sigma (optional): The amplitude :math:`\sigma`. Defaults to 1. num_insts (optional): Number of distinct instrument datasets. Defaults to 1. """ scale: JAXArray | float sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(())) lam: JAXArray | float def __init__( self, scale: JAXArray | float, sigma: JAXArray | float = jnp.ones(()), num_insts: int = 1, name: str = "IntegratedMatern52", **kwargs, ): self.scale = scale self.sigma = sigma self.name = name self.num_insts = num_insts self.base_model = smolgp.kernels.Matern52(scale=self.scale, sigma=self.sigma) self.lam = self.base_model.lam
[docs] class IntegratedCosine(IntegratedStateSpaceModel): r"""The :class:`~smolgp.kernels.Cosine` kernel integrated over a finite time range :math:`\delta`. Args: scale: The period :math:`\ell`. sigma (optional): The amplitude :math:`\sigma`. Defaults to 1. num_insts (optional): Number of distinct instrument datasets. Defaults to 1. """ scale: JAXArray | float sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(())) omega: JAXArray | float def __init__( self, scale: JAXArray | float, sigma: JAXArray | float = jnp.ones(()), num_insts: int = 1, name: str = "IntegratedCosine", **kwargs, ): self.scale = scale self.sigma = sigma self.name = name self.num_insts = num_insts self.base_model = smolgp.kernels.Cosine(scale=self.scale, sigma=self.sigma) self.omega = self.base_model.omega