gp
==

.. py:module:: smolgp.gp


Classes
-------

.. autoapisummary::

   smolgp.gp.ConditionedStates
   smolgp.gp.PredictedStates
   smolgp.gp.GaussianProcess
   smolgp.gp.ConditionResult


Functions
---------

.. autoapisummary::

   smolgp.gp.assign_unique_kernel_names
   smolgp.gp._default_jitter


Module Contents
---------------

.. py:function:: assign_unique_kernel_names(kernel: smolgp.kernels.StateSpaceModel) -> smolgp.kernels.StateSpaceModel

   Return a new kernel where duplicated leaf kernel names are made unique by appending _1, _2, etc.

   For example, if the original kernel has three components
   named "SHO", "Matern", and "Matern", they will be renamed
   to "SHO", "Matern_1", and "Matern_2". This is useful for
   ensuring that the component kernels can be uniquely identified
   when making predictions at test points or when extracting
   component contributions.


.. py:class:: ConditionedStates(X, t_states: tinygp.helpers.JAXArray, instid: tinygp.helpers.JAXArray, obsid: tinygp.helpers.JAXArray, stateid: tinygp.helpers.JAXArray, m_pred: tinygp.helpers.JAXArray, P_pred: tinygp.helpers.JAXArray, m_filt: tinygp.helpers.JAXArray, P_filt: tinygp.helpers.JAXArray, m_smooth: tinygp.helpers.JAXArray, P_smooth: tinygp.helpers.JAXArray)

   Bases: :py:obj:`equinox.Module`


   An object to hold the conditioned means and variances

   X: len(N) data coordinates
   t_states: len(K) time coordinates of all states
   instid  : len(N) instrument ID for each measurement
   obsid   : len(K) observation IDs corresponding to the measurement at each state
   stateid : len(K) state IDs corresponding to each state (0 for exposure-start, 1 for exposure-end)
   predicted_mean/var : len(K) Kalman predicted state
   filtered_mean/var  : len(K) Kalman filtered state
   smoothed_mean/var  : len(K) RTS smoothed state


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: t_states
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: instid
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: obsid
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: stateid
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: predicted_mean
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: filtered_mean
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: smoothed_mean
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: predicted_cov
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: filtered_cov
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: smoothed_cov
      :type:  tinygp.helpers.JAXArray


   .. py:method:: __call__()


   .. py:method:: project_at_data(observation_model)

      Project the states with measurements (e.g. exposure-ends)
      and sort back into original order as the data



.. py:class:: PredictedStates(t_states: tinygp.helpers.JAXArray, m: tinygp.helpers.JAXArray, P: tinygp.helpers.JAXArray, kernel: smolgp.kernels.StateSpaceModel)

   Bases: :py:obj:`equinox.Module`


   An object to hold the full predictive states

   t_states: time coordinates at each state
   mean : predictive mean vector for each state
   cov  : predictive covariance for each state


   .. py:attribute:: t_states
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: mean
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: cov
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: kernel
      :type:  smolgp.kernels.StateSpaceModel


   .. py:method:: project_mean(observation_model) -> tinygp.helpers.JAXArray

      The projected mean at self.t_states given an observation model.



   .. py:method:: project_variance(observation_model) -> tinygp.helpers.JAXArray

      The projected variance at self.t_states given an observation model.



   .. py:property:: loc
      :type: tinygp.helpers.JAXArray


      The overall mean at the predicted states


   .. py:property:: variance
      :type: tinygp.helpers.JAXArray


      The overall variance at the predicted states


   .. py:method:: get_component(component: str | list[str], return_var: bool = False) -> PredictedStates

      Extract the predicted states corresponding to a component kernel



   .. py:method:: get_all_components(return_var: bool = False) -> dict[str, Any]

      Extract the predicted mean/variance corresponding to each component kernel



