solvers
=======

.. py:module:: smolgp.solvers

.. autoapi-nested-parse::

   In ``smolgp``, "solvers" provide a swappable low-level interface for the
   Bayesian filtering and smoothing algorithms required for GP conditioning.
   New solvers can be contributed as external packages or pull requests to the
   `smolgp GitHub repository <https://github.com/smolgp-dev/smolgp>`_.

   The four built-in solvers are:

   1. :class:`StateSpaceSolver`: Standard Kalman filter and RTS smoother for
      instantaneous kernels (see :mod:`smolgp.kernels.base`). This is the
      default solver.

   2. :class:`IntegratedStateSpaceSolver`: Kalman filter and RTS smoother for
      integrated (time-averaged) measurement kernels
      (see :mod:`smolgp.kernels.integrated`).

   3. :class:`ParallelStateSpaceSolver`: GPU-parallelised version of
      :class:`StateSpaceSolver` with :math:`O(\log N)` complexity on compatible
      hardware.

   4. :class:`ParallelIntegratedStateSpaceSolver`: GPU-parallelised version of
      :class:`IntegratedStateSpaceSolver`.

   All solvers are exact up to numerical precision.

   Users generally do not need to instantiate solvers directly; :class:`~smolgp.GaussianProcess`
   selects the appropriate solver automatically based on the kernel type.



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/smolgp/solvers/integrated/index
   /autoapi/smolgp/solvers/kalman/index
   /autoapi/smolgp/solvers/parallel/index
   /autoapi/smolgp/solvers/rts/index
   /autoapi/smolgp/solvers/solver/index


Classes
-------

.. autoapisummary::

   smolgp.solvers.StateSpaceSolver
   smolgp.solvers.ParallelStateSpaceSolver
   smolgp.solvers.IntegratedStateSpaceSolver
   smolgp.solvers.ParallelIntegratedStateSpaceSolver


Package Contents
----------------

.. py:class:: StateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)

   Bases: :py:obj:`equinox.Module`


   A solver that implements Kalman filtering and RTS smoothing for state space GPs.

   Given a :class:`~smolgp.kernels.StateSpaceModel` kernel and a set of observed
   coordinates, this solver computes the Kalman filtered and
   Rauch-Tung-Striebel (RTS) smoothed posterior means and covariances using
   ``jax.lax.scan`` for efficient JIT-compiled sequential computation.

   Predictions at arbitrary test coordinates are handled by :meth:`predict`,
   which dispatches among retrodiction, interpolation, and extrapolation
   depending on whether each test point falls before, between, or after
   the observed data.

   :param kernel: The kernel function; must be a
       :class:`~smolgp.kernels.StateSpaceModel` instance.
   :type kernel: StateSpaceModel
   :param X: The input coordinates with leading dimension of size ``N``.
   :type X: JAXArray
   :param noise: Observation noise covariance array of shape ``(N, D, D)``,
       where ``D`` is the observation dimension.
   :type noise: JAXArray

   .. attribute:: kernel

      The kernel defining the state space model.

      :type: StateSpaceModel

   .. attribute:: X

      The observed input coordinates.

      :type: JAXArray

   .. attribute:: noise

      Per-observation noise covariance matrices, shape ``(N, D, D)``.

      :type: JAXArray

   .. attribute:: t_states

      Sortable scalar coordinates derived from ``X`` via
      :meth:`~smolgp.kernels.StateSpaceModel.coord_to_sortable`.

      :type: JAXArray


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: kernel
      :type:  smolgp.kernels.base.StateSpaceModel


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: t_states
      :type:  tinygp.helpers.JAXArray


   .. py:method:: normalization() -> tinygp.helpers.JAXArray


   .. py:method:: Kalman(y, return_v_S=False) -> Any

      Wrapper for Kalman filter used with this solver



   .. py:method:: RTS(kalman_results) -> Any

      Wrapper for RTS smoother used with this solver



   .. py:method:: condition(y, return_v_S=False) -> tinygp.helpers.JAXArray

      Compute the Kalman predicted, filtered, and RTS smoothed
      means and covariances at each of the input coordinates



   .. py:method:: predict(X_test, conditioned_results) -> tinygp.helpers.JAXArray

      Algorithm for making predictions at arbitrary coordinates ``X_test``.

      :param X_test: The test coordinates; same shape as ``self.X``.
      :type X_test: JAXArray
      :param conditioned_results: The output of :meth:`condition`.
      :type conditioned_results: tuple

      :returns: A pair ``(pred_mean, pred_var)`` of arrays with leading
                dimension ``M = len(X_test)``, giving the predicted state means
                and covariances at each test coordinate.
      :rtype: tuple

      Each test point is handled by one of three cases depending on its
      position relative to the observed data:

      1. **Retrodiction** — test point precedes all observations: smoothed
         backward from the first data point using the stationary prior.
      2. **Interpolation** — test point falls between two observations:
         Kalman-predicted forward from the nearest past point, then
         RTS-smoothed backward from the nearest future point.
      3. **Extrapolation** — test point follows all observations:
         Kalman-predicted forward from the final filtered state.



