Skip to content

Commit

Permalink
Merge pull request #125 from CITCOM-project/enum_variables
Browse files Browse the repository at this point in the history
Working support for ENUM variables
  • Loading branch information
jmafoster1 authored Jan 24, 2023
2 parents 024948f + fcb25c2 commit 92ccc5d
Show file tree
Hide file tree
Showing 12 changed files with 389 additions and 123 deletions.
4 changes: 4 additions & 0 deletions causal_testing/data_collection/data_collector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from abc import ABC, abstractmethod
from enum import Enum

import pandas as pd
import z3
Expand Down Expand Up @@ -140,4 +141,7 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
for meta in self.scenario.metas():
meta.populate(execution_data_df)
scenario_execution_data_df = self.filter_valid_data(execution_data_df)
for var_name, var in self.scenario.variables.items():
if issubclass(var.datatype, Enum):
scenario_execution_data_df[var_name] = [var.datatype(x) for x in scenario_execution_data_df[var_name]]
return scenario_execution_data_df
83 changes: 58 additions & 25 deletions causal_testing/generation/abstract_causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import pandas as pd
import z3
from scipy import stats
import itertools

from causal_testing.specification.scenario import Scenario
from causal_testing.specification.variable import Variable
from causal_testing.testing.causal_test_case import CausalTestCase
from causal_testing.testing.causal_test_outcome import CausalTestOutcome

from enum import Enum

logger = logging.getLogger(__name__)


Expand All @@ -24,23 +27,25 @@ def __init__(
self,
scenario: Scenario,
intervention_constraints: set[z3.ExprRef],
treatment_variables: set[Variable],
treatment_variable: Variable,
expected_causal_effect: dict[Variable:CausalTestOutcome],
effect_modifiers: set[Variable] = None,
estimate_type: str = "ate",
effect: str = "total",
):
assert treatment_variables.issubset(scenario.variables.values()), (
assert treatment_variable in scenario.variables.values(), (
"Treatment variables must be a subset of variables."
+ f" Instead got:\ntreatment_variables={treatment_variables}\nvariables={scenario.variables}"
+ f" Instead got:\ntreatment_variable={treatment_variable}\nvariables={scenario.variables}"
)

assert len(expected_causal_effect) == 1, "We currently only support tests with one causal outcome"

self.scenario = scenario
self.intervention_constraints = intervention_constraints
self.treatment_variables = treatment_variables
self.treatment_variable = treatment_variable
self.expected_causal_effect = expected_causal_effect
self.estimate_type = estimate_type
self.effect = effect

if effect_modifiers is not None:
self.effect_modifiers = effect_modifiers
Expand Down Expand Up @@ -100,7 +105,12 @@ def _generate_concrete_tests(
for c in self.intervention_constraints:
optimizer.assert_and_track(c, str(c))

optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
for v in run_columns:
optimizer.add_soft(
self.scenario.variables[v].z3
== self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])
)

if optimizer.check() == z3.unsat:
logger.warning(
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",
Expand All @@ -110,14 +120,15 @@ def _generate_concrete_tests(
model = optimizer.model()

concrete_test = CausalTestCase(
control_input_configuration={v: v.cast(model[v.z3]) for v in self.treatment_variables},
control_input_configuration={v: v.cast(model[v.z3]) for v in [self.treatment_variable]},
treatment_input_configuration={
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in self.treatment_variables
v: v.cast(model[self.scenario.treatment_variables[v.name].z3]) for v in [self.treatment_variable]
},
expected_causal_effect=list(self.expected_causal_effect.values())[0],
outcome_variables=list(self.expected_causal_effect.keys()),
estimate_type=self.estimate_type,
effect_modifier_configuration={v: v.cast(model[v.z3]) for v in self.effect_modifiers},
effect=self.effect,
)

for v in self.scenario.inputs():
Expand All @@ -128,19 +139,20 @@ def _generate_concrete_tests(
+ f"{constraints}\nUsing value {v.cast(model[v.z3])} instead in test\n{concrete_test}"
)

concrete_tests.append(concrete_test)
# Control run
control_run = {
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
}
control_run["bin"] = index
runs.append(control_run)
# Treatment run
if rct:
treatment_run = control_run.copy()
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
treatment_run["bin"] = index
runs.append(treatment_run)
if not any([vars(t) == vars(concrete_test) for t in concrete_tests]):
concrete_tests.append(concrete_test)
# Control run
control_run = {
v.name: v.cast(model[v.z3]) for v in self.scenario.variables.values() if v.name in run_columns
}
control_run["bin"] = index
runs.append(control_run)
# Treatment run
if rct:
treatment_run = control_run.copy()
treatment_run.update({k.name: v for k, v in concrete_test.treatment_input_configuration.items()})
treatment_run["bin"] = index
runs.append(treatment_run)

return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])

