smolgp
======

.. py:module:: smolgp

.. autoapi-nested-parse::

   ``smolgp`` is designed to be a drop-in extension of the `tinygp <https://github.com/dfm/tinygp>`_
   library for building Gaussian Process (GP) models in Python. As such, it is also built on top of
   `jax <https://github.com/google/jax>`_. The driving design philosophy is to match the API of ``tinygp``
   as closely as possible. With only a few exceptions, any existing code you have that uses ``tinygp`` should
   work with ``smolgp`` by simply by finding-and-replacing ``tiny`` with ``smol``.


   ``smolgp`` uses the state space
   representations of Gaussian Processes to implement linear-time (or up to logN with
   parallelization on GPU) solvers for GP regression and forecasting. It also implements
   "integrated" kernels that can model time-averaged measurements, such as those from
   long-exposure astronomical observations, which can also be solved in linear time and are
   also compatible with the parallel methods.

   The primary way that you interact with ``smolgp`` is to construct
   "kernel" functions using the building blocks provided in the ``kernels``
   subpackage (see :mod:`smolgp.kernels`), and then passing that to a
   :class:`GaussianProcess` object to do all the computations. Check out the
   :ref:`tutorials` for a more complete introduction.



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/smolgp/gp/index
   /autoapi/smolgp/helpers/index
   /autoapi/smolgp/kernels/index
   /autoapi/smolgp/solvers/index


Attributes
----------

.. autoapisummary::

   smolgp.__version__


Classes
-------

.. autoapisummary::

   smolgp.GaussianProcess


Package Contents
----------------

