from __future__ import annotations
import jax
import jax.numpy as jnp
[docs]
def IntegratedRTSSmoother(kernel, t_states, obsid, instid, stateid, kalman_results):
"""
Wrapper for jitted integrated_rts_smoother function
Parameters:
kernel : IntegratedStateSpaceModel kernel
t_states: Array of size K, sorted time coordinate of all states (exposure starts and ends)
obsid : Array of size N, which observation (0,...,N-1) is being made at each state k
instid : Array of size N, which instrument (0,...,Ninst-1) recorded observation n
stateid : Array of size K, 0 for exposure-start, 1 for exposure-end
kalman_results: output from Kalman filter (m_filtered, P_filtered, m_predicted, P_predicted)
Returns:
m_filtered: filtered means
P_filtered: filtered covariances
m_predicted: predicted means
P_predicted: predicted covariances
"""
# Model components
A_aug = kernel.transition_matrix
RESET = kernel.reset_matrix
return integrated_rts_smoother(A_aug, RESET, t_states, obsid, instid, stateid, *kalman_results)
[docs]
@jax.jit
def integrated_rts_smoother(
A_aug,
RESET,
t_states,
obsid,
instid,
stateid,
m_filtered,
P_filtered,
m_predicted,
P_predicted,
):
"""
Jax implementation of the integrated RTS smoothing algorithm
See Section 3.2.2 in Rubenzahl & Hattori et al. (in prep)
for detailed description of the algorithm and notation.
"""
def step(carry, k):
# Outputs from Kalman filter, unpacked for notational consistency
m_k = m_filtered[k]
P_k = P_filtered[k]
m_pred_next = m_predicted[k + 1] # has superscript minus
P_pred_next = P_predicted[k + 1] # has superscript minus
# Unpack state and covariance from last iteration
m_hat_next, P_hat_next = carry
# Compute smoothing gain
Delta = t_states[k + 1] - t_states[k]
A_k = A_aug(0, Delta)
def smooth_start():
"""RTS smooth an exposure-start state"""
m_k_pre = m_predicted[k] # pre-reset start state
P_k_pre = P_predicted[k] # pre-reset start covariance
Reset = RESET(instid[obsid[k]])
AR = A_k @ Reset
G_k = jnp.linalg.solve(P_pred_next.T, (P_k_pre @ AR.T).T).T
m_hat_k = m_k_pre + G_k @ (m_hat_next - m_pred_next)
P_hat_k = P_k_pre + G_k @ (P_hat_next - P_pred_next) @ G_k.T
return m_hat_k, P_hat_k
def smooth_end():
"""RTS smooth an exposure-end state"""
G_k = jnp.linalg.solve(P_pred_next.T, (P_k @ A_k.T).T).T
m_hat_k = m_k + G_k @ (m_hat_next - m_pred_next)
P_hat_k = P_k + G_k @ (P_hat_next - P_pred_next) @ G_k.T
return m_hat_k, P_hat_k
m_hat_k, P_hat_k = jax.lax.cond(
stateid[k] == 0,
lambda _: smooth_start(),
lambda _: smooth_end(),
operand=None,
)
return (m_hat_k, P_hat_k), (m_hat_k, P_hat_k)
# Start smoothing from final filtered state
init_carry = (m_filtered[-1], P_filtered[-1])
# Run backward from N-2 down to 0
K = len(t_states) # number of iterations
_, outputs = jax.lax.scan(step, init_carry, jnp.arange(K - 2, -1, -1))
m_smooth_reversed, P_smooth_reversed = outputs
# Reverse outputs (with final filtered=smoothed state) to match time order
m_smooth = jnp.vstack([m_smooth_reversed[::-1], m_filtered[-1][None, :]])
P_smooth = jnp.vstack([P_smooth_reversed[::-1], P_filtered[-1][None, :, :]])
return m_smooth, P_smooth