Source code for smolgp.solvers.rts

from __future__ import annotations

import jax
import jax.numpy as jnp


[docs] def RTSSmoother(kernel, X, kalman_results): """ Wrapper for RTS smoother Parameters: kernel: StateSpaceModel kernel X: data coordinates, e.g. time or (time, texp, instid) kalman_results: output from Kalman filter (m_filtered, P_filtered, m_predicted, P_predicted) Returns: m_smooth: smoothed means P_smooth: smoothed covariances """ A = kernel.transition_matrix t = kernel.coord_to_sortable(X) return rts_smoother(A, t, *kalman_results)
[docs] @jax.jit def rts_smoother(A, t, m_filtered, P_filtered, m_predicted, P_predicted): """ Jax implementation of the Rauch-Tung-Striebel (RTS) smoothing algorithm See Theorem 8.2 (pdf page 156) in "Bayesian Filtering and Smoothing" by Simo Särkkä for detailed description of the algorithm and notation. """ N = len(t) # number of data points def step(carry, k): """ Routine for a single step of the RTS smoother Parameters: carry: (m_next, P_next) - next state and covariance k: index of the current time step Recall we are iterating backwards, so _next is k+1 Returns: - Smoothed state (m_k_hat) and covariance (P_k_hat) to carry to next iteration - Full output for completed scan (m_k_hat, P_k_hat) """ # Outputs from Kalman filter, unpacked for notational consistency m_k = m_filtered[k] P_k = P_filtered[k] m_pred_next = m_predicted[k + 1] # has superscript minus P_pred_next = P_predicted[k + 1] # has superscript minus # Unpack state and covariance from last iteration m_hat_next, P_hat_next = carry # Time-lag between states Delta_k = t[k + 1] - t[k] # Transition matrix A_k = A(0, Delta_k) # Compute smoothing gain # P_pred_next_inv = jnp.linalg.inv(P_pred_next) # G_k = P_k @ A_k.T @ P_pred_next_inv # smoothing gain G_k = jnp.linalg.solve(P_pred_next.T, (P_k @ A_k.T).T).T # more stable # Update state and covariance m_hat_k = m_k + G_k @ (m_hat_next - m_pred_next) P_hat_k = P_k + G_k @ (P_hat_next - P_pred_next) @ G_k.T return (m_hat_k, P_hat_k), (m_hat_k, P_hat_k) # Start smoothing from final filtered state init_carry = (m_filtered[-1], P_filtered[-1]) # Run backward from N-2 down to 0 _, outputs = jax.lax.scan(step, init_carry, jnp.arange(N - 2, -1, -1)) m_smooth_reversed, P_smooth_reversed = outputs # Reverse outputs to match time order m_smooth = jnp.vstack([m_smooth_reversed[::-1], m_filtered[-1][None, :]]) P_smooth = jnp.vstack([P_smooth_reversed[::-1], P_filtered[-1][None, :, :]]) return m_smooth, P_smooth