Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dissociate posterior updates from prediction errors #139

Merged
merged 17 commits into from
Nov 28, 2023
Merged
176 changes: 93 additions & 83 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,137 +7,146 @@
:depth: 5

API
+++
###

Updates functions
-----------------
*****************

Update functions are the heart of probabilistic networks as they shape the propagation of beliefs in the neural hierarchy. The library implements the standard variational updates for value and volatility coupling, as described in Weber et al. (2023).

The `updates` module contains the update functions used during the belief propagation. Update functions are available through three sub-modules, organized according to their functional roles. We usually dissociate the first updates, triggered top-down (from the leaves to the roots of the networks), that are prediction steps and recover the current state of inference. The second updates are the prediction error, signalling the divergence between the prediction and the new observation (for input nodes), or state (for state nodes). Interleaved with these steps are posterior update steps, where a node receives prediction errors from the child nodes and estimates new statistics.


The `updates` module contains the update function used both for prediction and prediction-error steps during the belief propagation. These functions call intenaly the more specific update function listed in the prediction and prediction-error sub-modules.
Posterior updates
=================

Updating binary nodes
=====================
Update the sufficient statistics of a state node after receiving prediction errors from children nodes. The prediction errors from all the children below the node should be computed before calling the posterior update step.

Core functionnalities to update *binary* nodes.
Binary nodes
------------

.. currentmodule:: pyhgf.updates.binary
.. currentmodule:: pyhgf.updates.posterior.binary

.. autosummary::
:toctree: generated/pyhgf.updates.binary
:toctree: generated/pyhgf.updates.posterior.binary

binary_node_prediction_error
binary_node_prediction
binary_input_prediction_error
binary_node_update_infinite
binary_node_update_finite

Updating continuous nodes
=========================

Core functionnalities to update *continuous* nodes.
Categorical nodes
-----------------

.. currentmodule:: pyhgf.updates.continuous
.. currentmodule:: pyhgf.updates.posterior.categorical

.. autosummary::
:toctree: generated/pyhgf.updates.continuous

continuous_node_prediction_error
continuous_node_prediction
continuous_input_prediction_error
:toctree: generated/pyhgf.updates.posterior.categorical

Updating categorical nodes
==========================
categorical_input_update

Core functionnalities to update *categorical* nodes.
Continuous nodes
----------------

.. currentmodule:: pyhgf.updates.categorical
.. currentmodule:: pyhgf.updates.posterior.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.categorical
:toctree: generated/pyhgf.updates.posterior.continuous

categorical_input_update
posterior_update_mean_continuous_node
posterior_update_precision_continuous_node
continuous_node_update
continuous_node_update_ehgf
continuous_node_update_missing_observations
continuous_blank_update

Prediction error steps
======================
Prediction steps
================

Propagate prediction errors to the value and volatility parents of a given node.
Compute the expectation for future observation given the influence of parent nodes. The prediction step are executed for all nodes, top-down, before any observation.

Binary nodes
~~~~~~~~~~~~
------------

.. currentmodule:: pyhgf.updates.prediction_error.inputs.binary
.. currentmodule:: pyhgf.updates.prediction.binary

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.inputs.binary
:toctree: generated/pyhgf.updates.prediction.binary

prediction_error_input_value_parent
input_surprise_inf
input_surprise_reg
binary_state_node_prediction

.. currentmodule:: pyhgf.updates.prediction_error.nodes.binary
Continuous nodes
----------------

.. currentmodule:: pyhgf.updates.prediction.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.nodes.binary
:toctree: generated/pyhgf.updates.prediction.continuous

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_value_parent
predict_mean
predict_precision
continuous_node_prediction

Continuous nodes
~~~~~~~~~~~~~~~~
Prediction error steps
======================

Updating continuous input nodes.
Compute the value and volatility prediction errors of a given node. The prediction error can only be computed after the posterior update (or observation) of a given node.

.. currentmodule:: pyhgf.updates.prediction_error.inputs.continuous
Inputs
------

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.inputs.continuous
Binary inputs
^^^^^^^^^^^^^

