From cce173da97447a7233d513f6ccb744d54359aae9 Mon Sep 17 00:00:00 2001 From: Louie Date: Sat, 7 Dec 2024 16:17:55 +0100 Subject: [PATCH] updated and cleaned latent_update_function. Plot of function doesn't work --- .../notebooks/Latent_var_notebook.ipynb | 153 ++++++++---------- 1 file changed, 64 insertions(+), 89 deletions(-) diff --git a/docs/source/notebooks/Latent_var_notebook.ipynb b/docs/source/notebooks/Latent_var_notebook.ipynb index bc6f282a9..426670c17 100644 --- a/docs/source/notebooks/Latent_var_notebook.ipynb +++ b/docs/source/notebooks/Latent_var_notebook.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 74, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ @@ -45,29 +45,17 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 48, "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'jax' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[3], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;66;03m# Disable JIT compilation globally\u001b[39;00m\n\u001b[1;32m----> 2\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mjax_disable_jit\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;66;03m# True - If I want the compiler disabled.\u001b[39;00m\n", - "\u001b[1;31mNameError\u001b[0m: name 'jax' is not defined" - ] - } - ], + "outputs": [], "source": [ "# Disable JIT compilation globally\n", - "jax.config.update(\"jax_disable_jit\", False) # True - If I want the compiler disabled." + "# jax.config.update(\"jax_disable_jit\", False) # True - If I want the compiler disabled." ] }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 49, "metadata": {}, "outputs": [ { @@ -75,39 +63,46 @@ "output_type": "stream", "text": [ "5\n", - "4\n", - "{-1: {'time_step': 0.0}, 0: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': (1.0,), 'value_coupling_children': None, 'value_coupling_parents': (1.0,), 'tonic_volatility': 0.0, 'tonic_drift': 0.0, 'autoconnection_strength': 0.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 1: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': (1.0,), 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 2: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10.0, 'expected_precision': 1.0, 'volatility_coupling_children': (1.0,), 'volatility_coupling_parents': None, 'value_coupling_children': None, 'value_coupling_parents': None, 'tonic_volatility': -2.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 3: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': None, 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}}\n" + "4\n" ] } ], "source": [ + "timeserie = load_data(\"continuous\")\n", + "\n", + "# latent_hgf = (\n", + "# Network()\n", + "# .add_nodes(precision=1e4)\n", + "# .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=0)\n", + "# .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=0)\n", + "# .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=1)\n", + "# # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=1)\n", + "# # .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=2)\n", + "# # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=2)\n", + "# ).create_belief_propagation_fn()\n", + "\n", "latent_hgf = (\n", " Network()\n", " .add_nodes(precision=1e4)\n", " .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=0)\n", - " .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=0)\n", - " .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=1)\n", + " .add_nodes(precision=1e1, tonic_volatility=-2.0, value_children=1)\n", + " .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=0)\n", " # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=1)\n", " # .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=2)\n", " # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=2)\n", ").create_belief_propagation_fn()\n", "\n", - "# attributes, edges, update_sequence = (\n", - "# latent_hgf.get_network()\n", - "# )\n", - "\n", "attributes, edges, update_sequence = (\n", " latent_hgf.get_network()\n", ")\n", "\n", "print(len(attributes))\n", - "print(len(edges))\n", - "print(attributes)" + "print(len(edges))" ] }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 50, "metadata": {}, "outputs": [ { @@ -145,35 +140,35 @@ "\n", "\n", "x_2\n", - "\n", - "2\n", + "\n", + "2\n", "\n", - "\n", + "\n", "\n", - "x_2->x_0\n", - "\n", - "\n", + "x_2->x_1\n", + "\n", + "\n", "\n", "\n", "\n", "x_3\n", - "\n", - "3\n", + "\n", + "3\n", "\n", - "\n", + "\n", "\n", - "x_3->x_1\n", - "\n", - "\n", + "x_3->x_0\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 77, + "execution_count": 50, "metadata": {}, "output_type": "execute_result" } @@ -184,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -198,18 +193,10 @@ " edges: Edges,\n", " index: int\n", " ):\n", - " # Add new node to attributes\n", - "\n", - " print('attributes before changing anything:')\n", - " print(attributes)\n", - "\n", - " print('node to be connected to:')\n", - " print(index)\n", "\n", " new_node_idx = len(edges) # Use len() to get the next index\n", - " print('new_node_idx:')\n", "\n", - " print(new_node_idx)\n", + " # Add new node to attributes\n", " attributes[new_node_idx] = {\n", " \"mean\": 0.0,\n", " \"expected_mean\": 0.0,\n", @@ -229,57 +216,34 @@ " \"volatility_prediction_error\": 0.0,\n", " },\n", " }\n", - " \n", + "\n", " # Add new AdjacencyList(empty values) to Edges tuple\n", " new_adj_list = AdjacencyLists(2, None, None, None, None, None)\n", " edges = edges + (new_adj_list,)\n", - "\n", - " print('attributes after adding new empty node:')\n", - " print(attributes)\n", " \n", " # Use add_edges to integrate the newly altered attributes and edges\n", - " attributes, edges = add_edges(attributes, edges, 'value', new_node_idx, index) # If i understand correctly, the new node is incorrectly coupled to node nr. 1, instread of the intended node nr. 2\n", + " attributes, edges = add_edges(attributes, edges, 'value', new_node_idx, index) # If i understand correctly, the new node is incorrectly coupled to node nr. 1, instread of the intended node nr. 2 in the new attributes.\n", " \n", - " print('attributes after adding value coupling:')\n", - " print(attributes)\n", - "\n", " # Return altered attributes and edges\n", " return attributes, edges" ] }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 52, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "attributes before changing anything:\n", - "{-1: {'time_step': 0.0}, 0: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': (1.0,), 'value_coupling_children': None, 'value_coupling_parents': (1.0,), 'tonic_volatility': 0.0, 'tonic_drift': 0.0, 'autoconnection_strength': 0.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 1: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': (1.0,), 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 2: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10.0, 'expected_precision': 1.0, 'volatility_coupling_children': (1.0,), 'volatility_coupling_parents': None, 'value_coupling_children': None, 'value_coupling_parents': None, 'tonic_volatility': -2.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 3: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': None, 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}}\n", - "node to be connected to:\n", - "2\n", - "new_node_idx:\n", - "4\n", - "attributes after adding new empty node:\n", - "{-1: {'time_step': 0.0}, 0: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': (1.0,), 'value_coupling_children': None, 'value_coupling_parents': (1.0,), 'tonic_volatility': 0.0, 'tonic_drift': 0.0, 'autoconnection_strength': 0.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 1: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': (1.0,), 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 2: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10.0, 'expected_precision': 1.0, 'volatility_coupling_children': (1.0,), 'volatility_coupling_parents': None, 'value_coupling_children': None, 'value_coupling_parents': None, 'tonic_volatility': -2.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 3: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': None, 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 4: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 1.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': None, 'value_coupling_parents': None, 'tonic_volatility': -4.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}}\n", - "attributes after adding value coupling:\n", - "{-1: {'time_step': 0.0}, 0: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': (1.0,), 'value_coupling_children': None, 'value_coupling_parents': (1.0,), 'tonic_volatility': 0.0, 'tonic_drift': 0.0, 'autoconnection_strength': 0.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 1: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': (1.0,), 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 2: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 10.0, 'expected_precision': 1.0, 'volatility_coupling_children': (1.0,), 'volatility_coupling_parents': None, 'value_coupling_children': None, 'value_coupling_parents': None, 'tonic_volatility': -2.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 3: {'mean': 1.0357, 'expected_mean': 0.0, 'precision': 10000.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': (1.0,), 'tonic_volatility': -13.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}, 4: {'mean': 0.0, 'expected_mean': 0.0, 'precision': 1.0, 'expected_precision': 1.0, 'volatility_coupling_children': None, 'volatility_coupling_parents': None, 'value_coupling_children': (1.0,), 'value_coupling_parents': None, 'tonic_volatility': -4.0, 'tonic_drift': 0.0, 'autoconnection_strength': 1.0, 'observed': 1, 'temp': {'effective_precision': 0.0, 'value_prediction_error': 0.0, 'volatility_prediction_error': 0.0}}}\n" - ] - } - ], + "outputs": [], "source": [ "attributes, edges, update_sequence = (\n", " latent_hgf.get_network()\n", ")\n", "\n", - "latent_hgf_alt_attributes, latent_hgf_alt_edges = latent_update(attributes, edges, 2)" + "latent_hgf_alt_attributes, latent_hgf_alt_edges = latent_update(attributes, edges, 3)" ] }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 53, "metadata": {}, "outputs": [ { @@ -298,24 +262,35 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "latent_hgf.attributes = latent_hgf_alt_attributes\n", + "latent_hgf.edges = latent_hgf_alt_edges" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [ { "ename": "TypeError", - "evalue": "'Network' object does not support item assignment", + "evalue": "'NoneType' object is not subscriptable", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[1;32mIn[71], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mlatent_hgf\u001b[49m\u001b[43m[\u001b[49m\u001b[43mattributes\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;241m=\u001b[39m latent_hgf_alt_attributes\n\u001b[0;32m 2\u001b[0m latent_hgf[edges] \u001b[38;5;241m=\u001b[39m latent_hgf_alt_edges\n", - "\u001b[1;31mTypeError\u001b[0m: 'Network' object does not support item assignment" + "Cell \u001b[1;32mIn[56], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mlatent_hgf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot_network\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32mc:\\Users\\hesse\\miniconda3\\envs\\pymc_env\\Lib\\site-packages\\pyhgf\\model\\network.py:673\u001b[0m, in \u001b[0;36mNetwork.plot_network\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 671\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mplot_network\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m 672\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Visualization of node network using GraphViz.\"\"\"\u001b[39;00m\n\u001b[1;32m--> 673\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mplot_network\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnetwork\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32mc:\\Users\\hesse\\miniconda3\\envs\\pymc_env\\Lib\\site-packages\\pyhgf\\plots.py:324\u001b[0m, in \u001b[0;36mplot_network\u001b[1;34m(network)\u001b[0m\n\u001b[0;32m 320\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m value_parents_idx \u001b[38;5;129;01min\u001b[39;00m value_parents:\n\u001b[0;32m 321\u001b[0m \n\u001b[0;32m 322\u001b[0m \u001b[38;5;66;03m# get the coupling function from the value parent\u001b[39;00m\n\u001b[0;32m 323\u001b[0m child_idx \u001b[38;5;241m=\u001b[39m network\u001b[38;5;241m.\u001b[39medges[value_parents_idx]\u001b[38;5;241m.\u001b[39mvalue_children\u001b[38;5;241m.\u001b[39mindex(i)\n\u001b[1;32m--> 324\u001b[0m coupling_fn \u001b[38;5;241m=\u001b[39m \u001b[43mnetwork\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43medges\u001b[49m\u001b[43m[\u001b[49m\u001b[43mvalue_parents_idx\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcoupling_fn\u001b[49m\u001b[43m[\u001b[49m\u001b[43mchild_idx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[0;32m 325\u001b[0m graphviz_structure\u001b[38;5;241m.\u001b[39medge(\n\u001b[0;32m 326\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mvalue_parents_idx\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 327\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 328\u001b[0m color\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblack\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m coupling_fn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mblack:invis:black\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 329\u001b[0m )\n\u001b[0;32m 331\u001b[0m \u001b[38;5;66;03m# connect volatility parents\u001b[39;00m\n", + "\u001b[1;31mTypeError\u001b[0m: 'NoneType' object is not subscriptable" ] } ], "source": [ - "latent_hgf[attributes] = latent_hgf_alt_attributes\n", - "latent_hgf[edges] = latent_hgf_alt_edges" + "latent_hgf.plot_network() # Not sure why the plot function doesn't function with altered Attributes and Edges..." ] } ],