diff --git a/docs/source/notebooks/Example_4_Iowa_Gambling_Task_Short.ipynb b/docs/source/notebooks/Example_4_Iowa_Gambling_Task_Short.ipynb index 7efb19034..812672568 100644 --- a/docs/source/notebooks/Example_4_Iowa_Gambling_Task_Short.ipynb +++ b/docs/source/notebooks/Example_4_Iowa_Gambling_Task_Short.ipynb @@ -5,7 +5,7 @@ "metadata": {}, "source": [ "(example_4)=\n", - "# Example 4: Iowa-Gambling Task" + "# Example 4: Inferring optimistic bias in the Iowa-Gambling Task using variable autoconnection strength" ] }, { @@ -15,6 +15,17 @@ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ilabcode/pyhgf/blob/master/docs/source/notebooks/Example_4_Iowa_Gambling_Task.ipynb)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "authors:\n", + " - Aleksandrs Baskakovs, Aarhus University, Denmark (aleks@mgmt.au.dk)\n", + " - Nicolas Legrand, Aarhus University, Denmark (nicolas.legrand@cas.au.dk)\n", + "---" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -33,12 +44,11 @@ "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", - "from jax import grad\n", "import numpy as np\n", "import pymc as pm\n", "import pytensor.tensor as pt\n", "import seaborn as sns\n", - "from jax import jit\n", + "from jax import grad, jit\n", "from jax.nn import softmax\n", "from jax.tree_util import Partial\n", "from pyhgf.math import binary_surprise\n", @@ -693,22 +703,14 @@ "):\n", "\n", " # update the autoconnection strengths at the first level\n", - " network.attributes[4][\n", - " \"autoconnection_strength\"\n", - " ] = autoconnection_strength_1\n", - " network.attributes[5][\n", - " \"autoconnection_strength\"\n", - " ] = autoconnection_strength_2\n", - " network.attributes[6][\n", - " \"autoconnection_strength\"\n", - " ] = autoconnection_strength_1\n", - " network.attributes[7][\n", - " \"autoconnection_strength\"\n", - " ] = autoconnection_strength_2\n", - " \n", + " network.attributes[4][\"autoconnection_strength\"] = autoconnection_strength_1\n", + " network.attributes[5][\"autoconnection_strength\"] = autoconnection_strength_2\n", + " network.attributes[6][\"autoconnection_strength\"] = autoconnection_strength_1\n", + " network.attributes[7][\"autoconnection_strength\"] = autoconnection_strength_2\n", + "\n", " # run the model forward\n", " network.input_data(input_data=u, observed=observed)\n", - " \n", + "\n", " # compute decision probabilities given the belief trajectories\n", " expected_means = jnp.array(\n", " [\n", @@ -722,24 +724,24 @@ " for i in range(4, 8)\n", " ]\n", " )\n", - " \n", + "\n", " # Compute the decision probabilities\n", " x = beta_1 * expected_means + beta_2 * expected_variances\n", " x -= jnp.max(x, axis=0)\n", " decision_probabilities = softmax(x, axis=1)\n", - " \n", + "\n", " # compute the binary surprise over each bandit x trials\n", " surprises = binary_surprise(x=decisions.T, expected_mean=decision_probabilities)\n", - " \n", + "\n", " # avoid numerical overflow\n", " surprises = jnp.where(surprises > 1e6, 1e6, surprises)\n", - " \n", + "\n", " # sum all the binary surprises\n", " surprise = surprises.sum()\n", - " \n", + "\n", " # returns inf if the model cannot fit somewhere\n", " surprise = jnp.where(jnp.isnan(surprise), jnp.inf, surprise)\n", - " \n", + "\n", " return -surprise" ] }, @@ -809,7 +811,9 @@ "\n", " def grad(self, inputs, output_gradients):\n", " # Create a PyTensor expression of the gradient\n", - " grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_custom_op(*inputs)\n", + " grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_custom_op(\n", + " *inputs\n", + " )\n", "\n", " output_gradient = output_gradients[0]\n", " # We reference the VJP Op created below, which encapsulates\n", @@ -819,22 +823,30 @@ " output_gradient * grad_autoconnection_strength_2,\n", " ]\n", "\n", + "\n", "class GradCustomOp(Op):\n", " def make_node(self, autoconnection_strength_1, autoconnection_strength_2):\n", " # Make sure the two inputs are tensor variables\n", " inputs = [\n", - " pt.as_tensor_variable(autoconnection_strength_1), \n", - " pt.as_tensor_variable(autoconnection_strength_2), \n", + " pt.as_tensor_variable(autoconnection_strength_1),\n", + " pt.as_tensor_variable(autoconnection_strength_2),\n", " ]\n", " # Output has the shape type and shape as the first input\n", " outputs = [inp.type() for inp in inputs]\n", " return Apply(self, inputs, outputs)\n", "\n", " def perform(self, node, inputs, outputs):\n", - " grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_logp_fn(*inputs)\n", + " grad_autoconnection_strength_1, grad_autoconnection_strength_2 = grad_logp_fn(\n", + " *inputs\n", + " )\n", + "\n", + " outputs[0][0] = np.asarray(\n", + " grad_autoconnection_strength_1, dtype=node.outputs[0].dtype\n", + " )\n", + " outputs[1][0] = np.asarray(\n", + " grad_autoconnection_strength_2, dtype=node.outputs[1].dtype\n", + " )\n", "\n", - " outputs[0][0] = np.asarray(grad_autoconnection_strength_1, dtype=node.outputs[0].dtype)\n", - " outputs[1][0] = np.asarray(grad_autoconnection_strength_2, dtype=node.outputs[1].dtype)\n", "\n", "# Instantiate the Ops\n", "custom_op = CustomOp()\n", @@ -950,11 +962,13 @@ "source": [ "with pm.Model() as model:\n", " autoconnection_strength = pm.Beta(\"autoconnection_strength\", 1.0, 1.0, shape=2)\n", - " pm.Potential(\"hgf\", custom_op(\n", - " autoconnection_strength_1=autoconnection_strength[0], \n", - " autoconnection_strength_2=autoconnection_strength[1]\n", + " pm.Potential(\n", + " \"hgf\",\n", + " custom_op(\n", + " autoconnection_strength_1=autoconnection_strength[0],\n", + " autoconnection_strength_2=autoconnection_strength[1],\n", + " ),\n", " )\n", - " )\n", " idata = pm.sample(chains=2, cores=1)" ] },