Skip to content

Commit

Permalink
Merge pull request #225 from CITCOM-project/estimator_params_for_linear
Browse files Browse the repository at this point in the history
add estimator params for linear regression estimate methods
  • Loading branch information
christopher-wild authored Aug 15, 2023
2 parents ee8be4d + d8439a7 commit 18b6ccd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
32 changes: 17 additions & 15 deletions causal_testing/testing/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,16 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres
# x = x[model.params.index]
return model.predict(x)

def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]:
def estimate_control_treatment(
self, adjustment_config: dict = None, bootstrap_size: int = 100
) -> tuple[pd.Series, pd.Series]:
"""Estimate the outcomes under control and treatment.
:return: The estimated control and treatment values and their confidence
intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)).
"""

if adjustment_config is None:
adjustment_config = {}
y = self.estimate(self.df, adjustment_config=adjustment_config)

try:
Expand Down Expand Up @@ -197,18 +200,16 @@ def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple

return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment))

def estimate_ate(self, estimator_params: dict = None) -> float:
def estimate_ate(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value. Here, we actually
calculate the expected outcomes under control and treatment and take one away from the other. This
allows for custom terms to be put in such as squares, inverses, products, etc.
:return: The estimated average treatment effect and 95% confidence intervals
"""
if estimator_params is None:
estimator_params = {}
bootstrap_size = estimator_params.get("bootstrap_size", 100)
adjustment_config = estimator_params.get("adjustment_config", None)
if adjustment_config is None:
adjustment_config = {}
(control_outcome, control_bootstraps), (
treatment_outcome,
treatment_bootstraps,
Expand All @@ -231,18 +232,16 @@ def estimate_ate(self, estimator_params: dict = None) -> float:

return estimate, (ci_low, ci_high)

def estimate_risk_ratio(self, estimator_params: dict = None) -> float:
def estimate_risk_ratio(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float:
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value. Here, we actually
calculate the expected outcomes under control and treatment and divide one by the other. This
allows for custom terms to be put in such as squares, inverses, products, etc.
:return: The estimated risk ratio and 95% confidence intervals.
"""
if estimator_params is None:
estimator_params = {}
bootstrap_size = estimator_params.get("bootstrap_size", 100)
adjustment_config = estimator_params.get("adjustment_config", None)
if adjustment_config is None:
adjustment_config = {}
(control_outcome, control_bootstraps), (
treatment_outcome,
treatment_bootstraps,
Expand Down Expand Up @@ -374,7 +373,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
"""
if adjustment_config is None:
adjustment_config = {}

model = self._run_linear_regression()

x = pd.DataFrame(columns=self.df.columns)
Expand All @@ -393,13 +391,15 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd

return y.iloc[1], y.iloc[0]

def estimate_risk_ratio(self) -> tuple[float, list[float, float]]:
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value.
:return: The average treatment effect and the 95% Wald confidence intervals.
"""
control_outcome, treatment_outcome = self.estimate_control_treatment()
if adjustment_config is None:
adjustment_config = {}
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]

Expand All @@ -413,6 +413,8 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
:return: The average treatment effect and the 95% Wald confidence intervals.
"""
if adjustment_config is None:
adjustment_config = {}
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]
Expand Down
7 changes: 4 additions & 3 deletions tests/testing_tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ def test_ate_adjustment(self):
logistic_regression_estimator = LogisticRegressionEstimator(
"length_in", 65, 55, {"large_gauge"}, "completed", df
)
ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}})
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
self.assertEqual(round(ate, 4), -0.3388)

def test_ate_invalid_adjustment(self):
df = self.scarf_df.copy()
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
with self.assertRaises(ValueError):
ate, _ = logistic_regression_estimator.estimate_ate(
estimator_params={"adjustment_config": {"large_gauge": 0}}
adjustment_config = {"large_gauge": 0}
)

def test_ate_effect_modifiers(self):
Expand Down Expand Up @@ -392,8 +392,9 @@ def test_program_15_no_interaction_ate_calculated(self):
)
# terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"]
# for term_to_square in terms_to_square:

ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
{k: self.nhefs_df.mean()[k] for k in covariates}
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
)
self.assertEqual(round(ate, 1), 3.5)
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
Expand Down

0 comments on commit 18b6ccd

Please sign in to comment.