+[docs]
+classNetwork:
+"""A predictive coding neural network.
+
+ This is the core class to define and manipulate neural networks, that consists in
+ 1. attributes, 2. structure and 3. update sequences.
+
+ Attributes
+ ----------
+ attributes :
+ The attributes of the probabilistic nodes.
+ edges :
+ The edges of the probabilistic nodes as a tuple of
+ :py:class:`pyhgf.typing.AdjacencyLists`. The tuple has the same length as the
+ node number. For each node, the index lists the value/volatility
+ parents/children.
+ inputs :
+ Information on the input nodes.
+ node_trajectories :
+ The dynamic of the node's beliefs after updating.
+ update_sequence :
+ The sequence of update functions that are applied during the belief propagation
+ step.
+ scan_fn :
+ The function that is passed to :py:func:`jax.lax.scan`. This is a pre-
+ parametrized version of :py:func:`pyhgf.networks.beliefs_propagation`.
+
+ """
+
+
+
+
+ defcreate_belief_propagation_fn(self,overwrite:bool=True)->"Network":
+"""Create the belief propagation function.
+
+ .. note:
+ This step is called by default when using py:meth:`input_data`.
+
+ Parameters
+ ----------
+ overwrite :
+ If `True` (default), create a new belief propagation function and ignore
+ preexisting values. Otherwise, do not create a new function if the attribute
+ `scan_fn` is already defined.
+
+ """
+ # create the network structure from edges and inputs
+ self.inputs=Inputs(self.inputs.idx,self.inputs.kind)
+ self.structure=(self.inputs,self.edges)
+
+ # create the update sequence if it does not already exist
+ ifself.update_sequenceisNone:
+ self.set_update_sequence()
+
+ # create the belief propagation function
+ # this function is used by scan to loop over observations
+ ifself.scan_fnisNone:
+ self.scan_fn=Partial(
+ beliefs_propagation,
+ update_sequence=self.update_sequence,
+ structure=self.structure,
+ )
+ else:
+ ifoverwrite:
+ self.scan_fn=Partial(
+ beliefs_propagation,
+ update_sequence=self.update_sequence,
+ structure=self.structure,
+ )
+
+ returnself
+
+ defcache_belief_propagation_fn(self)->"Network":
+"""Blank call to the belief propagation function.
+
+ .. note:
+ This step is called by default when using py:meth:`input_data`. It can
+ sometimes be convenient to call this step independently to chache the JITed
+ function before fitting the model.
+
+ """
+ ifself.scan_fnisNone:
+ self=self.create_belief_propagation_fn()
+
+ # blanck call to cache the JIT-ed functions
+ _=scan(
+ self.scan_fn,
+ self.attributes,
+ (
+ jnp.ones((1,len(self.inputs.idx))),
+ jnp.ones((1,1)),
+ jnp.ones((1,1)),
+ ),
+ )
+
+ returnself
+
+ definput_data(
+ self,
+ input_data:np.ndarray,
+ time_steps:Optional[np.ndarray]=None,
+ observed:Optional[np.ndarray]=None,
+ ):
+"""Add new observations.
+
+ Parameters
+ ----------
+ input_data :
+ 2d array of new observations (time x features).
+ time_steps :
+ Time vector (optional). If `None`, the time vector will default to
+ `np.ones(len(input_data))`. This vector is automatically transformed
+ into a time steps vector.
+ observed :
+ A 2d boolean array masking `input_data`. In case of missing inputs, (i.e.
+ `observed` is `0`), the input node will have value and volatility set to
+ `0.0`. If the parent(s) of this input receive prediction error from other
+ children, they simply ignore this one. If they are not receiving other
+ prediction errors, they are updated by keeping the same mean by decreasing
+ the precision as a function of time to reflect the evolution of the
+ underlying Gaussian Random Walk.
+ .. warning::
+ Missing inputs are missing observations from the agent's perspective and
+ should not be used to handle missing data points that were observed (e.g.
+ missing in the event log, or rejected trials).
+
+ """
+ ifself.scan_fnisNone:
+ self=self.create_belief_propagation_fn()
+ iftime_stepsisNone:
+ time_steps=np.ones((len(input_data),1))# time steps vector
+ else:
+ time_steps=time_steps[...,jnp.newaxis]
+ ifinput_data.ndim==1:
+ input_data=input_data[...,jnp.newaxis]
+
+ # is it observation or missing inputs
+ ifobservedisNone:
+ observed=np.ones(input_data.shape,dtype=int)
+
+ # this is where the model loops over the whole input time series
+ # at each time point, the node structure is traversed and beliefs are updated
+ # using precision-weighted prediction errors
+ last_attributes,node_trajectories=scan(
+ self.scan_fn,self.attributes,(input_data,time_steps,observed)
+ )
+
+ # trajectories of the network attributes a each time point
+ self.node_trajectories=node_trajectories
+ self.last_attributes=last_attributes
+
+ returnself
+
+ definput_custom_sequence(
+ self,
+ update_branches:Tuple[UpdateSequence],
+ branches_idx:np.array,
+ input_data:np.ndarray,
+ time_steps:Optional[np.ndarray]=None,
+ observed:Optional[np.ndarray]=None,
+ ):
+"""Add new observations with custom update sequences.
+
+ This method should be used when the update sequence should be adapted to the
+ input data. (e.g. in the case of missing/null observations that should not
+ trigger node update).
+
+ .. note::
+ When the dynamic adaptation of the update sequence is not required, it is
+ recommended to use :py:meth:`pyhgf.model.HGF.input_data` instead as this
+ might result in performance improvement.
+
+ Parameters
+ ----------
+ update_branches :
+ A tuple of UpdateSequence listing the possible update sequences.
+ branches_idx :
+ The branches indexes (integers). Should have the same length as the input
+ data.
+ input_data :
+ 2d array of new observations (time x features).
+ time_steps :
+ Time vector (optional). If `None`, the time vector will default to
+ `np.ones(len(input_data))`. This vector is automatically transformed
+ into a time steps vector.
+ observed :
+ A 2d boolean array masking `input_data`. In case of missing inputs, (i.e.
+ `observed` is `0`), the input node will have value and volatility set to
+ `0.0`. If the parent(s) of this input receive prediction error from other
+ children, they simply ignore this one. If they are not receiving other
+ prediction errors, they are updated by keeping the same mean be decreasing
+ the precision as a function of time to reflect the evolution of the
+ underlying Gaussian Random Walk.
+ .. warning::
+ Missing inputs are missing observations from the agent's perspective and
+ should not be used to handle missing data points that were observed (e.g.
+ missing in the event log, or rejected trials).
+
+ """
+ iftime_stepsisNone:
+ time_steps=np.ones(len(input_data))# time steps vector
+
+ # concatenate data and time
+ iftime_stepsisNone:
+ time_steps=np.ones((len(input_data),1))# time steps vector
+ else:
+ time_steps=time_steps[...,jnp.newaxis]
+ ifinput_data.ndim==1:
+ input_data=input_data[...,jnp.newaxis]
+
+ # is it observation or missing inputs
+ ifobservedisNone:
+ observed=np.ones(input_data.shape,dtype=int)
+
+ # create the update functions that will be scanned
+ branches_fn=[
+ Partial(
+ beliefs_propagation,
+ update_sequence=seq,
+ structure=self.structure,
+ )
+ forseqinupdate_branches
+ ]
+
+ # create the function that will be scanned
+ defswitching_propagation(attributes,scan_input):
+ data,idx=scan_input
+ returnswitch(idx,branches_fn,attributes,data)
+
+ # wrap the inputs
+ scan_input=(input_data,time_steps,observed),branches_idx
+
+ # scan over the input data and apply the switching belief propagation functions
+ _,node_trajectories=scan(switching_propagation,self.attributes,scan_input)
+
+ # the node structure at each value updates
+ self.node_trajectories=node_trajectories
+
+ # because some of the input nodes might not have been updated, here we manually
+ # insert the input data to the input node (without triggering updates)
+ foridx,inpinzip(self.inputs.idx,range(input_data.shape[1])):
+ self.node_trajectories[idx]["values"]=input_data[inp]
+
+ returnself
+
+ defget_network(self)->NetworkParameters:
+"""Return the attributes, structure and update sequence defining the network."""
+ ifself.scan_fnisNone:
+ self=self.create_belief_propagation_fn()
+
+ assertself.update_sequenceisnotNone
+
+ returnself.attributes,self.structure,self.update_sequence
+
+ defset_update_sequence(self,update_type:str="eHGF")->"Network":
+"""Generate an update sequence from the network's structure.
+
+ See :py:func:`pyhgf.networks.get_update_sequence` for more details.
+
+ Parameters
+ ----------
+ update_type :
+ The type of update to perform for volatility coupling. Can be `"eHGF"`
+ (defaults) or `"standard"`. The eHGF update step was proposed as an
+ alternative to the original definition in that it starts by updating the
+ mean and then the precision of the parent node, which generally reduces the
+ errors associated with impossible parameter space and improves sampling.
+
+ """
+ self.update_sequence=tuple(
+ get_update_sequence(network=self,update_type=update_type)
+ )
+
+ returnself
+
+ defadd_nodes(
+ self,
+ kind:str="continuous-state",
+ n_nodes:int=1,
+ node_parameters:Dict={},
+ value_children:Optional[Union[List,Tuple,int]]=None,
+ value_parents:Optional[Union[List,Tuple,int]]=None,
+ volatility_children:Optional[Union[List,Tuple,int]]=None,
+ volatility_parents:Optional[Union[List,Tuple,int]]=None,
+ coupling_fn:Tuple[Optional[Callable],...]=(None,),
+ **additional_parameters,
+ ):
+"""Add new input/state node(s) to the neural network.
+
+ Parameters
+ ----------
+ kind :
+ The kind of node to create. If `"continuous-state"` (default), the node will
+ be a regular state node that can have value and/or volatility
+ parents/children. If `"binary-state"`, the node should be the
+ value parent of a binary input. State nodes filtering distribution from the
+ exponential family can be created using the `"ef-"` prefix (e.g.
+ `"ef-normal"` for a univariate normal distribution). Note that only a few
+ distributions are implemented at the moment.
+
+ In addition to state nodes, four types of input nodes are supported:
+ - `generic-input`: receive a value or an array and pass it to the parent
+ nodes.
+ - `continuous-input`: receive a continuous observation as input.
+ - `binary-input` receives a single boolean as observation. The parameters
+ provided to the binary input node contain: 1. `binary_precision`, the binary
+ input precision, which defaults to `jnp.inf`. 2. `eta0`, the lower bound of
+ the binary process, which defaults to `0.0`. 3. `eta1`, the higher bound of
+ the binary process, which defaults to `1.0`.
+ - `categorical-input` receives a boolean array as observation. The
+ parameters provided to the categorical input node contain: 1.
+ `n_categories`, the number of categories implied by the categorical state.
+
+ .. note::
+ When using a categorical state node, the `binary_parameters` can be used to
+ parametrize the implied collection of binary HGFs.
+
+ .. note:
+ When using `categorical-input`, the implied `n` binary HGFs are
+ automatically created with a shared volatility parent at the third level,
+ resulting in a network with `3n + 2` nodes in total.
+
+ n_nodes :
+ The number of nodes to create (defaults to `1`).
+ node_parameters :
+ Dictionary of parameters. The default values are automatically inferred
+ from the node type. Different values can be provided by passing them in the
+ dictionary, which will overwrite the defaults.
+ value_children :
+ Indexes to the node's value children. The index can be passed as an integer
+ or a list of integers, in case of multiple children. The coupling strength
+ can be controlled by passing a tuple, where the first item is the list of
+ indexes, and the second item is the list of coupling strengths.
+ value_parents :
+ Indexes to the node's value parents. The index can be passed as an integer
+ or a list of integers, in case of multiple children. The coupling strength
+ can be controlled by passing a tuple, where the first item is the list of
+ indexes, and the second item is the list of coupling strengths.
+ volatility_children :
+ Indexes to the node's volatility children. The index can be passed as an
+ integer or a list of integers, in case of multiple children. The coupling
+ strength can be controlled by passing a tuple, where the first item is the
+ list of indexes, and the second item is the list of coupling strengths.
+ volatility_parents :
+ Indexes to the node's volatility parents. The index can be passed as an
+ integer or a list of integers, in case of multiple children. The coupling
+ strength can be controlled by passing a tuple, where the first item is the
+ list of indexes, and the second item is the list of coupling strengths.
+ coupling_fn :
+ Coupling function(s) between the current node and its value children.
+ It has to be provided as a tuple. If multiple value children are specified,
+ the coupling functions must be stated in the same order of the children.
+ Note: if a node has multiple parents nodes with different coupling
+ functions, a coupling function should be indicated for all the parent nodes.
+ If no coupling function is stated, the relationship between nodes is assumed
+ linear.
+ **kwargs :
+ Additional keyword parameters will be passed and overwrite the node
+ attributes.
+
+ """
+ ifkindnotin[
+ "continuous-input",
+ "binary-input",
+ "categorical-input",
+ "DP-state",
+ "ef-normal",
+ "generic-input",
+ "continuous-state",
+ "binary-state",
+ ]:
+ raiseValueError(
+ (
+ "Invalid node type. Should be one of the following: "
+ "'continuous-input', 'binary-input', 'categorical-input', "
+ "'DP-state', 'continuous-state', 'binary-state', 'ef-normal'."
+ )
+ )
+
+ # assess children number
+ # this is required to ensure the coupling functions match
+ children_number=1
+ ifvalue_childrenisNone:
+ children_number=0
+ elifisinstance(value_children,int):
+ children_number=1
+ elifisinstance(value_children,list):
+ children_number=len(value_children)
+
+ # transform coupling parameter into tuple of indexes and strenghts
+ couplings=[]
+ forindexesin[
+ value_parents,
+ volatility_parents,
+ value_children,
+ volatility_children,
+ ]:
+ ifindexesisnotNone:
+ ifisinstance(indexes,int):
+ coupling_idxs=tuple([indexes])
+ coupling_strengths=tuple([1.0])
+ elifisinstance(indexes,list):
+ coupling_idxs=tuple(indexes)
+ coupling_strengths=tuple([1.0]*len(coupling_idxs))
+ elifisinstance(indexes,tuple):
+ coupling_idxs=tuple(indexes[0])
+ coupling_strengths=tuple(indexes[1])
+ else:
+ coupling_idxs,coupling_strengths=None,None
+ couplings.append((coupling_idxs,coupling_strengths))
+ value_parents,volatility_parents,value_children,volatility_children=(
+ couplings
+ )
+
+ # create the default parameters set according to the node type
+ ifkind=="continuous-state":
+ default_parameters={
+ "mean":0.0,
+ "expected_mean":0.0,
+ "precision":1.0,
+ "expected_precision":1.0,
+ "volatility_coupling_children":volatility_children[1],
+ "volatility_coupling_parents":volatility_parents[1],
+ "value_coupling_children":value_children[1],
+ "value_coupling_parents":value_parents[1],
+ "tonic_volatility":-4.0,
+ "tonic_drift":0.0,
+ "autoconnection_strength":1.0,
+ "observed":1,
+ "temp":{
+ "effective_precision":0.0,
+ "value_prediction_error":0.0,
+ "volatility_prediction_error":0.0,
+ "expected_precision_children":0.0,
+ },
+ }
+ elifkind=="binary-state":
+ default_parameters={
+ "mean":0.0,
+ "expected_mean":0.0,
+ "precision":1.0,
+ "expected_precision":1.0,
+ "volatility_coupling_children":volatility_children[1],
+ "volatility_coupling_parents":volatility_parents[1],
+ "value_coupling_children":value_children[1],
+ "value_coupling_parents":value_parents[1],
+ "tonic_volatility":0.0,
+ "tonic_drift":0.0,
+ "autoconnection_strength":1.0,
+ "observed":1,
+ "binary_expected_precision":0.0,
+ "temp":{
+ "effective_precision":0.0,
+ "value_prediction_error":0.0,
+ "volatility_prediction_error":0.0,
+ "expected_precision_children":0.0,
+ },
+ }
+ elifkind=="generic-input":
+ default_parameters={
+ "values":0.0,
+ "time_step":0.0,
+ "observed":0,
+ }
+ elifkind=="continuous-input":
+ default_parameters={
+ "volatility_coupling_parents":None,
+ "value_coupling_parents":None,
+ "input_precision":1e4,
+ "expected_precision":1e4,
+ "time_step":0.0,
+ "values":0.0,
+ "surprise":0.0,
+ "observed":0,
+ "temp":{
+ "effective_precision":1.0,# should be fixed to 1 for input nodes
+ "value_prediction_error":0.0,
+ "volatility_prediction_error":0.0,
+ },
+ }
+ elifkind=="binary-input":
+ default_parameters={
+ "expected_precision":jnp.inf,
+ "eta0":0.0,
+ "eta1":1.0,
+ "time_step":0.0,
+ "values":0.0,
+ "observed":0,
+ "surprise":0.0,
+ }
+ elifkind=="categorical-input":
+ if"n_categories"innode_parameters:
+ n_categories=node_parameters["n_categories"]
+ elif"n_categories"inadditional_parameters:
+ n_categories=additional_parameters["n_categories"]
+ else:
+ n_categories=4
+ binary_parameters={
+ "eta0":0.0,
+ "eta1":1.0,
+ "binary_precision":jnp.inf,
+ "n_categories":n_categories,
+ "precision_1":1.0,
+ "precision_2":1.0,
+ "precision_3":1.0,
+ "mean_1":0.0,
+ "mean_2":-jnp.log(n_categories-1),
+ "mean_3":0.0,
+ "tonic_volatility_2":-4.0,
+ "tonic_volatility_3":-4.0,
+ }
+ binary_idxs:List[int]=[
+ 1+i+len(self.edges)foriinrange(n_categories)
+ ]
+ default_parameters={
+ "binary_idxs":binary_idxs,# type: ignore
+ "n_categories":n_categories,
+ "surprise":0.0,
+ "kl_divergence":0.0,
+ "time_step":jnp.nan,
+ "alpha":jnp.ones(n_categories),
+ "pe":jnp.zeros(n_categories),
+ "xi":jnp.array([1.0/n_categories]*n_categories),
+ "mean":jnp.array([1.0/n_categories]*n_categories),
+ "values":jnp.zeros(n_categories),
+ "binary_parameters":binary_parameters,
+ }
+ elif"ef-normal"inkind:
+ default_parameters={
+ "nus":3.0,
+ "xis":jnp.array([0.0,1.0]),
+ "values":0.0,
+ "observed":1.0,
+ }
+ elifkind=="DP-state":
+
+ if"batch_size"inadditional_parameters.keys():
+ batch_size=additional_parameters["batch_size"]
+ elif"batch_size"innode_parameters.keys():
+ batch_size=node_parameters["batch_size"]
+ else:
+ batch_size=10
+
+ default_parameters={
+ "batch_size":batch_size,# number of branches available in the network
+ "n":jnp.zeros(batch_size),# number of observation in each cluster
+ "n_total":0,# the total number of observations in the node
+ "alpha":1.0,# concentration parameter for the implied Dirichlet dist.
+ "expected_means":jnp.zeros(batch_size),
+ "expected_sigmas":jnp.ones(batch_size),
+ "sensory_precision":1.0,
+ "activated":jnp.zeros(batch_size),
+ "value_coupling_children":(1.0,),
+ "values":0.0,
+ "n_active_cluster":0,
+ }
+
+ # Update the default node parameters using keywords args and dictonary
+ ifbool(additional_parameters):
+ # ensure that all passed values are valid keys
+ invalid_keys=[
+ key
+ forkeyinadditional_parameters.keys()
+ ifkeynotindefault_parameters.keys()
+ ]
+
+ ifinvalid_keys:
+ raiseValueError(
+ (
+ "Some parameter(s) passed as keyword arguments were not found "
+ f"in the default key list for this node (i.e. {invalid_keys})."
+ " If you want to create a new key in the node attributes, "
+ "please use the node_parameters argument instead."
+ )
+ )
+
+ # if keyword parameters were provided, update the default_parameters dict
+ default_parameters.update(additional_parameters)
+
+ # update the defaults using the dict parameters
+ default_parameters.update(node_parameters)
+ node_parameters=default_parameters
+
+ if"input"inkind:
+ # "continuous": 0, "binary": 1, "categorical": 2, "generic": 3
+ input_type=input_types[kind.split("-")[0]]
+ else:
+ input_type=None
+
+ # define the type of node that is created
+ if"input"inkind:
+ node_type=0
+ elif"binary-state"inkind:
+ node_type=1
+ elif"continuous-state"inkind:
+ node_type=2
+ elif"ef-normal"inkind:
+ node_type=3
+ elif"DP-state"inkind:
+ node_type=4
+
+ for_inrange(n_nodes):
+ # convert the structure to a list to modify it
+ edges_as_list:List=list(self.edges)
+
+ node_idx=len(self.attributes)# the index of the new node
+
+ # for mutiple value children, set a default tuple with corresponding length
+ ifchildren_number!=len(coupling_fn):
+ ifcoupling_fn==(None,):
+ coupling_fn=children_number*coupling_fn
+ else:
+ raiseValueError(
+ "The number of coupling fn and value children do not match"
+ )
+
+ # add a new edge
+ edges_as_list.append(
+ AdjacencyLists(
+ node_type,None,None,None,None,coupling_fn=coupling_fn
+ )
+ )
+
+ # convert the list back to a tuple
+ self.edges=tuple(edges_as_list)
+
+ ifnode_idx==0:
+ # this is the first node, create the node structure
+ self.attributes={node_idx:deepcopy(node_parameters)}
+ ifinput_typeisnotNone:
+ self.inputs=Inputs((node_idx,),(input_type,))
+ else:
+ # update the node structure
+ self.attributes[node_idx]=deepcopy(node_parameters)
+
+ ifinput_typeisnotNone:
+ # add information about the new input node in the indexes
+ new_idx=self.inputs.idx
+ new_idx+=(node_idx,)
+ new_kind=self.inputs.kind
+ new_kind+=(input_type,)
+ self.inputs=Inputs(new_idx,new_kind)
+
+ # Update the edges of the parents and children accordingly
+ # --------------------------------------------------------
+ ifvalue_parents[0]isnotNone:
+ self.add_edges(
+ kind="value",
+ parent_idxs=value_parents[0],
+ children_idxs=node_idx,
+ coupling_strengths=value_parents[1],# type: ignore
+ )
+ ifvalue_children[0]isnotNone:
+ self.add_edges(
+ kind="value",
+ parent_idxs=node_idx,
+ children_idxs=value_children[0],
+ coupling_strengths=value_children[1],# type: ignore
+ coupling_fn=coupling_fn,
+ )
+ ifvolatility_children[0]isnotNone:
+ self.add_edges(
+ kind="volatility",
+ parent_idxs=node_idx,
+ children_idxs=volatility_children[0],
+ coupling_strengths=volatility_children[1],# type: ignore
+ )
+ ifvolatility_parents[0]isnotNone:
+ self.add_edges(
+ kind="volatility",
+ parent_idxs=volatility_parents[0],
+ children_idxs=node_idx,
+ coupling_strengths=volatility_parents[1],# type: ignore
+ )
+
+ ifkind=="categorical-input":
+ # if we are creating a categorical state or state-transition node
+ # we have to generate the implied binary network(s) here
+ self=fill_categorical_state_node(
+ self,
+ node_idx=node_idx,
+ binary_input_idxs=node_parameters["binary_idxs"],# type: ignore
+ binary_parameters=binary_parameters,
+ )
+
+ returnself
+
+ defplot_nodes(self,node_idxs:Union[int,List[int]],**kwargs):
+"""Plot the node(s) beliefs trajectories."""
+ returnplot_nodes(network=self,node_idxs=node_idxs,**kwargs)
+
+ defplot_trajectories(self,**kwargs):
+"""Plot the parameters trajectories."""
+ returnplot_trajectories(network=self,**kwargs)
+
+ defplot_correlations(self):
+"""Plot the heatmap of cross-trajectories correlation."""
+ returnplot_correlations(network=self)
+
+ defplot_network(self):
+"""Visualization of node network using GraphViz."""
+ returnplot_network(network=self)
+
+ defto_pandas(self)->pd.DataFrame:
+"""Export the nodes trajectories and surprise as a Pandas data frame.
+
+ Returns
+ -------
+ structure_df :
+ Pandas data frame with the time series of sufficient statistics and
+ the surprise of each node in the structure.
+
+ """
+ returnto_pandas(self)
+
+ defsurprise(
+ self,
+ response_function:Callable,
+ response_function_inputs:Tuple=(),
+ response_function_parameters:Optional[
+ Union[np.ndarray,ArrayLike,float]
+ ]=None,
+ )->float:
+"""Surprise of the model conditioned by the response function.
+
+ The surprise (negative log probability) depends on the input data, the model
+ parameters, the response function, its inputs and its additional parameters
+ (optional).
+
+ Parameters
+ ----------
+ response_function :
+ The response function to use to compute the model surprise. If `None`
+ (default), return the sum of Gaussian surprise if `model_type=="continuous"`
+ or the sum of the binary surprise if `model_type=="binary"`.
+ response_function_inputs :
+ A list of tuples with the same length as the number of models. Each tuple
+ contains additional data and parameters that can be accessible to the
+ response functions.
+ response_function_parameters :
+ A list of additional parameters that will be passed to the response
+ function. This can include values over which inferece is performed in a
+ PyMC model (e.g. the inverse temperature of a binary softmax).
+
+ Returns
+ -------
+ surprise :
+ The model's surprise given the input data and the response function.
+
+ """
+ returnresponse_function(
+ hgf=self,
+ response_function_inputs=response_function_inputs,
+ response_function_parameters=response_function_parameters,
+ )
+ returnself
+
+ defadd_edges(
+ self,
+ kind="value",
+ parent_idxs=Union[int,List[int]],
+ children_idxs=Union[int,List[int]],
+ coupling_strengths:Union[float,List[float],Tuple[float]]=1.0,
+ coupling_fn:Tuple[Optional[Callable],...]=(None,),
+ )->"Network":
+"""Add a value or volatility coupling link between a set of nodes.
+
+ Parameters
+ ----------
+ kind :
+ The kind of coupling, can be `"value"` or `"volatility"`.
+ parent_idxs :
+ The index(es) of the parent node(s).
+ children_idxs :
+ The index(es) of the children node(s).
+ coupling_strengths :
+ The coupling strength betwen the parents and children.
+ coupling_fn :
+ Coupling function(s) between the current node and its value children.
+ It has to be provided as a tuple. If multiple value children are specified,
+ the coupling functions must be stated in the same order of the children.
+ Note: if a node has multiple parents nodes with different coupling
+ functions, a coupling function should be indicated for all the parent nodes.
+ If no coupling function is stated, the relationship between nodes is assumed
+ linear.
+
+ """
+ attributes,edges=add_edges(
+ attributes=self.attributes,
+ edges=self.edges,
+ kind=kind,
+ parent_idxs=parent_idxs,
+ children_idxs=children_idxs,
+ coupling_strengths=coupling_strengths,
+ coupling_fn=coupling_fn,
+ )
+
+ self.attributes=attributes
+ self.edges=edges
+
+ returnself
This is the core class to define and manipulate neural networks, that consists in
+1. attributes, 2. structure and 3. update sequences.
+
+
Attributes:
+
+
attributes
The attributes of the probabilistic nodes.
+
+
edges
The edges of the probabilistic nodes as a tuple of
+pyhgf.typing.AdjacencyLists. The tuple has the same length as the
+node number. For each node, the index lists the value/volatility
+parents/children.
+
+
inputs
Information on the input nodes.
+
+
node_trajectories
The dynamic of the node’s beliefs after updating.
+
+
update_sequence
The sequence of update functions that are applied during the belief propagation
+step.
+
+
scan_fn
The function that is passed to jax.lax.scan(). This is a pre-
+parametrized version of pyhgf.networks.beliefs_propagation().
Plot the trajectory of expected sufficient statistics of a set of nodes.
This function will plot the expected mean and precision (converted into standard
deviation) before observation, and the Gaussian surprise after observation. If
diff --git a/dev/generated/pyhgf.plots/pyhgf.plots.plot_trajectories.html b/dev/generated/pyhgf.plots/pyhgf.plots.plot_trajectories.html
index 4a5d4f90a..05c493a3d 100644
--- a/dev/generated/pyhgf.plots/pyhgf.plots.plot_trajectories.html
+++ b/dev/generated/pyhgf.plots/pyhgf.plots.plot_trajectories.html
@@ -33,7 +33,6 @@
-
@@ -48,14 +47,13 @@
-
-
+
@@ -466,6 +464,7 @@
Generate an update sequence from the network’s structure.
This function return an optimized update sequence considering the edges of the
network. The function ensures that the following principles apply:
@@ -572,7 +571,7 @@
The type of update to perform for volatility coupling. Can be “eHGF”
(defaults) or “standard”. The eHGF update step was proposed as an
diff --git a/dev/generated/pyhgf.utils/pyhgf.utils.list_branches.html b/dev/generated/pyhgf.utils/pyhgf.utils.list_branches.html
index 3a3289690..51c8ba160 100644
--- a/dev/generated/pyhgf.utils/pyhgf.utils.list_branches.html
+++ b/dev/generated/pyhgf.utils/pyhgf.utils.list_branches.html
@@ -33,7 +33,6 @@
-
@@ -48,7 +47,6 @@
-
@@ -466,6 +464,7 @@
In this section, you can find tutorial notebooks that describe the internals of pyhgf, the theory behind the Hierarchical Gaussian filter, and step-by-step application and use cases of the model. At the beginning of every tutorial, you will find a badge to run the notebook interactively in a Google Colab session.
-An introduction to the Hierarchical Gaussian Filter
-
How the generative model of the Hierarchical Gaussian filter can be turned into update functions that update nodes through value and volatility coupling?
-Creating and manipulating networks of probabilistic nodes
-
How to create and manipulate a network of probabilistic nodes for reinforcement learning? Working at the intersection of graphs, neural networks and probabilistic frameworks.
The categorical Hierarchical Gaussian Filter is a generalisation of the binary HGF to handle categorical distribution with and without transition probabilities.
Hand-on exercises for theComputational Psychiatry Course (Zurich) to build intuition around the generalised Hierarchical Gaussian Filter, how to create and manipulate probabilistic networks, design an agent to perform a reinforcement learning task and use MCMC sampling for parameter inference and model comparison—about 4 hours.
-
-
-
-
-
-
-
-Introduction to the Generalised Hierarchical Gaussian Filter
-
Theoretical introduction to the generative model of the generalised Hierarchical Gaussian Filter and presentation of the update functions (i.e. the first inversion of the model).
-Applying the Hierarchical Gaussian Filter to reinforcement learning
-
Practical application of the generalised Hierarchical Gaussian Filter to reinforcement learning problems and estimation of parameters through MCMC sampling (i.e. the second inversion of the model).
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 17 seconds.
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
+
+
+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 15 seconds.
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 16 seconds.
-
-
-
There were 1 divergences after tuning. Increase `target_accept` or reparameterize.
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
+
+
+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 15 seconds.
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
+
+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds.
@@ -951,19 +962,19 @@
Recovering HGF parameters from the observed behaviors
The results above indicate that given the responses provided by the participant, the most likely values for the parameter \(\omega_2\) are between -4.9 and -3.1, with a mean at -3.9 (you can find slightly different values if you sample different actions from the decisions function). We can consider this as an excellent estimate given the sparsity of the data, and the complexity of the model.
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 60 seconds.
-
There were 1998 divergences after tuning. Increase `target_accept` or reparameterize.
+
There were 2000 divergences after tuning. Increase `target_accept` or reparameterize.
We recommend running at least 4 chains for robust computation of convergence diagnostics
@@ -746,7 +757,7 @@
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 10 seconds.
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 8 seconds.
-
There were 3 divergences after tuning. Increase `target_accept` or reparameterize.
+
There were 1 divergences after tuning. Increase `target_accept` or reparameterize.
We recommend running at least 4 chains for robust computation of convergence diagnostics
@@ -1110,7 +1133,7 @@
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
+
+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds.
/opt/hostedtoolcache/Python/3.12.5/x64/lib/python3.12/site-packages/rich/live.py:231: UserWarning: install
+"ipywidgets" for Jupyter support
+ warnings.warn('install "ipywidgets" for Jupyter support')
+
+
+
+
+
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4 seconds.
We have saved the pointwise log probabilities as a variable, here we simply move this variable to the log_likelihoo field of the idata object, so Arviz knows that this can be used later for model comparison.