Source code for smolgp.solvers.integrated.parallel.rts

from __future__ import annotations

import jax
import jax.numpy as jnp


[docs] def ParallelIntegratedRTSSmoother( kernel, t_states, stateid, instid, kalman_results, ): """ Wrapper for Parallel RTS smoother Parameters: kernel: StateSpaceModel kernel t_states: time coordinates of the states stateid: exposure start/end indicators instid: instrument IDs kalman_results: output from Kalman filter: m_pred, P_pred, m_filter, P_filter Returns: E: g: smoothed means L: smoothed covariances """ m_pred, P_pred, m_filter, P_filter = kalman_results Phi_aug = kernel.transition_matrix Q_aug = kernel.process_noise RESET = kernel.reset_matrix asso_params = make_associative_params( Phi_aug, Q_aug, RESET, t_states, stateid, instid, m_pred, P_pred, m_filter, P_filter, ) E, g, L = parallel_integrated_rts_smoother(asso_params) m_smoothed = g P_smoothed = L return m_smoothed, P_smoothed
[docs] @jax.jit def make_associative_params( Phi_aug, Q_aug, RESET, t_states, stateid, instid, m_pred, P_pred, m_filter, P_filter, ): """Generate the associative parameters needed for parallel RTS See eqns in Section 4B of Sarkka & Garcia-Fernandez (2020) """ def make_last_params(mf_last, Pf_last): return (jnp.zeros_like(Pf_last), mf_last, Pf_last) def make_generic_params( Phi_dt, Q_dt, Reset, mp, Pp, mf, Pf, sid_current, ): def end_state(): Phik = Phi_dt Qk = Q_dt mk = mf Pk = Pf return Phik, Qk, mk, Pk def start_state(): Phik = Phi_dt @ Reset Qk = Q_dt mk = mp Pk = Pp return Phik, Qk, mk, Pk Phik, Qk, mk, Pk = jax.lax.cond( sid_current == 0, start_state, end_state, ) A = Phik @ Pk @ Phik.T + Qk b = Pk @ Phik.T E = jax.scipy.linalg.solve(A.T, b.T, assume_a="pos").T g = mk - E @ (Phik @ mk) L = Pk - E @ Phik @ Pk return (E, g, L) t_delta = jnp.diff(t_states) Phis = jax.vmap(Phi_aug, in_axes=(None, 0))(0, t_delta) Qs = jax.vmap(Q_aug, in_axes=(None, 0))(0, t_delta) Resets = jax.vmap(RESET)(instid[:-1]) E, g, L = jax.vmap( make_generic_params, )( Phis, Qs, Resets, m_pred[:-1], P_pred[:-1], m_filter[:-1], P_filter[:-1], stateid[:-1], ) EN, gN, LN = make_last_params(m_filter[-1], P_filter[-1]) E_all = jnp.concatenate([E, EN[jnp.newaxis, ...]], axis=0) g_all = jnp.concatenate([g, gN[jnp.newaxis, ...]], axis=0) L_all = jnp.concatenate([L, LN[jnp.newaxis, ...]], axis=0) return (E_all, g_all, L_all)
[docs] def _combine_per_pair(left, right): Ei, gi, Li = left Ej, gj, Lj = right # The indices need to be swapped for some reason... Eij = Ej @ Ei gij = Ej @ gi + gj Lij = Ej @ Li @ Ej.T + Lj return (Eij, gij, Lij)
[docs] @jax.jit def parallel_integrated_rts_smoother(asso_params): """ Jax implementation of the parallel RTS smoother algorithm for integrated measurements. See Section 4B of Sarkka & Garcia-Fernandez (2020) for a detailed description of the algorithm and notation, and Section 3.2.4 of Rubenzahl & Hattori et al. (2025) for the integrated case. Total runtime (span) complexity is O(N/T + logT) where N is the number of time steps and T is the number of parallel threads. """ return jax.lax.associative_scan( jax.vmap(_combine_per_pair), asso_params, reverse=True, )