.. currentmodule:: pyhgf.updates.prediction_error.inputs.binary

prediction_error_input_precision_value_parent
prediction_error_input_precision_volatility_parent
prediction_error_input_mean_volatility_parent
prediction_error_input_mean_value_parent
.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.inputs.binary

binary_input_prediction_error_infinite_precision
binary_input_prediction_error_finite_precision

Updating continuous state nodes.
Continuous inputs
^^^^^^^^^^^^^^^^^

.. currentmodule:: pyhgf.updates.prediction_error.nodes.continuous
.. currentmodule:: pyhgf.updates.prediction_error.inputs.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction_error.nodes.continuous
:toctree: generated/pyhgf.updates.prediction_error.inputs.continuous

prediction_error_mean_value_parent
prediction_error_precision_value_parent
prediction_error_precision_volatility_parent
prediction_error_mean_volatility_parent
continuous_input_volatility_prediction_error
continuous_input_value_prediction_error
continuous_input_prediction_error

Prediction steps
================
State nodes
-----------

Compute the expectation for future observation given the influence of parent nodes.

Binary nodes
~~~~~~~~~~~~
Binary state nodes
^^^^^^^^^^^^^^^^^^

.. currentmodule:: pyhgf.updates.prediction.binary
.. currentmodule:: pyhgf.updates.prediction_error.nodes.binary

.. autosummary::
:toctree: generated/pyhgf.updates.prediction.binary
:toctree: generated/pyhgf.updates.prediction_error.nodes.binary

predict_binary_state_node
binary_state_node_prediction_error

Continuous nodes
~~~~~~~~~~~~~~~~
Continuous state nodes
^^^^^^^^^^^^^^^^^^^^^^

.. currentmodule:: pyhgf.updates.prediction.continuous
.. currentmodule:: pyhgf.updates.prediction_error.nodes.continuous

.. autosummary::
:toctree: generated/pyhgf.updates.prediction.continuous
:toctree: generated/pyhgf.updates.prediction_error.nodes.continuous

predict_mean
predict_precision
continuous_node_value_prediction_error
continuous_node_volatility_prediction_error
continuous_node_prediction_error

Distribution
------------
************

The Herarchical Gaussian Filter as a PyMC distribution. This distribution can be
The Hierarchical Gaussian Filter as a PyMC distribution. This distribution can be
embedded in models using PyMC>=5.0.0.

.. currentmodule:: pyhgf.distribution
Expand All @@ -150,11 +159,11 @@ embedded in models using PyMC>=5.0.0.
HGFDistribution

Model
-----
*****

The main class used to create a standard Hierarchical Gaussian Filter for binary or
The main class is used to create a standard Hierarchical Gaussian Filter for binary or
continuous inputs, with two or three levels. This class wraps the previous JAX modules
and create a standard node structure for these models.
and creates a standard node structure for these models.

.. currentmodule:: pyhgf.model

Expand All @@ -164,9 +173,9 @@ and create a standard node structure for these models.
HGF

Plots
-----
*****

Plotting functionnalities to visualize parameters trajectories and correlations after
Plotting functionalities to visualize parameters trajectories and correlations after
observing new data.

.. currentmodule:: pyhgf.plots
Expand All @@ -180,9 +189,9 @@ observing new data.
plot_nodes

Response
--------
********

A collection of responses functions. A response function is simply a callable taking at
A collection of response functions. A response function is simply a callable taking at
least the HGF instance as input after observation and returning surprise.

.. currentmodule:: pyhgf.response
Expand All @@ -195,7 +204,7 @@ least the HGF instance as input after observation and returning surprise.
first_level_binary_surprise

Networks
--------
********

Utilities for manipulating networks of probabilistic nodes.

Expand All @@ -211,7 +220,7 @@ Utilities for manipulating networks of probabilistic nodes.
get_update_sequence

Math
----
****

Math functions and probability densities.

Expand All @@ -224,4 +233,5 @@ Math functions and probability densities.
sigmoid
binary_surprise
gaussian_surprise
dirichlet_kullback_leibler
dirichlet_kullback_leibler
binary_surprise_finite_precision
Loading
Loading