Expand Down Expand Up @@ -176,9 +188,12 @@ def generate_concrete_tests(
runs = pd.DataFrame()
ks_stats = []

pre_break = False
for i in range(hard_max):
concrete_tests_, runs_ = self._generate_concrete_tests(sample_size, rct, seed + i)
concrete_tests += concrete_tests_
for t_ in concrete_tests_:
if not any([vars(t_) == vars(t) for t in concrete_tests]):
concrete_tests.append(t_)
runs = pd.concat([runs, runs_])
assert concrete_tests_ not in concrete_tests, "Duplicate entries unlikely unless something went wrong"

Expand All @@ -205,14 +220,32 @@ def generate_concrete_tests(
for var in effect_modifier_configs.columns
}
)
if target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
control_values = [test.control_input_configuration[self.treatment_variable] for test in concrete_tests]
treatment_values = [test.treatment_input_configuration[self.treatment_variable] for test in concrete_tests]

if self.treatment_variable.datatype is bool and set([(True, False), (False, True)]).issubset(
set(zip(control_values, treatment_values))
):
pre_break = True
break
if issubclass(self.treatment_variable.datatype, Enum) and set(
{
(x, y)
for x, y in itertools.product(self.treatment_variable.datatype, self.treatment_variable.datatype)
if x != y
}
).issubset(zip(control_values, treatment_values)):
pre_break = True
break
elif target_ks_score and all((stat <= target_ks_score for stat in ks_stats.values())):
pre_break = True
break

if target_ks_score is not None and not all((stat <= target_ks_score for stat in ks_stats.values())):
if target_ks_score is not None and not pre_break:
logger.error(
"Hard max of %s reached but could not achieve target ks_score of %s. Got %s.",
hard_max,
"Hard max reached but could not achieve target ks_score of %s. Got %s. Generated %s distinct tests",
target_ks_score,
ks_stats,
len(concrete_tests),
)
return concrete_tests, runs
29 changes: 14 additions & 15 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
:param metas:
"""
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
self.outputs = [Output(i["name"], i["type"]) for i in outputs]
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []

def setup(self):
Expand All @@ -89,10 +89,11 @@ def setup(self):
self._populate_metas()

def _create_abstract_test_case(self, test, mutates, effects):
assert len(test["mutations"]) == 1
abstract_test = AbstractCausalTestCase(
scenario=self.modelling_scenario,
intervention_constraints=[mutates[v](k) for k, v in test["mutations"].items()],
treatment_variables={self.modelling_scenario.variables[v] for v in test["mutations"]},
treatment_variable=next(self.modelling_scenario.variables[v] for v in test["mutations"]),
expected_causal_effect={
self.modelling_scenario.variables[variable]: effects[effect]
for variable, effect in test["expectedEffect"].items()
Expand All @@ -101,6 +102,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
if "effect_modifiers" in test
else {},
estimate_type=test["estimate_type"],
effect=test.get("effect", "total"),
)
return abstract_test

Expand All @@ -121,10 +123,10 @@ def generate_tests(self, effects: dict, mutates: dict, estimators: dict, f_flag:
concrete_tests, dummy = abstract_test.generate_concrete_tests(5, 0.05)
logger.info("Executing test: %s", test["name"])
logger.info(abstract_test)
logger.info([(v.name, v.distribution) for v in abstract_test.treatment_variables])
logger.info([(v.name, v.distribution) for v in [abstract_test.treatment_variable]])
logger.info("Number of concrete tests for test case: %s", str(len(concrete_tests)))
failures = self._execute_tests(concrete_tests, estimators, test, f_flag)
logger.info("%s/%s failed", failures, len(concrete_tests))
logger.info("%s/%s failed for %s\n", failures, len(concrete_tests), test["name"])

def _execute_tests(self, concrete_tests, estimators, test, f_flag):
failures = 0
Expand All @@ -151,11 +153,12 @@ def _populate_metas(self):
meta.populate(self.data)

for var in self.metas + self.outputs:
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
fitter.fit()
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
var.distribution = getattr(scipy.stats, dist)(**params)
logger.info(var.name + f"{dist}({params})")
if not var.distribution:
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
fitter.fit()
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
var.distribution = getattr(scipy.stats, dist)(**params)
logger.info(var.name + f" {dist}({params})")

def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estimator, f_flag: bool) -> bool:
"""Executes a singular test case, prints the results and returns the test case result
Expand All @@ -178,19 +181,15 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
if causal_test_result.ci_low() and causal_test_result.ci_high():
result_string = f"{causal_test_result.ci_low()} < {causal_test_result.test_value.value} < {causal_test_result.ci_high()}"
else:
result_string = causal_test_result.test_value.value
result_string = f"{causal_test_result.test_value.value} no confidence intervals"
if f_flag:
assert test_passes, (
f"{causal_test_case}\n FAILED - expected {causal_test_case.expected_causal_effect}, "
f"got {result_string}"
)
if not test_passes:
failed = True
logger.warning(
" FAILED- expected %s, got %s",
causal_test_case.expected_causal_effect,
causal_test_result.test_value.value,
)
logger.warning(" FAILED- expected %s, got %s", causal_test_case.expected_causal_effect, result_string)
return failed

