from __future__ import annotations
import jax
import jax.numpy as jnp
[docs]
def ParallelIntegratedKalmanFilter(
kernel,
X,
y,
t_states,
obsid,
instid,
stateid,
R,
return_v_S=False,
):
"""
Wrapper for parallel_integrated_kalman_filter function
Parameters:
kernel : IntegratedStateSpaceModel kernel
X : Array of size N, data coordinates (e.g. (time, texp, instid))
y : Array of size (N, D), measurements at the data coordinates
t_states: Array of size K, sorted time coordinate of all states (exposure starts and ends)
obsid : Array of size N, which observation (0,...,N-1) is being made at each state k
instid : Array of size N, which instrument (0,...,Ninst-1) recorded observation n
stateid : Array of size K, 0 for exposure-start, 1 for exposure-end
R : Observation noise covariance, shape (N, D, D)
return_v_S : Whether to return innovation and its covariance (for likelihood computation)
Returns:
m_filtered : filtered means
P_filtered : filtered covariances
m_predicted: predicted means
P_predicted: predicted covariances
"""
H_aug = kernel.observation_model
Phi_aug = kernel.transition_matrix
Q_aug = kernel.process_noise
RESET = kernel.reset_matrix
m0 = jnp.zeros(kernel.dimension)
P0 = kernel.stationary_covariance()
asso_params = make_associative_params(
Phi_aug,
H_aug,
Q_aug,
RESET,
R,
X,
y,
t_states,
obsid,
instid,
stateid,
m0,
P0,
)
A, b, C, eta, J = parallel_integrated_kalman_filter(asso_params)
m_pred, P_pred, v, S = postprocess(
Phi_aug,
Q_aug,
H_aug,
R,
X,
y,
t_states,
obsid,
stateid,
b,
C,
m0,
P0,
)
m_filt, P_filt = (b, C)
if return_v_S:
return m_filt, P_filt, m_pred, P_pred, v, S
else:
return m_filt, P_filt, m_pred, P_pred
[docs]
@jax.jit
def make_associative_params(
Phi_aug,
H_aug,
Q_aug,
RESET,
R,
X,
y,
t_states,
obsid,
instid,
stateid,
m0,
P0,
):
"""
Generate the associative parameters needed for parallel Kalman.
See Eqns. 10, 11, 12 from Sarkka & Garcia-Fernandez (2020)
"""
# precompute H at data coordinates
H_array = jax.vmap(H_aug)(X) # index with obsid
state_dim = H_array[0].shape[-1]
def make_first_params(
Phi_aug,
m0,
P0,
Reset,
):
Phi0 = Phi_aug(0, 0)
transition = Reset @ Phi0
m = transition @ m0
P = transition @ P0 @ transition.T # Q(0,0) = 0
A = Reset
b = jnp.squeeze(m)
C = P
eta = jnp.zeros(state_dim)
J = jnp.zeros_like(Phi0)
return (A, b, C, eta, J)
def to_start_params(ops):
(
Phi_dt,
Q_dt,
Reset,
obsid,
) = ops
A = Reset @ Phi_dt
b = jnp.zeros(Phi_dt.shape[-1])
C = Reset @ Q_dt @ Reset.T
eta = jnp.zeros(state_dim)
J = jnp.zeros_like(Phi_dt)
return (A, b, C, eta, J)
def to_end_params(ops):
(
Phi_dt,
Q_dt,
Reset,
obsid,
) = ops
I_nx = jnp.eye(Phi_dt.shape[-1])
Hk = H_array[obsid]
yk = y[obsid]
rk = R[obsid]
Sk = Hk @ Q_dt @ Hk.T + rk
Kk = jnp.linalg.solve(Sk.T, (Q_dt @ Hk.T).T).T
factor = I_nx - Kk @ Hk
A = factor @ Phi_dt
b = jnp.squeeze(Kk @ jnp.atleast_1d(yk))
C = factor @ Q_dt
C = 0.5 * (C + C.T)
_a = Phi_dt.T @ Hk.T
_b = jnp.linalg.solve(Sk, jnp.atleast_1d(yk))
eta = _a @ _b
_c = jnp.linalg.solve(Sk, Hk @ Phi_dt)
J = _a @ _c
return (A, b, C, eta, J)
A0, b0, C0, eta0, J0 = make_first_params(
Phi_aug,
m0,
P0,
RESET(instid[0]),
)
t_delta = jnp.diff(t_states)
Phis = jax.vmap(
Phi_aug,
in_axes=(
None,
0,
),
)(0, t_delta)
Qs = jax.vmap(
Q_aug,
in_axes=(
None,
0,
),
)(0, t_delta)
Resets = jax.vmap(RESET)(instid[1:])
ops = (
Phis,
Qs,
Resets,
obsid[1:],
)
A, b, C, eta, J = jax.vmap(
lambda sid, op: jax.lax.cond(
sid == 0,
to_start_params,
to_end_params,
op,
),
)(stateid[1:], ops)
A_all = jnp.concatenate([A0[jnp.newaxis, ...], A], axis=0)
b_all = jnp.concatenate([b0[jnp.newaxis, ...], b], axis=0)
C_all = jnp.concatenate([C0[jnp.newaxis, ...], C], axis=0)
eta_all = jnp.concatenate([eta0[jnp.newaxis, ...], eta], axis=0)
J_all = jnp.concatenate([J0[jnp.newaxis, ...], J], axis=0)
return (A_all, b_all, C_all, eta_all, J_all)
[docs]
def _combine_per_pair(left, right):
"""
See Eqn. 13 & 14 of Sarkka & Garcia-Fernandez (2020) for
a the algorithm and notation.
"""
Ai, bi, Ci, etai, Ji = left
Aj, bj, Cj, etaj, Jj = right
dim = Ai.shape[-1]
I = jnp.eye(dim)
D = I + Ci @ Jj
E = I + Jj @ Ci
Aij = Aj @ jnp.linalg.solve(D, Ai)
bij = Aj @ jnp.linalg.solve(D, bi + Ci @ etaj) + bj
Cij = Aj @ jnp.linalg.solve(D, Ci) @ Aj.T + Cj
etaij = Ai.T @ jnp.linalg.solve(E, etaj - Jj @ bi) + etai
Jij = Ai.T @ jnp.linalg.solve(E, Jj) @ Ai + Ji
Cij = 0.5 * (Cij + Cij.T)
Jij = 0.5 * (Jij + Jij.T)
return (Aij, bij, Cij, etaij, Jij)
[docs]
@jax.jit
def parallel_integrated_kalman_filter(asso_params):
"""
Jax implementation of the parallel Kalman filter algorithm
for integrated measurements.
See Section 4A of Sarkka & Garcia-Fernandez (2020) for
a detailed description of the algorithm and notation,
and section 3.2.4 of Rubenzahl & Hattori et al. (2025)
for the integrated measurement case.
Total runtime (span) complexity is O(N/T + logT) where N is the
number of time steps and T is the number of parallel threads.
"""
A, b, C, eta, J = jax.lax.associative_scan(
jax.vmap(_combine_per_pair),
asso_params,
)
return (A, b, C, eta, J)
[docs]
@jax.jit
def _calc_kf_predictions(
Phi_aug,
Q_aug,
H_aug,
t_states,
b,
C,
m0,
P0,
):
t_delta = jnp.diff(t_states)
Phi0 = Phi_aug(0, 0)
Phik = jax.vmap(Phi_aug, in_axes=(None, 0))(0, t_delta)
Phis = jnp.concatenate([Phi0[jnp.newaxis, ...], Phik], axis=0)
Q0 = Q_aug(0, 0)
Qk = jax.vmap(Q_aug, in_axes=(None, 0))(0, t_delta)
Qs = jnp.concatenate([Q0[jnp.newaxis, ...], Qk], axis=0)
m_prev = jnp.concatenate(
[m0[jnp.newaxis, ...], b[:-1]],
axis=0,
)
P_prev = jnp.concatenate(
[P0[jnp.newaxis, ...], C[:-1]],
axis=0,
)
m_pred = jax.vmap(lambda _Phi, _m: _Phi @ _m)(Phis, m_prev)
P_pred = jax.vmap(lambda _Phi, _P_prev, _Q: _Phi @ _P_prev @ _Phi.T + _Q)(
Phis,
P_prev,
Qs,
)
return (m_pred, P_pred)
[docs]
@jax.jit
def _calc_vS(
H_aug,
R,
m_pred,
P_pred,
X,
y,
stateid,
obsid,
):
ends_idx = jnp.nonzero(
stateid == 1,
size=y.shape[0],
)[0] # end states where there is a measurement
obsidx_in_ends_order = jnp.take(obsid, ends_idx)
mm = jnp.take(m_pred, ends_idx, axis=0)
Pm = jnp.take(P_pred, ends_idx, axis=0)
Xm = jax.tree.map(lambda x: jnp.take(x, obsidx_in_ends_order), X)
Hm = jax.vmap(H_aug)(Xm)
ym = jnp.take(y, obsidx_in_ends_order, axis=0) # (N_ends, D)
Rm = jnp.take(R, obsidx_in_ends_order, axis=0) # (N_ends, D, D)
y_pred = jax.vmap(
lambda H, m: H @ m,
in_axes=(0, 0),
)(Hm, mm)
v = ym - y_pred # both (N_ends, D)
S = jax.vmap(
lambda H, P, R: H @ P @ H.T + R,
in_axes=(0, 0, 0),
)(
Hm,
Pm,
Rm,
)
return (v, S)
[docs]
@jax.jit
def postprocess(
Phi_aug,
Q_aug,
H_aug,
R,
X,
y,
t_states,
obsid,
stateid,
b,
C,
m0,
P0,
):
m_pred, P_pred = _calc_kf_predictions(
Phi_aug,
Q_aug,
H_aug,
t_states,
b,
C,
m0,
P0,
)
v, S = _calc_vS(
H_aug,
R,
m_pred,
P_pred,
X,
y,
stateid,
obsid,
)
return (m_pred, P_pred, v, S)