Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add update function for learning causal coupling #249

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
28152b7
Added example 4 notebook
Oct 30, 2024
7c438ba
Changes in the network and example 4 file
Oct 31, 2024
526526e
Add description to example 4 notebook
Nov 1, 2024
4b69597
Add functions for edge updating
Nov 5, 2024
adf4668
Update causal_functions.py
Nov 6, 2024
c195f23
Updated functions and example
Nov 8, 2024
04c96ff
Updated comments
Nov 8, 2024
84f82a2
Update Example_4_Causal.ipynb
Nov 9, 2024
22821a2
updated example file
Nov 12, 2024
5e9e1c2
Added example 4 notebook
Oct 30, 2024
738f748
Changes in the network and example 4 file
Oct 31, 2024
42b6cfa
Add description to example 4 notebook
Nov 1, 2024
336ba9a
Add functions for edge updating
Nov 5, 2024
989674e
Update causal_functions.py
Nov 6, 2024
025c3ea
Updated functions and example
Nov 8, 2024
d6bc75e
Updated comments
Nov 8, 2024
dc0f395
Update Example_4_Causal.ipynb
Nov 9, 2024
de080d5
updated example file
Nov 12, 2024
0c99872
Update causal functions
Nov 13, 2024
94b0267
Update causal functions
Nov 13, 2024
ff19871
authors
LegrandNico Nov 14, 2024
31c1ab4
add an example of causal update in the tutorial
LegrandNico Nov 14, 2024
3540736
Update example with new update functions
Nov 15, 2024
8c0662b
Update sigmoid transformed coupling
Nov 21, 2024
8756223
update sigmoid causal coupling
Nov 22, 2024
09c80c0
update description causal updating
Nov 22, 2024
1a97672
update sigmoid function for coupling function
Nov 25, 2024
9891093
added KL for measures of freedom and control
Nov 26, 2024
2d20d0f
added agents for simulations
Dec 2, 2024
18b5fab
update agents for simulation
Dec 5, 2024
4c9ffe9
add a working version of the causal prediction errors
LegrandNico Dec 6, 2024
f985462
added plots for agents
Dec 18, 2024
55a5090
updated example notebook
Dec 18, 2024
4b4dce3
add derivations of derivative and roots
Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update causal_functions.py
Prediction error - sigmoid function, updated mutual information function based on entropy for gaussian
  • Loading branch information
Jane Doe committed Nov 6, 2024
commit adf4668d84c9487ce05e0a95231e043859ac1961
252 changes: 154 additions & 98 deletions pyhgf/updates/causal_functions.py
Original file line number Diff line number Diff line change
@@ -2,33 +2,16 @@
from scipy.stats import multivariate_normal
import itertools



