Source code for smolgp.solvers.parallel.solver

from __future__ import annotations
from typing import Any

import jax
import jax.numpy as jnp
import equinox as eqx

from tinygp.helpers import JAXArray
from tinygp.solvers.quasisep.solver import QuasisepSolver
from smolgp.kernels.base import StateSpaceModel

from smolgp.solvers.parallel.kalman import ParallelKalmanFilter
from smolgp.solvers.parallel.rts import ParallelRTSSmoother


[docs] class ParallelStateSpaceSolver(eqx.Module): """ A solver that uses ``jax.lax.associative_scan`` to implement parallel Kalman filtering and RTS smoothing """ X: JAXArray kernel: StateSpaceModel noise: JAXArray # shape (N, D, D): observation noise covariance per time step def __init__( self, kernel: StateSpaceModel, X: JAXArray, noise: JAXArray, ): """Build a :class:`StateSpaceSolver` for a given kernel and coordinates Args: kernel: The kernel function. X: The input coordinates. noise: Observation noise covariance array of shape ``(N, D, D)``. """ self.kernel = kernel self.X = X self.noise = noise
[docs] def normalization(self) -> JAXArray: # TODO: do we want/can we implement this in state space? for now, fall back to quasisep class _NoiseAdapter: def __init__(self, n): self._n = n def diagonal(self): return self._n[0, 0, :] return QuasisepSolver( self.kernel, self.X, _NoiseAdapter(self.noise) ).normalization()
[docs] def Kalman(self, y, return_v_S=True) -> Any: """Wrapper for Kalman filter used with this solver""" # noise (N, D, D) → R (N, D, D); y (..., N) → (N, D) y_nd = y[:, None] if y.ndim == 1 else y return ParallelKalmanFilter( self.kernel, self.X, y_nd, self.noise, return_v_S=return_v_S )
[docs] def RTS(self, kalman_results) -> Any: """Wrapper for RTS smoother used with this solver""" return ParallelRTSSmoother(self.kernel, self.X, kalman_results)
[docs] def condition(self, y, return_v_S=False) -> JAXArray: """ Compute the Kalman predicted, filtered, and RTS smoothed means and covariances at each of the input coordinates """ # Kalman filtering kalman_results = self.Kalman(y, return_v_S=True) if return_v_S: m_filtered, P_filtered, m_predicted, P_predicted, v, S = kalman_results v_S = (v, S) else: m_filtered, P_filtered, m_predicted, P_predicted = kalman_results v_S = None # RTS smoothing rts_results = self.RTS((m_filtered, P_filtered)) _, m_smoothed, P_smoothed = rts_results # Pack-up results and return t_states = self.kernel.coord_to_sortable(self.X) conditioned_states = ( (m_predicted, P_predicted), (m_filtered, P_filtered), (m_smoothed, P_smoothed), ) return t_states, conditioned_states, v_S
[docs] @jax.jit def predict(self, X_test, conditioned_results) -> JAXArray: """ Algorithm for making predictions at arbitrary coordinates X_test Args: X_test : The test coordinates. conditioned_results : The output of self.condition Returns: pred_mean : Predicted means of the states at X_test pred_var : Predicted variances of the states at X_test There are three cases: 1. Retrodiction : smoothing from the first data point using the prior as the prediction 2. Interpolation : filtering from most recent data point and smoothing from next future point 3. Extrapolation : predicting from final filtered point """ # Unpack conditioned results state_coords, conditioned_states, _ = conditioned_results ( (m_predicted, P_predicted), (m_filtered, P_filtered), (m_smoothed, P_smoothed), ) = conditioned_states t_states, instid, obsid, stateid = state_coords # Unpack test coordinates t_test = self.kernel.coord_to_sortable(X_test) N = len(self.X) # number of data points K = len(t_states) # number of states M = len(t_test) # number of test points # Prior covariance for retrodiction Pinf = self.kernel.stationary_covariance() if not isinstance(Pinf, JAXArray): # if multicomponent model Pinf = Pinf.to_dense() # needs to be array form here # Prior mean for retrodiction # mean = jnp.zeros(self.kernel.d) # TODO: mean function of base kernel # m0 = jnp.block([mean] + self.kernel.num_insts * [jnp.zeros(self.kernel.d)]) m0 = jnp.zeros(self.kernel.dimension) # Nearest (future) datapoint ## TODO: we assume X_states is sorted here; should we enforce that? k_nexts = jnp.searchsorted(t_states, t_test, side="right") # Which method to use for each test point: past = k_nexts <= 0 # Retrodiction future = k_nexts >= K # Forecast during = ~past & ~future # Interpolate cases = past.astype(int) * 0 + during.astype(int) * 1 + future.astype(int) * 2 # Shorthand for matrices A = self.kernel.transition_matrix Q = self.kernel.process_noise def predict(k, ktest): """ Kalman prediction from most recent filtered (not smoothed) state """ dt = X_test[ktest] - self.X[k] m_k = m_filtered[k] P_k = P_filtered[k] A_star = A(0, dt) # transition matrix from t_k to t_star Q_star = Q(0, dt) # process noise from t_k to t_star m_star_pred = A_star @ m_k P_star_pred = A_star @ P_k @ A_star.T + Q_star # No Kalman update since we have no data at t_star, so we're done return m_star_pred, P_star_pred def smooth(k_next, ktest, m_star_pred, P_star_pred): """ RTS smooth the prediction (ktest) using the nearest future (smoothed) state (k_next) m_star_pred and P_star_pred are the output of predict(k, k_star) """ # Next (future) data point predicted & smoothed state m_pred_next = m_predicted[ k_next ] # prediction (no kalman update) at next data point P_pred_next = P_predicted[ k_next ] # prediction (no kalman update) at next data point m_hat_next = m_smoothed[k_next] # RTS smoothed state at next data point P_hat_next = P_smoothed[ k_next ] # RTS smoothed covariance at next data point # Time-lag between states dt = self.X[k_next] - X_test[ktest] # Transition matrix for this step A_k = A(0, dt) # Compute smoothing gain # P_pred_next_inv = jnp.linalg.inv(P_pred_next) # G_k = P_star_pred @ A_k.T @ P_pred_next_inv # smoothing gain G_k = jnp.linalg.solve( P_pred_next.T, (P_star_pred @ A_k.T).T ).T # more stable # Update state and covariance m_star_hat = m_star_pred + G_k @ (m_hat_next - m_pred_next) P_star_hat = P_star_pred + G_k @ (P_hat_next - P_pred_next) @ G_k.T return m_star_hat, P_star_hat def retrodict(ktest): """Reverse-extrapolate from first datapoint t_star""" m_star, P_star = smooth(0, ktest, m0, Pinf) return m_star, P_star def interpolate(ktest): """Interpolate between nearest data points""" # Get nearest data point before and after the test point k_next = k_nexts[ktest] k_prev = k_next - 1 # 1. Kalman predict from most recent data point (in past) m_star_pred, P_star_pred = predict(k_prev, ktest) # 2. RTS smooth from next nearest data point (in future) m_star_hat, P_star_hat = smooth(k_next, ktest, m_star_pred, P_star_pred) return m_star_hat, P_star_hat def extrapolate(ktest): """Kalman predict from from last datapoint t_star""" m_star, P_star = predict(-1, ktest) return m_star, P_star def predict_point(ktest): """ Switch between retrodiction, interpolation, and extrapolation for a single test point ktest """ return jax.lax.switch( cases[ktest], (retrodict, interpolate, extrapolate), (ktest) ) # Calculate predictions ktests = jnp.arange(0, M, 1) (pred_mean, pred_var) = jax.vmap(predict_point)(ktests) return pred_mean, pred_var