Skip to content

Commit

Permalink
Merge pull request #142 from CITCOM-project/pylint_refactoring
Browse files Browse the repository at this point in the history
Pylint refactoring
  • Loading branch information
christopher-wild authored Mar 9, 2023
2 parents bbe2300 + d745bde commit d22709d
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 131 deletions.
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ disable=raw-checker-failed,
deprecated-pragma,
use-symbolic-message-instead,
logging-fstring-interpolation,
import-error,

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
10 changes: 10 additions & 0 deletions causal_testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
This is the CausalTestingFramework Module
It contains 5 subpackages:
data_collection
generation
json_front
specification
testing
"""

import logging

logger = logging.getLogger(__name__)
Expand Down
8 changes: 7 additions & 1 deletion causal_testing/generation/abstract_causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class AbstractCausalTestCase:
"""

def __init__(
# pylint: disable=too-many-arguments
self,
scenario: Scenario,
intervention_constraints: set[z3.ExprRef],
Expand Down Expand Up @@ -77,7 +78,11 @@ def sanitise(string):
)

def _generate_concrete_tests(
self, sample_size: int, rct: bool = False, seed: int = 0
# pylint: disable=too-many-locals
self,
sample_size: int,
rct: bool = False,
seed: int = 0,
) -> tuple[list[CausalTestCase], pd.DataFrame]:
"""Generates a list of `num` concrete test cases.
Expand Down Expand Up @@ -151,6 +156,7 @@ def _generate_concrete_tests(
return concrete_tests, pd.DataFrame(runs, columns=run_columns + ["bin"])

def generate_concrete_tests(
# pylint: disable=too-many-arguments, too-many-locals
self,
sample_size: int,
target_ks_score: float = None,
Expand Down
76 changes: 51 additions & 25 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""This module contains the JsonUtility class, details of using this class can be found here:
https://causal-testing-framework.readthedocs.io/en/latest/json_front_end.html"""

import argparse
import json
import logging

from abc import ABC
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
Expand Down Expand Up @@ -42,49 +44,38 @@ class JsonUtility(ABC):
"""

def __init__(self, log_path):
self.json_path = None
self.dag_path = None
self.data_path = None
self.inputs = None
self.outputs = None
self.metas = None
self.paths = None
self.variables = None
self.data = None
self.test_plan = None
self.modelling_scenario = None
self.causal_specification = None
self.setup_logger(log_path)

def set_path(self, json_path: str, dag_path: str, data_path: str):
def set_paths(self, json_path: str, dag_path: str, data_path: str):
"""
Takes a path of the directory containing all scenario specific files and creates individual paths for each file
:param json_path: string path representation to .json file containing test specifications
:param dag_path: string path representation to the .dot file containing the Causal DAG
:param data_path: string path representation to the data file
:returns:
- json_path -
- dag_path -
- data_path -
"""
self.json_path = Path(json_path)
self.dag_path = Path(dag_path)
self.data_path = Path(data_path)
self.paths = JsonClassPaths(json_path=json_path, dag_path=dag_path, data_path=data_path)

def set_variables(self, inputs: dict, outputs: dict, metas: dict):
def set_variables(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
"""Populate the Causal Variables
:param inputs:
:param outputs:
:param metas:
"""
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []

self.variables = CausalVariables(inputs=inputs, outputs=outputs, metas=metas)

def setup(self):
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
self.modelling_scenario = Scenario(self.inputs + self.outputs + self.metas, None)
self.modelling_scenario = Scenario(self.variables.inputs + self.variables.outputs + self.variables.metas, None)
self.modelling_scenario.setup_treatment_variables()
self.causal_specification = CausalSpecification(
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.dag_path)
scenario=self.modelling_scenario, causal_dag=CausalDAG(self.paths.dag_path)
)
self._json_parse()
self._populate_metas()
Expand Down Expand Up @@ -139,20 +130,20 @@ def _execute_tests(self, concrete_tests, estimators, test, f_flag):

def _json_parse(self):
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
with open(self.json_path, encoding="utf-8") as f:
with open(self.paths.json_path, encoding="utf-8") as f:
self.test_plan = json.load(f)

self.data = pd.read_csv(self.data_path)
self.data = pd.read_csv(self.paths.data_path)

def _populate_metas(self):
"""
Populate data with meta-variable values and add distributions to Causal Testing Framework Variables
"""

for meta in self.metas:
for meta in self.variables.metas:
meta.populate(self.data)

for var in self.metas + self.outputs:
for var in self.variables.metas + self.variables.outputs:
if not var.distribution:
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
fitter.fit()
Expand Down Expand Up @@ -202,7 +193,7 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
- causal_test_engine - Test Engine instance for the test being run
- estimation_model - Estimator instance for the test being run
"""
data_collector = ObservationalDataCollector(self.modelling_scenario, self.data_path)
data_collector = ObservationalDataCollector(self.modelling_scenario, self.paths.data_path)
causal_test_engine = CausalTestEngine(self.causal_specification, data_collector, index_col=0)
minimal_adjustment_set = self.causal_specification.causal_dag.identification(causal_test_case.base_test_case)
treatment_var = causal_test_case.treatment_variable
Expand Down Expand Up @@ -273,3 +264,38 @@ def get_args(test_args=None) -> argparse.Namespace:
required=True,
)
return parser.parse_args(test_args)