def calculate_mutual_information_sampled(attributes, edges, node_idx, num_samples=1000, eps=1e-10):
def calculate_mutual_information(network, node_idx, num_samples=1000, eps=1e-8):
"""
Calculate mutual information for each parent-child, parent-parent, and self relationship.
Mutual information between two continuous Gaussian variables is estimated by sampling and
approximating the joint and marginal distributions.

Mathematical::
The mutual information \( I(X; Y) \) between two continuous Gaussian variables \( X \) and \( Y \) is defined as:

\[
I(X; Y) = \int \int p(x, y) \log \left( \ frac{p(x, y)}{p(x) p(y)} \ right) dx \, dy
\]

Using discrete samples, mutual information is approximated as:

\[
I(X; Y) \ approx \sum_{i,j} p(x_i, y_j) \left( \log(p(x_i, y_j) + \ text{eps}) - \log(p(x_i) + \ text{eps}) - \log(p(y_j) + \ text{eps}) \ right)
\]
Calculate mutual information for each parent-child, parent-parent, and self relationship
in a Network instance using entropy-based mutual information calculations with sampled covariance.

Parameters:
- `attributes`: dictionary with node attributes, such as expected precision and mean.
- `edges`: tuple of adjacency lists for value and volatility parents and children.
- `node_idx`: index of the child node for mutual information calculation.
- `num_samples`: (default 1000) number of samples used for covariance and mutual information estimation.
- `eps`: small constant to avoid log(0) issues (default 1e-10).
- network: A Network instance.
- node_idx: Index of the child node for mutual information calculation.
- num_samples: Number of samples used for covariance and mutual information estimation (default 1000).
- eps: Small constant to avoid log(0) issues (default 1e-8).

Returns:
- Dictionary of mutual information values:
@@ -39,135 +22,208 @@ def calculate_mutual_information_sampled(attributes, edges, node_idx, num_sample
mutual_info_dict = {"parent_child": {}, "parent_parent": {}, "self": None}

# Sample the child node distribution
node_mean = attributes[node_idx]["expected_mean"]
node_precision = attributes[node_idx]["expected_precision"]
node_mean = network.attributes[node_idx]["expected_mean"]
node_precision = network.attributes[node_idx]["expected_precision"]
node_variance = 1 / node_precision
child_samples = np.random.normal(node_mean, np.sqrt(node_variance), num_samples)

# List the value parents
value_parents_idxs = edges[node_idx].value_parents
value_parents_idxs = network.edges[node_idx].value_parents

# Store sampled data for covariance and MI
# Sample data for each parent
data = []
for parent_idx in value_parents_idxs:
parent_mean = attributes[parent_idx]["expected_mean"]
parent_precision = attributes[parent_idx]["expected_precision"]
parent_mean = network.attributes[parent_idx]["expected_mean"]
parent_precision = network.attributes[parent_idx]["expected_precision"]
parent_variance = 1 / parent_precision
parent_samples = np.random.normal(parent_mean, np.sqrt(parent_variance), num_samples)
data.append(parent_samples)

# Append child samples to data array
data.append(child_samples)

# Stack data columns and calculate the covariance matrix
data = np.vstack(data).T
cov = np.cov(data, rowvar=False) + eps * np.eye(data.shape[1]) # adjusting the diagonal of the matrix

# Means and variances from samples
means = data.mean(axis=0)
cov = np.cov(data.T)
# Self-entropy for the child node
child_variance = cov[-1, -1]
child_entropy = 0.5 * np.log(2 * np.pi * np.e * child_variance)
mutual_info_dict["self"] = {node_idx: child_entropy}

# Distributions, marginal and joint
# Calculate MI for each parent-child pair using entropy method
for i, parent_idx in enumerate(value_parents_idxs):
# Get marginal and joint distributions
d_child = multivariate_normal(means[-1], cov[-1, -1])
d_parent = multivariate_normal(means[i], cov[i, i])
jd_parent_child = multivariate_normal(means[[i, -1]], cov[[i, -1]][:, [i, -1]])

# Discretize samples and calculate MI
child_vals = np.linspace(child_samples.min(), child_samples.max(), 100)
parent_vals = np.linspace(data[:, i].min(), data[:, i].max(), 100)
triplets = ((jd_parent_child.pdf(tup), d_parent.pdf(tup[0]), d_child.pdf(tup[1]))
for tup in itertools.product(parent_vals, child_vals))
# Entropies of individual distributions
parent_variance = cov[i, i]
parent_entropy = 0.5 * np.log(2 * np.pi * np.e * parent_variance)

# Compute MI (using `eps` to avoid log(0))
mutual_information = np.sum([
p_xy * (np.log(p_xy + eps) - np.log(p_x + eps) - np.log(p_y + eps))
for p_xy, p_x, p_y in triplets
])
# calculating the joint entropy
joint_entropy = 0.5 * np.log((2 * np.pi * np.e) ** 2 * np.linalg.det(cov[[i, -1]][:, [i, -1]]))

# Store the calculated Mi for current pair
mutual_info_dict["parent_child"][(parent_idx, node_idx)] = mutual_information
# MI(X; Y) = H(X) + H(Y) - H(X, Y)
mutual_info_dict["parent_child"][(parent_idx, node_idx)] = max(
parent_entropy + child_entropy - joint_entropy, 0 # Ensure non-negative MI
)

# if more than one parent, compute pairs
# Calculate MI for each parent-parent pair if there are multiple parents
if len(value_parents_idxs) > 1:
for i in range(len(value_parents_idxs)):
for j in range(i + 1, len(value_parents_idxs)):
parent_i = value_parents_idxs[i]
parent_j = value_parents_idxs[j]

# Marginal and joint dist
d_i = multivariate_normal(means[i], cov[i, i])
d_j = multivariate_normal(means[j], cov[j, j])
jd_ij = multivariate_normal(means[[i, j]], cov[[i, j]][:, [i, j]])

# discretize sampled data and calculate mutual information
vals_i = np.linspace(data[:, i].min(), data[:, i].max(), 100)
vals_j = np.linspace(data[:, j].min(), data[:, j].max(), 100)
triplets = ((jd_ij.pdf(tup), d_i.pdf(tup[0]), d_j.pdf(tup[1]))
for tup in itertools.product(vals_i, vals_j))

# Gte MI for pair
mutual_information = np.sum([
p_xy * (np.log(p_xy + eps) - np.log(p_x + eps) - np.log(p_y + eps))
for p_xy, p_x, p_y in triplets
])
# Entropies of individual distributions
var_i, var_j = cov[i, i], cov[j, j]
entropy_i = 0.5 * np.log(2 * np.pi * np.e * var_i)
entropy_j = 0.5 * np.log(2 * np.pi * np.e * var_j)

# Store mutual information for parent-parent relationship
mutual_info_dict["parent_parent"][(parent_i, parent_j)] = mutual_information
# Joint entropy of parent_i and parent_j
joint_entropy_ij = 0.5 * np.log((2 * np.pi * np.e) ** 2 * np.linalg.det(cov[[i, j]][:, [i, j]]))

# self-entropy for the child node (entropy of a Gaussian variable)
self_entropy = 0.5 * np.log(2 * np.pi * np.e * node_variance)
mutual_info_dict["self"] = {node_idx: self_entropy}
# MI(X; Y) = H(X) + H(Y) - H(X, Y)
mutual_info_dict["parent_parent"][(parent_i, parent_j)] = max(
entropy_i + entropy_j - joint_entropy_ij, 0 # Ensure non-negative MI
)

return mutual_info_dict




def calculate_surd(mutual_info_dict, child_idx, parent_idx):
"""
Calculate the SURD for a specified causal parent in relation to a given child node.
Calculate the SURD decomposition for a specified
causal parent in relation to a given child node.