.. py:class:: ParallelStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)

   Bases: :py:obj:`equinox.Module`


   A solver that uses ``jax.lax.associative_scan`` to implement
   parallel Kalman filtering and RTS smoothing


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: kernel
      :type:  smolgp.kernels.base.StateSpaceModel


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:method:: normalization() -> tinygp.helpers.JAXArray


   .. py:method:: Kalman(y, return_v_S=True) -> Any

      Wrapper for Kalman filter used with this solver



   .. py:method:: RTS(kalman_results) -> Any

      Wrapper for RTS smoother used with this solver



   .. py:method:: condition(y, return_v_S=False) -> tinygp.helpers.JAXArray

      Compute the Kalman predicted, filtered, and RTS smoothed
      means and covariances at each of the input coordinates



   .. py:method:: predict(X_test, conditioned_results) -> tinygp.helpers.JAXArray

      Algorithm for making predictions at arbitrary coordinates X_test

      :param X_test: The test coordinates.
      :param conditioned_results: The output of self.condition

      :returns: Predicted means of the states at X_test
                pred_var  : Predicted variances of the states at X_test
      :rtype: pred_mean

      There are three cases:
          1. Retrodiction  : smoothing from the first data point
                             using the prior as the prediction
          2. Interpolation : filtering from most recent data point
                             and smoothing from next future point
          3. Extrapolation : predicting from final filtered point



.. py:class:: IntegratedStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)

   Bases: :py:obj:`equinox.Module`


   A solver that uses ``jax.lax.scan`` to implement Kalman filtering
   and RTS smoothing for integrated measurements


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: kernel
      :type:  smolgp.kernels.base.StateSpaceModel


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: state_coords
      :type:  tinygp.helpers.JAXArray


   .. py:method:: normalization() -> tinygp.helpers.JAXArray


   .. py:method:: Kalman(y, return_v_S=False) -> Any

      Wrapper for Kalman filter used with this solver



   .. py:method:: RTS(kalman_results) -> Any

      Wrapper for RTS smoother used with this solver



   .. py:method:: condition(y, return_v_S=False) -> tinygp.helpers.JAXArray

      Compute the Kalman predicted, filtered, and RTS smoothed
      means and covariances at each of the input coordinates



   .. py:method:: predict(X_test, conditioned_results) -> tinygp.helpers.JAXArray

      Algorithm for making predictions at arbitrary coordinates X_test

      :param X_test: The test coordinates.
      :param conditioned_results: The output of self.condition()

      :returns: Predicted means of the states at X_test
                pred_var  : Predicted variances of the states at X_test
      :rtype: pred_mean

      There are three cases:
          1. Retrodiction  : smoothing from the first data point
                             using the prior as the prediction
          2. Interpolation : filtering from most recent data point
                             and smoothing from next future point
          3. Extrapolation : predicting from final filtered point



.. py:class:: ParallelIntegratedStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)

   Bases: :py:obj:`equinox.Module`


   A solver that uses ``jax.lax.associative_scan`` to implement
   parallel Kalman filtering and RTS smoothing for integrated measurements


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: kernel
      :type:  smolgp.kernels.base.StateSpaceModel


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: state_coords
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: _state_coords
      :type:  tinygp.helpers.JAXArray


   .. py:method:: normalization() -> tinygp.helpers.JAXArray


   .. py:method:: Kalman(y, return_v_S=True) -> Any

      Wrapper for Kalman filter used with this solver



   .. py:method:: RTS(kalman_results) -> Any

      Wrapper for RTS smoother used with this solver



   .. py:method:: condition(y, return_v_S=True) -> tinygp.helpers.JAXArray

      Compute the Kalman predicted, filtered, and RTS smoothed
      means and covariances at each of the input coordinates



   .. py:method:: predict(X_test, conditioned_results) -> tinygp.helpers.JAXArray

      Algorithm for making predictions at arbitrary coordinates X_test

      :param X_test: The test coordinates.
      :param conditioned_results: The output of self.condition()
      :param observation_model: (optional) H for the test points
                                should be a function just like
                                self.kernel.observation_model

      There are three cases:
          1. Retrodiction  : smoothing from the first data point
                             using the prior as the prediction
          2. Interpolation : filtering from most recent data point
                             and smoothing from next future point
          3. Extrapolation : predicting from final filtered point