.. py:class:: GaussianProcess(kernel: GaussianProcess.__init__.kernels, X: tinygp.helpers.JAXArray, *, noise: tinygp.helpers.JAXArray | None = None, mean: tinygp.means.MeanBase | Callable[[tinygp.helpers.JAXArray], tinygp.helpers.JAXArray] | tinygp.helpers.JAXArray | None = None, solver: Any | None = None, mean_value: tinygp.helpers.JAXArray | None = None, variance_value: tinygp.helpers.JAXArray | None = None, covariance_value: Any | None = None, states: tinygp.helpers.JAXArray | None = None, use_unique_names: bool = True, **solver_kwargs: Any)

   Bases: :py:obj:`equinox.Module`


   An interface for designing a Gaussian Process regression model.

   :param kernel: The kernel function.
   :type kernel: Kernel
   :param X: The input coordinates — any PyTree compatible with ``kernel``
       whose leading dimension has size ``N_data``.
       For integrated kernels, pass ``(t, texp)`` where ``t`` is the array of
       exposure midpoints and ``texp`` is the array of exposure durations.
   :type X: JAXArray
   :param noise: Observation noise covariance matrices with shape
       ``(N, D, D)``, where ``N`` is the number of data points and ``D`` is
       the observation dimension (usually 1). Each slice ``noise[k]`` is the
       :math:`D \times D` noise covariance for the ``k``-th observation.
       A 1-D array of shape ``(N,)`` is interpreted as scalar per-observation
       variances. Defaults to :math:`\sqrt{\varepsilon_{\mathrm{machine}}}
       \cdot I` for all observations.
   :type noise: JAXArray, optional
   :param mean: A callable or constant mean function evaluated as
       ``mean(X)``.
   :type mean: Callable, optional
   :param solver: Solver class for filtering and smoothing. If ``None``
       (default), selected automatically based on the kernel type.


   .. py:attribute:: num_data
      :type:  int


   .. py:attribute:: dtype
      :type:  jax.numpy.dtype


   .. py:attribute:: kernel
      :type:  tinygp.kernels.Kernel


   .. py:attribute:: X
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: mean_function
      :type:  tinygp.means.MeanBase


   .. py:attribute:: mean
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: var
      :type:  tinygp.helpers.JAXArray | None


   .. py:attribute:: noise
      :type:  tinygp.helpers.JAXArray


   .. py:attribute:: solver
      :type:  smolgp.solvers.StateSpaceSolver


   .. py:attribute:: states
      :type:  ConditionedStates


   .. py:property:: loc
      :type: tinygp.helpers.JAXArray


      If conditioned, this will be the mean at the data points
      Otherwise, it is just the prior mean.


   .. py:property:: variance
      :type: tinygp.helpers.JAXArray


      If conditioned, this will be the variance at the data points
      Otherwise, it is just the prior variance.


   .. py:property:: covariance
      :type: tinygp.helpers.JAXArray

      :abstractmethod:



   .. py:method:: log_probability(y: tinygp.helpers.JAXArray) -> tinygp.helpers.JAXArray

      Compute the log probability of this multivariate normal

      :param y: The observed data. This should have the shape
                ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
                data provided when instantiating this object.
      :type y: JAXArray

      :returns: The marginal log probability of this multivariate normal model,
                evaluated at ``y``.



   .. py:method:: condition(y: tinygp.helpers.JAXArray, X_test: tinygp.helpers.JAXArray | None = None, *, include_mean: bool = True, kernel: tinygp.kernels.Kernel | None = None) -> ConditionResult

      Condition the model on observed data

      :param y: The observed data. This should have the shape
                ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
                data provided when instantiating this object.
      :type y: JAXArray
      :param X_test: The coordinates where the prediction
                     should be evaluated. This should have a data type compatible
                     with the ``X`` data provided when instantiating this object. If
                     it is not provided, ``X`` will be used by default, so the
                     predictions will be made.
      :type X_test: JAXArray, optional
      :param include_mean: If ``True`` (default), the predicted
                           values will include the mean function evaluated at ``X_test``.
      :type include_mean: bool, optional
      :param kernel: A kernel to optionally specify the component
                     kernel to be used for predicting after conditioning. See
                     :ref:`multicomponent` for an example.
      :type kernel: Kernel, optional

      :returns: A named tuple where the first element ``log_probability`` is the log
                marginal probability of the model, and the second element ``gp`` is
                the :class:`GaussianProcess` object describing the conditional
                distribution evaluated at ``X_test``.



   .. py:method:: predict(X_test: tinygp.helpers.JAXArray | None = None, y: tinygp.helpers.JAXArray | None = None, *, return_full_state: bool = False, kernel: int | None = None, return_var: bool = False, observation_model: Any | None = None) -> tinygp.helpers.JAXArray | tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Predict the GP model at new test points conditioned on observed data

      :param X_test: The coordinates where the prediction
                     should be evaluated. This should have a data type compatible
                     with the ``X`` data provided when instantiating this object. If
                     it is not provided, ``X`` will be used by default, so the
                     predictions will be made at the data coordinates.
      :type X_test: JAXArray, optional
      :param y: The observed data. Only needs to be given if the GP
                has not yet been conditioned. This should have the shape
                ``(N_data,)``, where ``N_data`` was the zeroth axis of the ``X``
                data provided when instantiating this object.
      :type y: JAXArray
      :param include_mean: If ``True`` (default), the predicted
                           values will include the mean function evaluated at ``X_test``.
      :type include_mean: bool, optional
      :param return_var: If ``True`` (default), the variance of the
                         predicted values at ``X_test`` will be returned.
      :type return_var: bool, optional
      :param return_cov: If ``True``, the covariance of the
                         predicted values at ``X_test`` will be returned. If
                         ``return_var`` is ``True``, this flag will be ignored.
      :type return_cov: bool, optional
      :param observation_model: optionally provide a function of
                                X_test to define the output observation model.
                                Default will use that of the kernel.
      :type observation_model: Any, optional
      :param return_full_state: If ``True``, return the full predicted state
                                mean and covariance, rather than projecting to observation space. Default is
                                ``False``, i.e. the result is projected through kernel.observation_model.
      :type return_full_state: bool, optional
      :param kernel: If specified, the index of the kernel in a
                     multi-component model (for example, a sum or product of kernels)
                     to extract and project (if return_full_state is False) the prediction for.
      :type kernel: int, optional

      :returns: The mean of the predictive model evaluated at ``X_test``, with shape
                ``(N_test,)`` where ``N_test`` is the zeroth dimension of
                ``X_test``. If either ``return_var`` or ``return_cov`` is ``True``,
                the variance or covariance of the predicted process will also be
                returned with shape ``(N_test,)`` or ``(N_test, N_test)``
                respectively.



   .. py:method:: sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None = None) -> tinygp.helpers.JAXArray

      Generate samples from the prior process

      :param key: A ``jax`` random number key array. shape (tuple, optional): The
      :param number and shape of samples to: generate.

      :returns: The sampled realizations from the process with shape ``(N_data,) +
                shape`` where ``N_data`` is the zeroth dimension of the ``X``
                coordinates provided when instantiating this process.



   .. py:method:: numpyro_dist(**kwargs: Any) -> tinygp.numpyro_support.TinyDistribution

      Get the numpyro MultivariateNormal distribution for this process



   .. py:method:: _sample(key: jax.random.KeyArray, shape: collections.abc.Sequence[int] | None) -> tinygp.helpers.JAXArray
      :abstractmethod:



   .. py:method:: _compute_log_prob(v: tinygp.helpers.JAXArray, S: tinygp.helpers.JAXArray) -> tinygp.helpers.JAXArray

      Compute the log-likelihood given v and S from the Kalman filter



   .. py:method:: get_component_mean(component: list | str, return_var: bool = False, **kwargs) -> Any

      Get the predictive mean (and variance) of a particular
      (or sum of) component kernel in a multi-component model
      evaluated at self.X

      :param X: The coordinates where the prediction
                should be evaluated. This should have a data type compatible
                with the ``X`` data provided when instantiating this object.
      :type X: JAXArray, optional
      :param component: The name(s) of the component kernel(s)
                        to extract the mean for. If a list of names is provided,
                        the joint mean and variance for that collection of kernels
                        will be returned.
      :type component: list | str
      :param return_var: If ``True``, also return the variances
                         of each component. Default is ``False``.
      :type return_var: bool, optional

      :returns:     component_mean (JAXArray)
                If ``return_var`` is ``True``:
                    component_mean (JAXArray)
                    component_var (JAXArray)
      :rtype: If ``return_var`` is ``False``



   .. py:method:: get_all_component_means(return_var: bool = False, **kwargs) -> Any

      Get the predictive mean (and optionally variance) of each
      component kernels individually, evaluated at self.X

      :param return_var: If ``True``, also return the variances
                         of each component. Default is ``False``.
      :type return_var: bool, optional

      :returns: If ``return_var`` is ``False``, a list of JAX arrays containing the
                means of each component kernel evaluated at the data points.
                If ``return_var`` is ``True``, a tuple where the first element is
                the list of means as before, and the second element is a list of
                JAX arrays containing the variances of each component kernel
                evaluated at the data points.



.. py:class:: ConditionResult

   Bases: :py:obj:`NamedTuple`


   The result of conditioning a :class:`GaussianProcess` on data

   This has two entries, ``log_probability`` and ``gp``, that are described
   below.


   .. py:attribute:: log_probability
      :type:  tinygp.helpers.JAXArray

      The log probability of the conditioned model

      In other words, this is the marginal likelihood for the kernel parameters,
      given the observed data, or the multivariate normal log probability
      evaluated at the given data.


   .. py:attribute:: gp
      :type:  GaussianProcess

      A :class:`GaussianProcess` describing the conditional distribution

      This will have a mean and covariance conditioned on the observed data, but
      it is otherwise a fully functional GP that can sample from or condition
      further (although that's probably not going to be very efficient).


.. py:function:: _default_jitter(reference: tinygp.helpers.JAXArray) -> tinygp.helpers.JAXArray

   Default to adding some amount of jitter to the diagonal, just in case,
   we use sqrt(eps) for the dtype of the mean function because that seems to
   give sensible results in general.


