diff --git a/CHANGELOG.md b/CHANGELOG.md index b980b239..f428146e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## [2.2.1] - 2022-09-06 +- **Bug Fix** + - Including a necessary init file to allow the import of the causal cate learners. + - Fix a docstring issue where the description of causal learners were not showing all parameters. + ## [2.2.0] - 2022-08-25 - **Enhancement** - Including Classification S-Learner and T-Learner models to the causal cate learning library. diff --git a/src/fklearn/causal/cate_learning/meta_learners.py b/src/fklearn/causal/cate_learning/meta_learners.py index 956e33ee..219f0094 100644 --- a/src/fklearn/causal/cate_learning/meta_learners.py +++ b/src/fklearn/causal/cate_learning/meta_learners.py @@ -191,6 +191,7 @@ def causal_s_classification_learner( [1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html [2] https://causalml.readthedocs.io/en/latest/methodology.html + Parameters ---------- df : pd.DataFrame @@ -369,34 +370,29 @@ def causal_t_classification_learner( and $M_{1}$ are traditional Machine Learning models such as a LightGBM Classifier and that $x_{i}$ is the feature set of sample $i$. - References: + **References:** + [1] https://matheusfacure.github.io/python-causality-handbook/21-Meta-Learners.html + [2] https://causalml.readthedocs.io/en/latest/methodology.html Parameters ---------- - df : pd.DataFrame A Pandas' DataFrame with features and target columns. The model will be trained to predict the target column from the features. - treatment_col: str The name of the column in `df` which contains the names of the treatments and control to which each data sample was subjected. - control_name: str The name of the control group. - prediction_column : str The name of the column with the predictions from the provided learner. - learner: LearnerFnType A fklearn classification learner function. - treatment_learner: LearnerFnType An optional fklearn classification learner function. - learner_transformers: List[LearnerFnType] A list of fklearn transformer functions to be applied after the learner and before estimating the CATE. This parameter may be useful, for example, to estimate the CATE with calibrated classifiers. @@ -447,6 +443,6 @@ def p(new_df: pd.DataFrame) -> pd.DataFrame: return p, p(df), log -causal_t_classification_learner.__doc__ = learner_return_docstring( +causal_t_classification_learner.__doc__ += learner_return_docstring( "Causal T-Learner Classifier" ) diff --git a/src/fklearn/exceptions/__init__.py b/src/fklearn/exceptions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/fklearn/exceptions/exceptions.py b/src/fklearn/exceptions/exceptions.py index beac1101..07d54cc8 100644 --- a/src/fklearn/exceptions/exceptions.py +++ b/src/fklearn/exceptions/exceptions.py @@ -1,17 +1,31 @@ +from typing import Any, Dict, List + + class MultipleTreatmentsError(Exception): - def __init__(self, msg="Data contains multiple treatments.", *args, **kwargs): + def __init__( + self, + msg: str = "Data contains multiple treatments.", + *args: List[Any], + **kwargs: Dict[str, Any] + ) -> None: super().__init__(msg, *args, **kwargs) class MissingControlError(Exception): def __init__( - self, msg="Data does not contain the specified control.", *args, **kwargs - ): + self, + msg: str = "Data does not contain the specified control.", + *args: List[Any], + **kwargs: Dict[str, Any] + ) -> None: super().__init__(msg, *args, **kwargs) class MissingTreatmentError(Exception): def __init__( - self, msg="Data does not contain the specified treatment.", *args, **kwargs - ): + self, + msg: str = "Data does not contain the specified treatment.", + *args: List[Any], + **kwargs: Dict[str, Any] + ) -> None: super().__init__(msg, *args, **kwargs) diff --git a/src/fklearn/resources/VERSION b/src/fklearn/resources/VERSION index ccbccc3d..c043eea7 100644 --- a/src/fklearn/resources/VERSION +++ b/src/fklearn/resources/VERSION @@ -1 +1 @@ -2.2.0 +2.2.1 diff --git a/tests/validation/test_evaluators.py b/tests/validation/test_evaluators.py index cc84f7f1..445179c4 100644 --- a/tests/validation/test_evaluators.py +++ b/tests/validation/test_evaluators.py @@ -469,7 +469,7 @@ def test_exponential_coefficient_evaluator(): result = exponential_coefficient_evaluator(predictions) - assert result['exponential_coefficient_evaluator__target'] == a1 + assert result['exponential_coefficient_evaluator__target'] == pytest.approx(a1) def test_logistic_coefficient_evaluator():