Parameters:
- mutual_info_dict: Dictionary output from `calculate_mutual_information_sampled`, containing
mutual information values for `parent_child`, `parent_parent`, and `self` relationships.
- child_idx: Index of the child node for which to calculate SURD values.
- child_idx: Index of the child node
- parent_idx: Index of the causal parent node in relation to the child node.

Returns:
- Dictionary of SURD values:
- `Unique`: Unique information provided by the causal parent to the child.
- `Redundant`: Redundant information shared by the causal parent with other parents.
- `Synergistic`: Synergistic information only provided when considering multiple parents together.
- `Differential`: Differential information, representing the remaining uncertainty in the child.
- `Redundant`: Redundant causality shared with other parents.
- `Unique`: Unique causality provided only by the specified parent.
- `Synergistic`: Synergistic causality arising from the joint effect with other parents.
- `Leak`: Causality leak representing unobserved influences on the child.
"""

# retrieve MI for parent child
unique_info = mutual_info_dict["parent_child"].get((parent_idx, child_idx), 0)
# Retrieve mutual information between the parent and child
parent_child_mi = mutual_info_dict["parent_child"].get((parent_idx, child_idx), 0)

# get self info for child
# Retrieve self-information of the child node
child_self_info = mutual_info_dict["self"].get(child_idx, 0)

# Initiliase infos
# Initialize redundant, unique, synergistic, and leak information
redundant_info = 0
unique_info = parent_child_mi
synergistic_info = 0
leak_info = child_self_info - parent_child_mi # Start with remaining uncertainty

# CHecking for other parents
# Identify other parents of the child node
other_parents = [p for (p, c) in mutual_info_dict["parent_child"] if c == child_idx and p != parent_idx]

# Calculate shared (reduntant) info
# Get other MIs if other parents exists
if other_parents:
# sum mutual information between parent of interest and other parents
redundant_info = sum(
mutual_info_dict["parent_parent"].get((p, parent_idx), 0) for p in other_parents
) / len(other_parents) if len(other_parents) > 1 else 0

# Synergistic information would only be non-zero if more than one parent is contributing
# Since we have a single parent, it's zero
synergistic_info = 0
redundant_mis = []
for other_parent in other_parents:
# Mutual information between the current parent and each other parent
parent_parent_mi = mutual_info_dict["parent_parent"].get((other_parent, parent_idx), 0)
redundant_mis.append(parent_parent_mi)

# Calculate total redundant information as the average shared info, or should it be total?
redundant_info = sum(redundant_mis) / len(redundant_mis) if redundant_mis else 0

# Calculate unique information by subtracting redundancy from the parent-child MI
unique_info = max(parent_child_mi - redundant_info, 0)

# Synergistic information is any information gain when considering multiple parents together
# Estimated as the extra information not accounted for by unique and redundant parts
combined_parent_mi = sum(mutual_info_dict["parent_child"].get((p, child_idx), 0) for p in other_parents)
synergistic_info = max(child_self_info - (combined_parent_mi + unique_info + redundant_info), 0)

# Differential information:
# remaining uncertainty in the child node after accounting for the other infos
differential_info = child_self_info - unique_info - redundant_info - synergistic_info
# Calculate causality leak as remaining uncertainty in the child not captured by any parent
leak_info = max(child_self_info - (unique_info + redundant_info + synergistic_info), 0)

return {
"Unique": unique_info,
"Redundant": redundant_info,
"Unique": unique_info,
"Synergistic": synergistic_info,
"Differential": max(differential_info, 0) # Ensures differential information is non-negative
"Leak": leak_info
}


def approximate_counterfactual_prediction_error(network, parent_idx, child_idx):
"""
Calculate the prediction error for a child node in a Network by
removing the influence of a specified parent node (for now, this is based on the current coupling strength)
and assessing the child's prediction error without that influence.
The function is based on the logic that the parent’s prediction error
influences the child’s prediction error in proportion to the coupling strength between them.

