diff --git a/causal_testing/testing/estimators.py b/causal_testing/testing/estimators.py index 6532d8c7..0dab3c55 100644 --- a/causal_testing/testing/estimators.py +++ b/causal_testing/testing/estimators.py @@ -161,13 +161,16 @@ def estimate(self, data: pd.DataFrame, adjustment_config: dict = None) -> Regres # x = x[model.params.index] return model.predict(x) - def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple[pd.Series, pd.Series]: + def estimate_control_treatment( + self, adjustment_config: dict = None, bootstrap_size: int = 100 + ) -> tuple[pd.Series, pd.Series]: """Estimate the outcomes under control and treatment. :return: The estimated control and treatment values and their confidence intervals in the form ((ci_low, control, ci_high), (ci_low, treatment, ci_high)). """ - + if adjustment_config is None: + adjustment_config = {} y = self.estimate(self.df, adjustment_config=adjustment_config) try: @@ -197,7 +200,7 @@ def estimate_control_treatment(self, bootstrap_size, adjustment_config) -> tuple return (y.iloc[1], np.array(control)), (y.iloc[0], np.array(treatment)) - def estimate_ate(self, estimator_params: dict = None) -> float: + def estimate_ate(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float: """Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused by changing the treatment variable from the control value to the treatment value. Here, we actually calculate the expected outcomes under control and treatment and take one away from the other. This @@ -205,10 +208,8 @@ def estimate_ate(self, estimator_params: dict = None) -> float: :return: The estimated average treatment effect and 95% confidence intervals """ - if estimator_params is None: - estimator_params = {} - bootstrap_size = estimator_params.get("bootstrap_size", 100) - adjustment_config = estimator_params.get("adjustment_config", None) + if adjustment_config is None: + adjustment_config = {} (control_outcome, control_bootstraps), ( treatment_outcome, treatment_bootstraps, @@ -231,7 +232,7 @@ def estimate_ate(self, estimator_params: dict = None) -> float: return estimate, (ci_low, ci_high) - def estimate_risk_ratio(self, estimator_params: dict = None) -> float: + def estimate_risk_ratio(self, adjustment_config: dict = None, bootstrap_size: int = 100) -> float: """Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused by changing the treatment variable from the control value to the treatment value. Here, we actually calculate the expected outcomes under control and treatment and divide one by the other. This @@ -239,10 +240,8 @@ def estimate_risk_ratio(self, estimator_params: dict = None) -> float: :return: The estimated risk ratio and 95% confidence intervals. """ - if estimator_params is None: - estimator_params = {} - bootstrap_size = estimator_params.get("bootstrap_size", 100) - adjustment_config = estimator_params.get("adjustment_config", None) + if adjustment_config is None: + adjustment_config = {} (control_outcome, control_bootstraps), ( treatment_outcome, treatment_bootstraps, @@ -374,7 +373,6 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd """ if adjustment_config is None: adjustment_config = {} - model = self._run_linear_regression() x = pd.DataFrame(columns=self.df.columns) @@ -393,13 +391,15 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd return y.iloc[1], y.iloc[0] - def estimate_risk_ratio(self) -> tuple[float, list[float, float]]: + def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]: """Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused by changing the treatment variable from the control value to the treatment value. :return: The average treatment effect and the 95% Wald confidence intervals. """ - control_outcome, treatment_outcome = self.estimate_control_treatment() + if adjustment_config is None: + adjustment_config = {} + control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config) ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"] ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"] @@ -413,6 +413,8 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float :return: The average treatment effect and the 95% Wald confidence intervals. """ + if adjustment_config is None: + adjustment_config = {} control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config) ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"] ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"] diff --git a/tests/testing_tests/test_estimators.py b/tests/testing_tests/test_estimators.py index c9184472..835a1144 100644 --- a/tests/testing_tests/test_estimators.py +++ b/tests/testing_tests/test_estimators.py @@ -124,7 +124,7 @@ def test_ate_adjustment(self): logistic_regression_estimator = LogisticRegressionEstimator( "length_in", 65, 55, {"large_gauge"}, "completed", df ) - ate, _ = logistic_regression_estimator.estimate_ate(estimator_params={"adjustment_config": {"large_gauge": 0}}) + ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0}) self.assertEqual(round(ate, 4), -0.3388) def test_ate_invalid_adjustment(self): @@ -132,7 +132,7 @@ def test_ate_invalid_adjustment(self): logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df) with self.assertRaises(ValueError): ate, _ = logistic_regression_estimator.estimate_ate( - estimator_params={"adjustment_config": {"large_gauge": 0}} + adjustment_config = {"large_gauge": 0} ) def test_ate_effect_modifiers(self): @@ -392,8 +392,9 @@ def test_program_15_no_interaction_ate_calculated(self): ) # terms_to_square = ["age", "wt71", "smokeintensity", "smokeyrs"] # for term_to_square in terms_to_square: + ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated( - {k: self.nhefs_df.mean()[k] for k in covariates} + adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates} ) self.assertEqual(round(ate, 1), 3.5) self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])