Skip to content

Commit

Permalink
Merge branch 'main' into estimator_params_for_linear
Browse files Browse the repository at this point in the history
  • Loading branch information
christopher-wild authored Aug 15, 2023
2 parents ef094af + ee8be4d commit d8439a7
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 34 deletions.
16 changes: 0 additions & 16 deletions causal_testing/testing/causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,6 @@ def __init__(
else:
self.effect_modifier_configuration = {}

def get_treatment_variable(self):
"""Return the treatment variable name (as string) for this causal test case"""
return self.treatment_variable.name

def get_outcome_variable(self):
"""Return the outcome variable name (as string) for this causal test case."""
return self.outcome_variable.name

def get_control_value(self):
"""Return a the control value of the treatment variable in this causal test case."""
return self.control_value

def get_treatment_value(self):
"""Return the treatment value of the treatment variable in this causal test case."""
return self.treatment_value

def execute_test(self, estimator: type(Estimator), data_collector: DataCollector) -> CausalTestResult:
"""Execute a causal test case and return the causal test result.
Expand Down
8 changes: 5 additions & 3 deletions causal_testing/testing/causal_test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ def to_dict(self, json=False):
"outcome": self.estimator.outcome,
"adjustment_set": list(self.adjustment_set) if json else self.adjustment_set,
"effect_measure": self.test_value.type,
"effect_estimate": self.test_value.value,
"ci_low": self.ci_low(),
"ci_high": self.ci_high(),
"effect_estimate": self.test_value.value.to_dict()
if json and hasattr(self.test_value.value, "to_dict")
else self.test_value.value,
"ci_low": self.ci_low().to_dict() if json and hasattr(self.ci_low(), "to_dict") else self.ci_low(),
"ci_high": self.ci_high().to_dict() if json and hasattr(self.ci_high(), "to_dict") else self.ci_high(),
}
if self.adequacy:
base_dict["adequacy"] = self.adequacy.to_dict()
Expand Down
4 changes: 2 additions & 2 deletions examples/poisson-line-process/example_poisson_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def causal_test_intensity_num_shapes(
# 8. Set up an estimator
data = pd.read_csv(observational_data_path)

treatment = causal_test_case.get_treatment_variable()
outcome = causal_test_case.get_outcome_variable()
treatment = causal_test_case.treatment_variable.name
outcome = causal_test_case.outcome_variable.name

estimator = None
if empirical:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"fitter~=1.4",
"lhsmdu~=1.1",
"networkx~=2.6",
"numpy~=1.22.0",
"numpy~=1.23",
"pandas~=1.3",
"scikit_learn~=1.1",
"scipy~=1.7",
Expand Down
12 changes: 0 additions & 12 deletions tests/testing_tests/test_causal_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,6 @@ def setUp(self) -> None:
treatment_value=1,
)

def test_get_treatment_variable(self):
self.assertEqual(self.causal_test_case.get_treatment_variable(), "A")

def test_get_outcome_variable(self):
self.assertEqual(self.causal_test_case.get_outcome_variable(), "C")

def test_get_treatment_value(self):
self.assertEqual(self.causal_test_case.get_treatment_value(), 1)

def test_get_control_value(self):
self.assertEqual(self.causal_test_case.get_control_value(), 0)

def test_str(self):
self.assertEqual(
str(self.causal_test_case),
Expand Down

0 comments on commit d8439a7

Please sign in to comment.