.. py:class:: GaussianProcess(kernel: GaussianProcess.__init__.kernels, X: tinygp.helpers.JAXArray, *, noise: tinygp.helpers.JAXArray | None = None, mean: tinygp.means.MeanBase | Callable[[tinygp.helpers.JAXArray], tinygp.helpers.JAXArray] | tinygp.helpers.JAXArray | None = None, solver: Any | None = None, mean_value: tinygp.helpers.JAXArray | None = None, variance_value: tinygp.helpers.JAXArray | None = None, covariance_value: Any | None = None, states: tinygp.helpers.JAXArray | None = None, use_unique_names: bool = True, **solver_kwargs: Any)

   Bases: :py:obj:`equinox.Module`


   An interface for designing a Gaussian Process regression model.

   :param kernel: The kernel function.
   :type kernel: Kernel
   :param X: The input coordinates — any PyTree compatible with ``kernel``
       whose leading dimension has size ``N_data``.
       For integrated kernels, pass ``(t, texp)`` where ``t`` is the array of
       exposure midpoints and ``texp`` is the array of exposure durations.
   :type X: JAXArray
   :param noise: Observation noise covariance matrices with shape
       ``(N, D, D)``, where ``N`` is the number of data points and ``D`` is
       the observation dimension (usually 1). Each slice ``noise[k]`` is the
       :math:`D \times D` noise covariance for the ``k``-th observation.
       A 1-D array of shape ``(N,)`` is interpreted as scalar per-observation
       variances. Defaults to :math:`\sqrt{\varepsilon_{\mathrm{machine}}}
       \cdot I` for all observations.
   :type noise: JAXArray, optional
   :param mean: A callable or constant mean function evaluated as
       ``mean(X)``.
   :type mean: Callable, optional
   :param solver: Solver class for filtering and smoothing. If ``None``
       (default), selected automatically based on the kernel type.


   .. py:attribute:: num_data
      :type:  int


   .. py:attribute:: dtype
      :type:  jax.numpy.dtype


   .. py:attribute:: kernel
      :type:  tinygp.kernels.Kernel


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: mean_function
      :type:  tinygp.means.MeanBase


   .. py:attribute:: mean
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: var
      :type:  tinygp.helpers.JAXArray | None


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: solver
      :type:  smolgp.solvers.StateSpaceSolver


   .. py:attribute:: states
      :type:  ConditionedStates


   .. py:property:: loc
      :type: tinygp.helpers.JAXArray


      If conditioned, this will be the mean at the data points
      Otherwise, it is just the prior mean.


   .. py:property:: variance
      :type: tinygp.helpers.JAXArray


      If conditioned, this will be the variance at the data points
      Otherwise, it is just the prior variance.


   .. py:property:: covariance
      :type: tinygp.helpers.JAXArray

      :abstractmethod:



   .. py:method:: log_probability(y: tinygp.helpers.JAXArray) -> tinygp.helpers.JAXArray

      Compute the log probability of this multivariate normal

      :param y: The observed data. This should have the shape
                ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
                data provided when instantiating this object.
      :type y: JAXArray

      :returns: The marginal log probability of this multivariate normal model,
                evaluated at ``y``.



   .. py:method:: condition(y: tinygp.helpers.JAXArray, X_test: tinygp.helpers.JAXArray | None = None, *, include_mean: bool = True, kernel: tinygp.kernels.Kernel | None = None) -> ConditionResult

      Condition the model on observed data

      :param y: The observed data. This should have the shape
                ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
                data provided when instantiating this object.
      :type y: JAXArray
      :param X_test: The coordinates where the prediction
                     should be evaluated. This should have a data type compatible
                     with the ``X`` data provided when instantiating this object. If
                     it is not provided, ``X`` will be used by default, so the
                     predictions will be made.
      :type X_test: JAXArray, optional
      :param include_mean: If ``True`` (default), the predicted
                           values will include the mean function evaluated at ``X_test``.
      :type include_mean: bool, optional
      :param kernel: A kernel to optionally specify the component
                     kernel to be used for predicting after conditioning. See
                     :ref:`multicomponent` for an example.
      :type kernel: Kernel, optional

      :returns: A named tuple where the first element ``log_probability`` is the log
                marginal probability of the model, and the second element ``gp`` is
                the :class:`GaussianProcess` object describing the conditional
                distribution evaluated at ``X_test``.



   .. py:method:: predict(X_test: tinygp.helpers.JAXArray | None = None, y: tinygp.helpers.JAXArray | None = None, *, return_full_state: bool = False, kernel: int | None = None, return_var: bool = False, observation_model: Any | None = None) -> tinygp.helpers.JAXArray | tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Predict the GP model at new test points conditioned on observed data

      :param X_test: The coordinates where the prediction
                     should be evaluated. This should have a data type compatible
                     with the ``X`` data provided when instantiating this object. If
                     it is not provided, ``X`` will be used by default, so the
                     predictions will be made at the data coordinates.
      :type X_test: JAXArray, optional
      :param y: The observed data. Only needs to be given if the GP
                has not yet been conditioned. This should have the shape
                ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
                data provided when instantiating this object.
      :type y: JAXArray
      :param include_mean: If ``True`` (default), the predicted
                           values will include the mean function evaluated at ``X_test``.
      :type include_mean: bool, optional
      :param return_var: If ``True`` (default), the variance of the
                         predicted values at ``X_test`` will be returned.
      :type return_var: bool, optional
      :param return_cov: If ``True``, the covariance of the
                         predicted values at ``X_test`` will be returned. If
                         ``return_var`` is ``True``, this flag will be ignored.
      :type return_cov: bool, optional
      :param observation_model: optionally provide a function of
                                X_test to define the output observation model.
                                Default will use that of the kernel.
      :type observation_model: Any, optional
      :param return_full_state: If ``True``, return the full predicted state
                                mean and covariance, rather than projecting to observation space. Default is
                                ``False``, i.e. the result is projected through kernel.observation_model.
      :type return_full_state: bool, optional
      :param kernel: If specified, the index of the kernel in a
                     multi-component model (for example, a sum or product of kernels)
                     to extract and project (if return_full_state is False) the prediction for.
      :type kernel: int, optional

      :returns: The mean of the predictive model evaluated at ``X_test``, with shape
                ``(N_test,)`` where ``N_test`` is the zeroth dimension of
                ``X_test``. If either ``return_var`` or ``return_cov`` is ``True``,
                the variance or covariance of the predicted process will also be
                returned with shape ``(N_test,)`` or ``(N_test, N_test)``
                respectively.



   .. py:method:: sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None = None) -> tinygp.helpers.JAXArray

      Generate samples from the prior process

      :param key: A ``jax`` random number key array. shape (tuple, optional): The
      :param number and shape of samples to: generate.

      :returns: The sampled realizations from the process with shape ``(N_data,) +
                shape`` where ``N_data`` is the zeroth dimension of the ``X``
                coordinates provided when instantiating this process.



   .. py:method:: numpyro_dist(**kwargs: Any) -> tinygp.numpyro_support.TinyDistribution

      Get the numpyro MultivariateNormal distribution for this process



   .. py:method:: _sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None) -> tinygp.helpers.JAXArray
      :abstractmethod:



   .. py:method:: _compute_log_prob(v: tinygp.helpers.JAXArray, S: tinygp.helpers.JAXArray) -> tinygp.helpers.JAXArray

      Compute the log-likelihood given v and S from the Kalman filter



   .. py:method:: get_component_mean(component: list | str, return_var: bool = False, **kwargs) -> Any

      Get the predictive mean (and variance) of a particular
      (or sum of) component kernel in a multi-component model
      evaluated at self.X

      :param X: The coordinates where the prediction
                should be evaluated. This should have a data type compatible
                with the ``X`` data provided when instantiating this object.
      :type X: JAXArray, optional
      :param component: The name(s) of the component kernel(s)
                        to extract the mean for. If a list of names is provided,
                        the joint mean and variance for that collection of kernels
                        will be returned.
      :type component: list | str
      :param return_var: If ``True``, also return the variances
                         of each component. Default is ``False``.
      :type return_var: bool, optional

      :returns:     component_mean (JAXArray)
                If ``return_var`` is ``True``:
                    component_mean (JAXArray)
                    component_var (JAXArray)
      :rtype: If ``return_var`` is ``False``



   .. py:method:: get_all_component_means(return_var: bool = False, **kwargs) -> Any

      Get the predictive mean (and optionally variance) of each
      component kernels individually, evaluated at self.X

      :param return_var: If ``True``, also return the variances
                         of each component. Default is ``False``.
      :type return_var: bool, optional

      :returns: If ``return_var`` is ``False``, a list of JAX arrays containing the
                means of each component kernel evaluated at the data points.
                If ``return_var`` is ``True``, a tuple where the first element is
                the list of means as before, and the second element is a list of
                JAX arrays containing the variances of each component kernel
                evaluated at the data points.



.. py:data:: __version__

