solver
======

.. py:module:: smolgp.solvers.solver


Classes
-------

.. autoapisummary::

   smolgp.solvers.solver.StateSpaceSolver


Module Contents
---------------

.. py:class:: StateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)

   Bases: :py:obj:`equinox.Module`


   A solver that implements Kalman filtering and RTS smoothing for state space GPs.

   Given a :class:`~smolgp.kernels.StateSpaceModel` kernel and a set of observed
   coordinates, this solver computes the Kalman filtered and
   Rauch-Tung-Striebel (RTS) smoothed posterior means and covariances using
   ``jax.lax.scan`` for efficient JIT-compiled sequential computation.

   Predictions at arbitrary test coordinates are handled by :meth:`predict`,
   which dispatches among retrodiction, interpolation, and extrapolation
   depending on whether each test point falls before, between, or after
   the observed data.

   :param kernel: The kernel function; must be a
       :class:`~smolgp.kernels.StateSpaceModel` instance.
   :type kernel: StateSpaceModel
   :param X: The input coordinates with leading dimension of size ``N``.
   :type X: JAXArray
   :param noise: Observation noise covariance array of shape ``(N, D, D)``,
       where ``D`` is the observation dimension.
   :type noise: JAXArray

   .. attribute:: kernel

      The kernel defining the state space model.

      :type: StateSpaceModel

   .. attribute:: X

      The observed input coordinates.

      :type: JAXArray

   .. attribute:: noise

      Per-observation noise covariance matrices, shape ``(N, D, D)``.

      :type: JAXArray

   .. attribute:: t_states

      Sortable scalar coordinates derived from ``X`` via
      :meth:`~smolgp.kernels.StateSpaceModel.coord_to_sortable`.

      :type: JAXArray


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: kernel
      :type:  smolgp.kernels.base.StateSpaceModel


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: t_states
      :type:  tinygp.helpers.JAXArray


   .. py:method:: normalization() -> tinygp.helpers.JAXArray


   .. py:method:: Kalman(y, return_v_S=False) -> Any

      Wrapper for Kalman filter used with this solver



   .. py:method:: RTS(kalman_results) -> Any

      Wrapper for RTS smoother used with this solver



   .. py:method:: condition(y, return_v_S=False) -> tinygp.helpers.JAXArray

      Compute the Kalman predicted, filtered, and RTS smoothed
      means and covariances at each of the input coordinates



   .. py:method:: predict(X_test, conditioned_results) -> tinygp.helpers.JAXArray

      Algorithm for making predictions at arbitrary coordinates ``X_test``.

      :param X_test: The test coordinates; same shape as ``self.X``.
      :type X_test: JAXArray
      :param conditioned_results: The output of :meth:`condition`.
      :type conditioned_results: tuple

      :returns: A pair ``(pred_mean, pred_var)`` of arrays with leading
                dimension ``M = len(X_test)``, giving the predicted state means
                and covariances at each test coordinate.
      :rtype: tuple

      Each test point is handled by one of three cases depending on its
      position relative to the observed data:

      1. **Retrodiction** — test point precedes all observations: smoothed
         backward from the first data point using the stationary prior.
      2. **Interpolation** — test point falls between two observations:
         Kalman-predicted forward from the nearest past point, then
         RTS-smoothed backward from the nearest future point.
      3. **Extrapolation** — test point follows all observations:
         Kalman-predicted forward from the final filtered state.