Parameters:
- network: A Network.
- parent_idx: Index of the causal parent node whose influence we want to exclude.
- child_idx: Index of the child node for which to calculate the counterfactual prediction error.

Returns:
- Approximated counterfactual prediction error: An estimate of the child node's prediction
error if the parent's influence were excluded.
"""
# retrieve the child’s prediction error
child_prediction_error = network.attributes[child_idx]["temp"]["value_prediction_error"]

# Get the current strength
# the coupling strenght is currently not appropriately accessed
coupling_strength = network.attributes[parent_idx]["value_coupling_children"][child_idx]

# retrieve the parent’s prediction error, if observed
parent_prediction_error = network.attributes[parent_idx]["temp"].get("value_prediction_error", 0.0)
if not network.attributes[parent_idx].get("observed", True):
parent_prediction_error = 0.0

# Calculate the counterfactual prediction error as if the parent had no influence
counterfactual_prediction_error = child_prediction_error + coupling_strength * parent_prediction_error

return counterfactual_prediction_error



def update_causal_coupling_strength(network, parent_idx, child_idx, learning_rate= 0.1):
"""
Update the causal coupling strength between a specified parent and child node in a Network
based on whether the parent's influence positively influenced the child's prediction.

Parameters:
- network: A Network instance containing nodes and their attributes, edges, and observation flags.
- parent_idx: Index of the causal parent node.
- child_idx: Index of the child node.
- learning_rate: Learning rate parameter (eta) controlling the update step size.

Returns:
- Updated causal coupling strength.
"""
# Get the current causal coupling strength for the specific parent-child edge
# Same issue as previously
current_coupling_strength = network.attributes[child_idx]["value_coupling_children"][parent_idx]

# GEt the child’s prediction error
child_prediction_error = network.attributes[child_idx]["temp"]["value_prediction_error"]

# Get the parent's prediction error, ensuring it's only considered if observed
parent_prediction_error = approximate_counterfactual_prediction_error(network, parent_idx, child_idx)

# Estimate the counterfactual prediction error, using the coupling strengt as an approximation
counterfactual_prediction_error = child_prediction_error + current_coupling_strength * parent_prediction_error

# Calculate the influence effect
influence_effect = counterfactual_prediction_error - child_prediction_error

# Update based on sigmoid transformation, learning rate and influence
updated_coupling_strength = current_coupling_strength + learning_rate * (1 / (1 + np.exp(-influence_effect)))

# Store the updated coupling strength in attributes for the specific edge
# network.attributes[child_idx]["value_coupling_children"][parent_idx] = updated_coupling_strength

return updated_coupling_strength