Skip to content

Commit

Permalink
update description causal updating
Browse files Browse the repository at this point in the history
  • Loading branch information
Jane Doe committed Nov 22, 2024
1 parent 8756223 commit 09c80c0
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions docs/source/notebooks/Example_4_Causal.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -721,7 +721,7 @@
},
{
"cell_type": "code",
"execution_count": 538,
"execution_count": null,
"id": "45c8cbfa",
"metadata": {},
"outputs": [],
Expand All @@ -736,27 +736,19 @@
" r\"\"\"\n",
"Causal coupling strength update.\n",
"\n",
"The causal coupling strength \\( w_{ij} \\) is updated in three steps:\n",
"The causal coupling strength \\( w_{ij} \\) is updated as follows:\n",
"\n",
"1. **Precision-weighted prediction error**:\n",
" \n",
" .. math::\n",
" \\Delta_j = \\pi_j \\cdot \\epsilon_j\n",
"\n",
" where:\n",
" - \\( \\pi_j \\) is the precision of the child node.\n",
" - \\( \\epsilon_j = y_j - \\hat{y}_j \\) is the prediction error, defined as the difference between the observed value \\( y_j \\) and the predicted value \\( \\hat{y}_j \\).\n",
"\n",
"2. **Raw update**:\n",
"1. **Raw update**:\n",
" \n",
" .. math::\n",
" w_{ij}^{t} = w_{ij}^{t-1} + \\eta \\cdot \\Delta_j \\cdot \\mu_i\n",
"\n",
" where:\n",
" - \\( \\eta \\) is the learning rate.\n",
" - \\( \\mu_i \\) is the parent's expected mean.\n",
" - \\( \\Delta_j \\) is the precision-weighted PE.\n",
"\n",
"3. **Rescaling with Sigmoid-like Transformation**:\n",
"2. **Rescaling with Sigmoid-like transformation**:\n",
" \n",
" .. math::\n",
" w_{ij} = \\frac{x^\\text{temperature}}{x^\\text{temperature} + (1 - x)^\\text{temperature}}\n",
Expand Down Expand Up @@ -844,7 +836,7 @@
},
{
"cell_type": "code",
"execution_count": 541,
"execution_count": null,
"id": "9f8a053b",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -909,7 +901,6 @@
"\n",
"axs.plot(\n",
" causal_hgf.node_trajectories[1][\"temp\"]['value_prediction_error'], \n",
" # colour = \"blue\", \n",
" linestyle=\"--\", \n",
" linewidth=1.0, \n",
" label=\"PE child\"\n",
Expand All @@ -920,7 +911,6 @@
"\n",
"axs.plot(\n",
" causal_hgf.node_trajectories[0]['mean'], \n",
" # colour = \"green\", \n",
" linestyle=\"-\", \n",
" linewidth=1.0, \n",
" label=\"Mean parent\"\n",
Expand Down

0 comments on commit 09c80c0

Please sign in to comment.