Source code for smolgp.solvers.integrated.kalman

from __future__ import annotations

import jax
import jax.numpy as jnp


[docs] def IntegratedKalmanFilter( kernel, X, y, t_states, obsid, instid, stateid, R, return_v_S=False ): """ Wrapper for integrated_kalman_filter function Parameters: kernel : IntegratedStateSpaceModel kernel X : Array of size N, data coordinates (e.g. (time, texp, instid)) y : Array of size (N, D), measurements at the data coordinates t_states: Array of size K, sorted time coordinate of all states (exposure starts and ends) obsid : Array of size N, which observation (0,...,N-1) is being made at each state k instid : Array of size N, which instrument (0,...,Ninst-1) recorded observation n stateid : Array of size K, 0 for exposure-start, 1 for exposure-end R : Observation noise covariance, shape (N, D, D) return_v_S : Whether to return innovation and its covariance (for likelihood computation) Returns: m_filtered : filtered means P_filtered : filtered covariances m_predicted: predicted means P_predicted: predicted covariances """ # Model components H_aug = kernel.observation_model A_aug = kernel.transition_matrix Q_aug = kernel.process_noise RESET = kernel.reset_matrix # Initial state and covariance # mean = jnp.zeros(kernel.d) # TODO: mean function of base kernel # m0 = jnp.block([mean] + kernel.num_insts*[jnp.zeros(kernel.d)]) m0 = jnp.zeros(kernel.dimension) P0 = kernel.stationary_covariance() output = integrated_kalman_filter( A_aug, Q_aug, H_aug, R, RESET, X, y, t_states, obsid, instid, stateid, m0, P0 ) if return_v_S: return output else: m_filtered, P_filtered, m_predicted, P_predicted, v, S = output return m_filtered, P_filtered, m_predicted, P_predicted
[docs] @jax.jit def integrated_kalman_filter( A_aug, Q_aug, H_aug, R, RESET, X, y, t_states, obsid, instid, stateid, m0, P0 ): """ Jax implementation of the integrated Kalman filter algorithm See Section 3.2.1 in Rubenzahl & Hattori et al. (in prep) for detailed description of the algorithm and notation. """ H = jax.vmap(H_aug)(X) @jax.jit def step(carry, k): # Unpack previous state and covariance m_prev, P_prev = carry # If k==0 we use the prior m0, Pinf and zero time-lag (dt=0) Delta = jax.lax.cond( k > 0, lambda i: t_states[i] - t_states[i - 1], lambda _: 0.0, k ) n = obsid[k] # Get transition matrix A_prev = A_aug(0, Delta) Q_prev = Q_aug(0, Delta) # Predict step is same m_pred = A_prev @ m_prev P_pred = A_prev @ P_prev @ A_prev.T + Q_prev # Update the end of the exposure def update_end(): Hk = H[n] y_pred = Hk @ m_pred # predicted observation v_k = y[n] - y_pred # "innovation" or "surprise" term S_k = Hk @ P_pred @ Hk.T + R[n] # uncertainy in predicted observation K_k = jnp.linalg.solve(S_k.T, (P_pred @ Hk.T).T).T # Kalman gain m_k = m_pred + K_k @ v_k # conditioned state estimate P_k = P_pred - K_k @ S_k @ K_k.T # conditioned covariance estimate return m_k, P_k, m_pred, P_pred, v_k, S_k # Update the start of the exposure, aka reset its z to zero def update_start(): Reset = RESET(instid[n]) m_k = Reset @ m_pred P_k = Reset @ P_pred @ Reset.T Hk = H[n] # TODO: change this and next two lines to use shapes? v_k = jnp.zeros_like(Hk @ m_pred) # maybe e.g. jax broadcast_shapes? S_k = jnp.zeros_like(Hk @ P_pred @ Hk.T) return m_k, P_k, m_pred, P_pred, v_k, S_k m_k, P_k, m_pred, P_pred, v_k, S_k = jax.lax.cond( stateid[k] == 0, lambda _: update_start(), lambda _: update_end(), operand=None, ) return (m_k, P_k), (m_k, P_k, m_pred, P_pred, v_k, S_k) # Initialize carry with prior state and covariance init_carry = (m0, P0) # Run the filter over all time steps, unpack, and return results _, outputs = jax.lax.scan(step, init_carry, jnp.arange(len(t_states))) m_filtered, P_filtered, m_predicted, P_predicted, v, S = outputs # only return v,S at exposure ends (where there is data) ends_idx = jnp.nonzero(stateid == 1, size=y.shape[0])[0] v_sel = jnp.take(v, ends_idx, axis=0) S_sel = jnp.take(S, ends_idx, axis=0) return m_filtered, P_filtered, m_predicted, P_predicted, v_sel, S_sel