def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) -> tuple[CausalTestEngine, Estimator]:
Expand Down
5 changes: 4 additions & 1 deletion causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def direct_effect_adjustment_sets(self, treatments: list[str], outcomes: list[st
gam.add_edges_from(edges_to_add)

min_seps = list(list_all_min_sep(gam, "TREATMENT", "OUTCOME", set(treatments), set(outcomes)))
# min_seps.remove(set(outcomes))
if set(outcomes) in min_seps:
min_seps.remove(set(outcomes))
return min_seps

def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: list[str]) -> list[set[str]]:
Expand All @@ -278,6 +279,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
:param outcomes: A list of strings representing outcomes.
:return: A list of strings representing the minimal adjustment set.
"""

# 1. Construct the proper back-door graph's ancestor moral graph
proper_backdoor_graph = self.get_proper_backdoor_graph(treatments, outcomes)
ancestor_proper_backdoor_graph = proper_backdoor_graph.get_ancestor_graph(treatments, outcomes)
Expand Down Expand Up @@ -316,6 +318,7 @@ def enumerate_minimal_adjustment_sets(self, treatments: list[str], outcomes: lis
for adj in minimum_adjustment_sets
if self.constructive_backdoor_criterion(proper_backdoor_graph, treatments, outcomes, adj)
]

return valid_minimum_adjustment_sets

def adjustment_set_is_minimal(self, treatments: list[str], outcomes: list[str], adjustment_set: set[str]) -> bool:
Expand Down
36 changes: 15 additions & 21 deletions causal_testing/specification/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import lhsmdu
from pandas import DataFrame
from scipy.stats._distn_infrastructure import rv_generic
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef

# Declare type variable
# Is there a better way? I'd really like to do Variable[T](ExprRef)
Expand All @@ -22,7 +22,7 @@ def z3_types(datatype):
if datatype in types:
return types[datatype]
if issubclass(datatype, Enum):
dtype, _ = EnumSort(datatype.__name__, [x.name for x in datatype])
dtype, _ = EnumSort(datatype.__name__, [str(x.value) for x in datatype])
return lambda x: Const(x, dtype)
if hasattr(datatype, "to_z3"):
return datatype.to_z3()
Expand Down Expand Up @@ -153,19 +153,27 @@ def cast(self, val: Any) -> T:
:rtype: T
"""
assert val is not None, f"Invalid value None for variable {self}"
if isinstance(val, self.datatype):
return val
if isinstance(val, BoolRef) and self.datatype == bool:
return str(val) == "True"
if isinstance(val, RatNumRef) and self.datatype == float:
return float(val.numerator().as_long() / val.denominator().as_long())
if hasattr(val, "is_string_value") and val.is_string_value() and self.datatype == str:
return val.as_string()
if (isinstance(val, float) or isinstance(val, int)) and (self.datatype == int or self.datatype == float):
if (isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)) and (
self.datatype == int or self.datatype == float or self.datatype == bool
):
return self.datatype(val)
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
return self.datatype(str(val))
return self.datatype(str(val))

def z3_val(self, z3_var, val: Any) -> T:
native_val = self.cast(val)
if isinstance(native_val, Enum):
values = [z3_var.sort().constructor(c)() for c in range(z3_var.sort().num_constructors())]
values = [v for v in values if str(v) == str(val)]
values = [v for v in values if val.__class__(str(v)) == val]
assert len(values) == 1, f"Expected {values} to be length 1"
return values[0]
return native_val
Expand Down Expand Up @@ -193,7 +201,6 @@ def typestring(self) -> str:
"""
return type(self).__name__

@abstractmethod
def copy(self, name: str = None) -> Variable:
"""Return a new instance of the Variable with the given name, or with
the original name if no name is supplied.
Expand All @@ -203,26 +210,18 @@ def copy(self, name: str = None) -> Variable:
:rtype: Variable
"""
raise NotImplementedError("Method `copy` must be instantiated.")
if name:
return self.__class__(name, self.datatype, self.distribution)
return self.__class__(self.name, self.datatype, self.distribution)


class Input(Variable):
"""An extension of the Variable class representing inputs."""

def copy(self, name=None) -> Input:
if name:
return Input(name, self.datatype, self.distribution)
return Input(self.name, self.datatype, self.distribution)


class Output(Variable):
"""An extension of the Variable class representing outputs."""

def copy(self, name=None) -> Output:
if name:
return Output(name, self.datatype, self.distribution)
return Output(self.name, self.datatype, self.distribution)


class Meta(Variable):
"""An extension of the Variable class representing metavariables. These are variables which are relevant to the
Expand All @@ -242,8 +241,3 @@ class Meta(Variable):
def __init__(self, name: str, datatype: T, populate: Callable[[DataFrame], DataFrame]):
super().__init__(name, datatype)
self.populate = populate

def copy(self, name=None) -> Meta:
if name:
return Meta(name, self.datatype, self.distribution)
return Meta(self.name, self.datatype, self.distribution)
Loading

0 comments on commit 92ccc5d

Please sign in to comment.