from __future__ import annotations
import jax
import jax.numpy as jnp
[docs]
def ParallelKalmanFilter(kernel, X, y, R, return_v_S=False):
"""
Wrapper for the parallel Kalman filter.
Parameters:
kernel: StateSpaceModel kernel
X: data coordinates, e.g. time or (time, texp, instid)
y: observations, shape (N, D)
R: observation noise covariance, shape (N, D, D)
Returns:
A:
b: filtered means
C: filtered covariances
eta:
J:
"""
H = kernel.observation_model
Phi = kernel.transition_matrix
Q = kernel.process_noise
m0 = jnp.zeros(kernel.dimension)
P0 = kernel.stationary_covariance()
t = kernel.coord_to_sortable(X)
asso_params = make_associative_params(Phi, H, Q, R, t, y, m0, P0)
A, b, C, eta, J = parallel_kalman_filter(asso_params)
m_pred, P_pred, v, S = postprocess(Phi, H, Q, R, X, t, y, b, C, m0, P0)
# return (A, b, C, eta, J), (m_pred, P_pred, v, S)
m_filt, P_filt = b, C
if return_v_S:
return m_filt, P_filt, m_pred, P_pred, v, S
else:
return m_filt, P_filt, m_pred, P_pred
[docs]
@jax.jit
def make_associative_params(Phi, H, Q, R, t, y, m0, P0):
"""Generate the associative parameters needed for parallel Kalman
See Eqns. 10, 11, 12 from Sarkka & Garcia-Fernandez (2020)
"""
def make_first_params(
Phi,
H,
m0,
P0,
y0,
r0,
):
Phi0 = Phi(0, 0)
H0 = H(0) # this is sort of unnecessary but we'll keep it for now
m = Phi0 @ m0
P = Phi0 @ P0 @ Phi0.T # Q(0,0) = 0
S = H0 @ P @ H0.T + r0
K = jnp.linalg.solve(S.T, (P @ H0.T).T).T
A = jnp.zeros_like(Phi0)
b = jnp.squeeze(m + K @ (y0 - H0 @ m))
C = P - K @ S @ K.T
_M = Phi0.T @ H0.T
eta = jnp.squeeze(_M @ jnp.linalg.solve(S, jnp.atleast_1d(y0)))
J = _M @ jnp.linalg.solve(S, _M.T)
return (A, b, C, eta, J)
def make_generic_params(
Phi,
H,
Q,
t_delta,
y,
r,
):
Phi_dt = Phi(0, t_delta)
I = jnp.eye(Phi_dt.shape[-1])
Hk = H(t_delta) # TODO: this is wrong, pass data coordinate here
Q_dt = Q(0, t_delta)
S = Hk @ Q_dt @ Hk.T + r
K = jnp.linalg.solve(S.T, (Q_dt @ Hk.T).T).T
A = (I - K @ Hk) @ Phi_dt
b = jnp.squeeze(K @ jnp.atleast_1d(y)) # remove atleast_1d?
C = (I - K @ Hk) @ Q_dt
_M = Phi_dt.T @ Hk.T
eta = jnp.squeeze(_M @ jnp.linalg.solve(S, jnp.atleast_1d(y)))
J = _M @ jnp.linalg.solve(S, _M.T)
return (A, b, C, eta, J)
A0, b0, C0, eta0, J0 = make_first_params(Phi, H, m0, P0, y[0], R[0])
t_delta = jnp.diff(t)
A, b, C, eta, J = jax.vmap(
make_generic_params,
in_axes=(None, None, None, 0, 0, 0),
)(Phi, H, Q, t_delta, y[1:], R[1:])
A_all = jnp.concatenate([A0[jnp.newaxis, ...], A], axis=0)
b_all = jnp.concatenate([b0[jnp.newaxis, ...], b], axis=0)
C_all = jnp.concatenate([C0[jnp.newaxis, ...], C], axis=0)
eta_all = jnp.concatenate([eta0[jnp.newaxis, ...], eta], axis=0)
J_all = jnp.concatenate([J0[jnp.newaxis, ...], J], axis=0)
return (A_all, b_all, C_all, eta_all, J_all)
[docs]
@jax.jit
def _combine_per_pair(left, right):
"""See Eqn. 13 & 14 of Sarkka & Garcia-Fernandez (2020) for
a the algorithm and notation.
"""
Ai, bi, Ci, etai, Ji = left
Aj, bj, Cj, etaj, Jj = right
dim = Ai.shape[-1]
I = jnp.eye(dim)
D = I + Ci @ Jj
E = I + Jj @ Ci
Aij = Aj @ jnp.linalg.solve(D, Ai)
bij = Aj @ jnp.linalg.solve(D, bi + Ci @ etaj) + bj
Cij = Aj @ jnp.linalg.solve(D, Ci) @ Aj.T + Cj
etaij = Ai.T @ jnp.linalg.solve(E, etaj - Jj @ bi) + etai
Jij = Ai.T @ jnp.linalg.solve(E, Jj) @ Ai + Ji
Cij = 0.5 * (Cij + Cij.T)
Jij = 0.5 * (Jij + Jij.T)
return (Aij, bij, Cij, etaij, Jij)
[docs]
@jax.jit
def parallel_kalman_filter(asso_params):
"""
Jax implementation of the parallel Kalman filter algorithm
See Section 4A of Sarkka & Garcia-Fernandez (2020) for
a detailed description of the algorithm and notation.
Total runtime (span) complexity is ~O(logN) where N is the number
of time steps.
"""
A, b, C, eta, J = jax.lax.associative_scan(
jax.vmap(_combine_per_pair),
asso_params,
)
return (A, b, C, eta, J)
[docs]
@jax.jit
def postprocess(Phi, H, Q, R, X, t, y, b, C, m0, P0):
t_delta = jnp.diff(t)
dim = b.shape[-1]
I = jnp.eye(dim)
Phis = jax.vmap(Phi, in_axes=(None, 0))(0, t_delta)
Qs = jax.vmap(Q, in_axes=(None, 0))(0, t_delta)
Phi_all = jnp.concatenate(
[I[jnp.newaxis, ...], Phis],
axis=0,
)
Q_all = jnp.concatenate(
[jnp.zeros_like(I)[jnp.newaxis, ...], Qs],
axis=0,
)
H_all = jax.vmap(H, in_axes=(0,))(X)
R_all = R
m_prev = jnp.concatenate(
[m0[jnp.newaxis, ...], b[:-1]],
axis=0,
)
P_prev = jnp.concatenate(
[P0[jnp.newaxis, ...], C[:-1]],
axis=0,
)
m_pred = jax.vmap(lambda _Phi, _m: _Phi @ _m)(
Phi_all,
m_prev,
)
P_pred = jax.vmap(lambda _Phi, _P_prev, _Q: _Phi @ _P_prev @ _Phi.T + _Q)(
Phi_all,
P_prev,
Q_all,
)
y_pred = jax.vmap(lambda _H, _m: _H @ _m, in_axes=(0, 0))(
H_all,
m_pred,
)
v = y - y_pred # both (N, D)
S = jax.vmap(lambda _H, _P, _R: _H @ _P @ _H.T + _R)(
H_all,
P_pred,
R_all,
)
return (m_pred, P_pred, v, S)