solver
======

.. py:module:: smolgp.solvers.parallel.solver


Classes
-------

.. autoapisummary::

   smolgp.solvers.parallel.solver.ParallelStateSpaceSolver


Module Contents
---------------

.. py:class:: ParallelStateSpaceSolver(kernel: smolgp.kernels.base.StateSpaceModel, X: tinygp.helpers.JAXArray, noise: tinygp.helpers.JAXArray)

   Bases: :py:obj:`equinox.Module`


   A solver that uses ``jax.lax.associative_scan`` to implement
   parallel Kalman filtering and RTS smoothing


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: kernel
      :type:  smolgp.kernels.base.StateSpaceModel


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:method:: normalization() -> tinygp.helpers.JAXArray


   .. py:method:: Kalman(y, return_v_S=True) -> Any

      Wrapper for Kalman filter used with this solver



   .. py:method:: RTS(kalman_results) -> Any

      Wrapper for RTS smoother used with this solver



   .. py:method:: condition(y, return_v_S=False) -> tinygp.helpers.JAXArray

      Compute the Kalman predicted, filtered, and RTS smoothed
      means and covariances at each of the input coordinates



   .. py:method:: predict(X_test, conditioned_results) -> tinygp.helpers.JAXArray

      Algorithm for making predictions at arbitrary coordinates X_test

      :param X_test: The test coordinates.
      :param conditioned_results: The output of self.condition

      :returns: Predicted means of the states at X_test
                pred_var  : Predicted variances of the states at X_test
      :rtype: pred_mean

      There are three cases:
          1. Retrodiction  : smoothing from the first data point
                             using the prior as the prediction
          2. Interpolation : filtering from most recent data point
                             and smoothing from next future point
          3. Extrapolation : predicting from final filtered point