@dataclass
class JsonClassPaths:
"""
A dataclass that converts strings of paths to Path objects for use in the JsonUtility class
:param json_path: string path representation to .json file containing test specifications
:param dag_path: string path representation to the .dot file containing the Causal DAG
:param data_path: string path representation to the data file
"""

json_path: Path
dag_path: Path
data_path: Path

def __init__(self, json_path: str, dag_path: str, data_path: str):
self.json_path = Path(json_path)
self.dag_path = Path(dag_path)
self.data_path = Path(data_path)


@dataclass()
class CausalVariables:
"""
A dataclass that converts
"""

inputs: list[Input]
outputs: list[Output]
metas: list[Meta]

def __init__(self, inputs: list[dict], outputs: list[dict], metas: list[dict]):
self.inputs = [Input(i["name"], i["type"], i["distribution"]) for i in inputs]
self.outputs = [Output(i["name"], i["type"], i.get("distribution", None)) for i in outputs]
self.metas = [Meta(i["name"], i["type"], i["populate"]) for i in metas] if metas else []
6 changes: 3 additions & 3 deletions causal_testing/specification/causal_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,19 +150,19 @@ def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")

# (ii) Instrument does not affect outcome except through its potential effect on treatment
if not all([treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome)]):
if not all((treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome))):
raise ValueError(
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
)

# (iii) Instrument and outcome do not share causes
if any(
[
(
cause
for cause in self.graph.nodes
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
]
)
):
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")

Expand Down
9 changes: 3 additions & 6 deletions causal_testing/specification/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import lhsmdu
from pandas import DataFrame
from scipy.stats._distn_infrastructure import rv_generic
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String

# Declare type variable
T = TypeVar("T")
Z3 = TypeVar("Z3")
z3 = TypeVar("Z3")


def z3_types(datatype: T) -> Z3:
def z3_types(datatype: T) -> z3:
"""Cast datatype to Z3 datatype
:param datatype: python datatype to be cast
:return: Type name compatible with Z3 library
Expand Down Expand Up @@ -76,7 +76,6 @@ def __init__(self, name: str, datatype: T, distribution: rv_generic = None):
def __repr__(self):
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"

# TODO: We're going to need to implement all the supported Z3 operations like this
def __ge__(self, other: Any) -> BoolRef:
"""Create the Z3 expression `other >= self`.
Expand Down Expand Up @@ -167,8 +166,6 @@ def cast(self, val: Any) -> T:
return val.as_string()
if (isinstance(val, (float, int, bool))) and (self.datatype in (float, int, bool)):
return self.datatype(val)
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
return self.datatype(str(val))
return self.datatype(str(val))

def z3_val(self, z3_var, val: Any) -> T:
Expand Down
2 changes: 2 additions & 0 deletions causal_testing/testing/causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


class CausalTestCase:
# pylint: disable=too-many-instance-attributes
"""
A CausalTestCase extends the information held in a BaseTestCase. As well as storing the treatment and outcome
variables, a CausalTestCase stores the values of these variables. Also the outcome variable and value are
Expand All @@ -22,6 +23,7 @@ class CausalTestCase:
"""

def __init__(
# pylint: disable=too-many-arguments
self,
base_test_case: BaseTestCase,
expected_causal_effect: CausalTestOutcome,
Expand Down
41 changes: 11 additions & 30 deletions causal_testing/testing/causal_test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
causal_test_result objects
"""
if self.scenario_execution_data_df.empty:
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
test_suite_results = {}
for edge in test_suite:
print("edge: ")
Expand All @@ -75,7 +75,7 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
list(minimal_adjustment_set) + [edge.treatment_variable.name] + [edge.outcome_variable.name]
)
if self._check_positivity_violation(variables_for_positivity):
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")

estimators = test_suite[edge]["estimators"]
tests = test_suite[edge]["tests"]
Expand All @@ -85,13 +85,10 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
causal_test_results = []

for test in tests:
treatment_variable = test.treatment_variable
treatment_value = test.treatment_value
control_value = test.control_value
estimator = estimator_class(
treatment_variable.name,
treatment_value,
control_value,
test.treatment_variable.name,
test.treatment_value,
test.control_value,
minimal_adjustment_set,
test.outcome_variable.name,
)
Expand Down Expand Up @@ -125,7 +122,7 @@ def execute_test(
:return causal_test_result: A CausalTestResult for the executed causal test case.
"""
if self.scenario_execution_data_df.empty:
raise Exception("No data has been loaded. Please call load_data prior to executing a causal test case.")
raise ValueError("No data has been loaded. Please call load_data prior to executing a causal test case.")
if estimator.df is None:
estimator.df = self.scenario_execution_data_df
treatment_variable = causal_test_case.treatment_variable
Expand All @@ -141,7 +138,7 @@ def execute_test(
variables_for_positivity = list(minimal_adjustment_set) + [treatment_variable.name] + [outcome_variable.name]

if self._check_positivity_violation(variables_for_positivity):
raise Exception("POSITIVITY VIOLATION -- Cannot proceed.")
raise ValueError("POSITIVITY VIOLATION -- Cannot proceed.")

causal_test_result = self._return_causal_test_results(estimate_type, estimator, causal_test_case)
return causal_test_result
Expand All @@ -161,11 +158,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case

cates_df, confidence_intervals = estimator.estimate_cates()
causal_test_result = CausalTestResult(
treatment=estimator.treatment,
outcome=estimator.outcome,
treatment_value=estimator.treatment_value,
control_value=estimator.control_value,
adjustment_set=estimator.adjustment_set,
estimator=estimator,
test_value=TestValue("ate", cates_df),
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
confidence_intervals=confidence_intervals,
Expand All @@ -174,11 +167,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
logger.debug("calculating risk_ratio")
risk_ratio, confidence_intervals = estimator.estimate_risk_ratio()
causal_test_result = CausalTestResult(
treatment=estimator.treatment,
outcome=estimator.outcome,
treatment_value=estimator.treatment_value,
control_value=estimator.control_value,
adjustment_set=estimator.adjustment_set,
estimator=estimator,
test_value=TestValue("risk_ratio", risk_ratio),
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
confidence_intervals=confidence_intervals,
Expand All @@ -187,11 +176,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
logger.debug("calculating ate")
ate, confidence_intervals = estimator.estimate_ate()
causal_test_result = CausalTestResult(
treatment=estimator.treatment,
outcome=estimator.outcome,
treatment_value=estimator.treatment_value,
control_value=estimator.control_value,
adjustment_set=estimator.adjustment_set,
estimator=estimator,
test_value=TestValue("ate", ate),
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
confidence_intervals=confidence_intervals,
Expand All @@ -202,11 +187,7 @@ def _return_causal_test_results(self, estimate_type, estimator, causal_test_case
logger.debug("calculating ate")
ate, confidence_intervals = estimator.estimate_ate_calculated()
causal_test_result = CausalTestResult(
treatment=estimator.treatment,
outcome=estimator.outcome,
treatment_value=estimator.treatment_value,
control_value=estimator.control_value,
adjustment_set=estimator.adjustment_set,
estimator=estimator,
test_value=TestValue("ate", ate),
effect_modifier_configuration=causal_test_case.effect_modifier_configuration,
confidence_intervals=confidence_intervals,
Expand Down
Loading

0 comments on commit d22709d

Please sign in to comment.