From 65e86767a2dc642e106b66e95b96579140b9bb72 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 5 Sep 2024 18:11:51 +0200 Subject: [PATCH 1/8] fix get and set params of PostPredictionWrapper --- molpipeline/post_prediction.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/molpipeline/post_prediction.py b/molpipeline/post_prediction.py index f5a5881e..b47c5173 100644 --- a/molpipeline/post_prediction.py +++ b/molpipeline/post_prediction.py @@ -11,7 +11,7 @@ from typing_extensions import Self from numpy import typing as npt -from sklearn.base import BaseEstimator, TransformerMixin, clone +from sklearn.base import BaseEstimator, TransformerMixin from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement from molpipeline.error_handling import FilterReinserter @@ -194,15 +194,10 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: dict[str, Any] Parameters. """ + param_dict = {"wrapped_estimator": self.wrapped_estimator} if deep: - param_dict = { - "wrapped_estimator": clone(self.wrapped_estimator), - } - else: - param_dict = { - "wrapped_estimator": self.wrapped_estimator, - } - param_dict.update(self.wrapped_estimator.get_params(deep=deep)) + for key, value in self.wrapped_estimator.get_params(deep=deep).items(): + param_dict[f"wrapped_estimator__{key}"] = value return param_dict def set_params(self, **params: Any) -> Self: @@ -219,12 +214,12 @@ def set_params(self, **params: Any) -> Self: Parameters. """ param_copy = dict(params) - wrapped_estimator = param_copy.pop("wrapped_estimator") - if wrapped_estimator: - self.wrapped_estimator = wrapped_estimator - if param_copy: - if isinstance(self.wrapped_estimator, ABCPipelineElement): - self.wrapped_estimator.set_params(**param_copy) - else: - self.wrapped_estimator.set_params(**param_copy) + if "wrapped_estimator" in param_copy: + self.wrapped_estimator = param_copy.pop("wrapped_estimator") + wrapped_params = {} + for key, value in param_copy.items(): + estimator, _, param = key.partition("__") + if estimator == "wrapped_estimator": + wrapped_params[param] = value + self.wrapped_estimator.set_params(**param_copy) return self From 840910e23c9694e3b70629a1b57516acfc870c74 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 5 Sep 2024 18:15:53 +0200 Subject: [PATCH 2/8] fix param setting --- molpipeline/post_prediction.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/molpipeline/post_prediction.py b/molpipeline/post_prediction.py index b47c5173..e0525629 100644 --- a/molpipeline/post_prediction.py +++ b/molpipeline/post_prediction.py @@ -216,10 +216,10 @@ def set_params(self, **params: Any) -> Self: param_copy = dict(params) if "wrapped_estimator" in param_copy: self.wrapped_estimator = param_copy.pop("wrapped_estimator") - wrapped_params = {} + wrapped_estimator_params = {} for key, value in param_copy.items(): estimator, _, param = key.partition("__") if estimator == "wrapped_estimator": - wrapped_params[param] = value - self.wrapped_estimator.set_params(**param_copy) + wrapped_estimator_params[param] = value + self.wrapped_estimator.set_params(**wrapped_estimator_params) return self From 31ac61ee6b0474ba7164ca1f0f5fa89a4c734bf9 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 18 Sep 2024 11:35:49 +0200 Subject: [PATCH 3/8] Add tests --- tests/test_elements/test_post_prediction.py | 75 +++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 tests/test_elements/test_post_prediction.py diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py new file mode 100644 index 00000000..b3f74e98 --- /dev/null +++ b/tests/test_elements/test_post_prediction.py @@ -0,0 +1,75 @@ +"""Test the module post_prediction.py.""" + +import unittest + +import numpy as np +from sklearn.base import clone +from sklearn.decomposition import PCA +from sklearn.ensemble import RandomForestClassifier + +from molpipeline.post_prediction import PostPredictionWrapper + + +class TestPostPredictionWrapper(unittest.TestCase): + """Test the PostPredictionWrapper class.""" + + def test_get_params(self) -> None: + """Test get_params method.""" + rf = RandomForestClassifier() + rf_params = rf.get_params(deep=True) + + ppw = PostPredictionWrapper(rf) + ppw_params = ppw.get_params(deep=True) + + wrapped_params = {} + for key, value in ppw_params.items(): + first, sep, rest = key.partition("__") + if first == "wrapped_estimator": + if rest == "": + self.assertEqual(rf, value) + else: + wrapped_params[rest] = value + + self.assertDictEqual(rf_params, wrapped_params) + + def test_set_params(self) -> None: + """Test set_params method.""" + rf = RandomForestClassifier() + ppw = PostPredictionWrapper(rf) + + ppw.set_params(wrapped_estimator__n_estimators=10) + if not isinstance(ppw.wrapped_estimator, RandomForestClassifier): + raise TypeError("Wrapped estimator is not a RandomForestClassifier.") + self.assertEqual(ppw.wrapped_estimator.n_estimators, 10) + + ppw_params = ppw.get_params(deep=True) + self.assertEqual(ppw_params["wrapped_estimator__n_estimators"], 10) + + def test_fit_transform(self) -> None: + """Test fit method.""" + rng = np.random.default_rng(20240918) + features = rng.random((100, 10)) + + pca = PCA(n_components=3) + pca.fit(features) + pca_transformed = pca.transform(features) + + ppw = PostPredictionWrapper(clone(pca)) + ppw.fit(features) + ppw_transformed = ppw.transform(features) + + self.assertEqual(pca_transformed.shape, ppw_transformed.shape) + self.assertTrue(np.allclose(pca_transformed, ppw_transformed)) + + def test_inverse_transform(self) -> None: + """Test inverse_transform method.""" + rng = np.random.default_rng(20240918) + features = rng.random((100, 10)) + + ppw = PostPredictionWrapper(PCA(n_components=3)) + ppw.fit(features) + ppw_transformed = ppw.transform(features) + ppw_inverse = ppw.inverse_transform(ppw_transformed) + + self.assertEqual(features.shape, ppw_inverse.shape) + self.assertTrue(np.allclose(features, ppw_inverse)) From d36e2f2cd16d495a8b7445e0a07249c3633444c2 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 18 Sep 2024 11:45:48 +0200 Subject: [PATCH 4/8] remove unused var --- tests/test_elements/test_post_prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py index b3f74e98..0b5ca691 100644 --- a/tests/test_elements/test_post_prediction.py +++ b/tests/test_elements/test_post_prediction.py @@ -23,7 +23,7 @@ def test_get_params(self) -> None: wrapped_params = {} for key, value in ppw_params.items(): - first, sep, rest = key.partition("__") + first, _, rest = key.partition("__") if first == "wrapped_estimator": if rest == "": self.assertEqual(rf, value) From c81b04882654c086d171f4441b3891e75ba25007 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 18 Sep 2024 12:11:18 +0200 Subject: [PATCH 5/8] Fix test. Inverse transform does not yield original values, as it is not lossless. --- tests/test_elements/test_post_prediction.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py index 0b5ca691..be76a863 100644 --- a/tests/test_elements/test_post_prediction.py +++ b/tests/test_elements/test_post_prediction.py @@ -64,12 +64,19 @@ def test_fit_transform(self) -> None: def test_inverse_transform(self) -> None: """Test inverse_transform method.""" rng = np.random.default_rng(20240918) - features = rng.random((100, 10)) + features = rng.random((5, 10)) - ppw = PostPredictionWrapper(PCA(n_components=3)) + pca = PCA(n_components=3) + pca.fit(features) + pca_transformed = pca.transform(features) + pca_inverse = pca.inverse_transform(pca_transformed) + + ppw = PostPredictionWrapper(clone(pca)) ppw.fit(features) ppw_transformed = ppw.transform(features) ppw_inverse = ppw.inverse_transform(ppw_transformed) self.assertEqual(features.shape, ppw_inverse.shape) - self.assertTrue(np.allclose(features, ppw_inverse)) + self.assertEqual(pca_inverse.shape, ppw_inverse.shape) + + self.assertTrue(np.allclose(pca_inverse, ppw_inverse)) From fbe318e1d5c520371bf7f68d95d4f261144058a7 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 26 Sep 2024 17:17:18 +0200 Subject: [PATCH 6/8] Check for identity --- tests/test_elements/test_post_prediction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py index be76a863..becd71b7 100644 --- a/tests/test_elements/test_post_prediction.py +++ b/tests/test_elements/test_post_prediction.py @@ -26,7 +26,7 @@ def test_get_params(self) -> None: first, _, rest = key.partition("__") if first == "wrapped_estimator": if rest == "": - self.assertEqual(rf, value) + self.assertIs(rf, value) else: wrapped_params[rest] = value From 2730782b6adef534063f900186383b5243ecc31d Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 26 Sep 2024 17:20:08 +0200 Subject: [PATCH 7/8] Add test for Instance. Keep raise to make mypy happy --- tests/test_elements/test_post_prediction.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py index becd71b7..9ee3f9aa 100644 --- a/tests/test_elements/test_post_prediction.py +++ b/tests/test_elements/test_post_prediction.py @@ -38,6 +38,7 @@ def test_set_params(self) -> None: ppw = PostPredictionWrapper(rf) ppw.set_params(wrapped_estimator__n_estimators=10) + self.assertIsInstance(ppw.wrapped_estimator, RandomForestClassifier) if not isinstance(ppw.wrapped_estimator, RandomForestClassifier): raise TypeError("Wrapped estimator is not a RandomForestClassifier.") self.assertEqual(ppw.wrapped_estimator.n_estimators, 10) From 642f6e9095cb9961417f9ae0fa3fb80b497b29d3 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 26 Sep 2024 17:23:45 +0200 Subject: [PATCH 8/8] Reduce test data size --- tests/test_elements/test_post_prediction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_elements/test_post_prediction.py b/tests/test_elements/test_post_prediction.py index 9ee3f9aa..cbfcf901 100644 --- a/tests/test_elements/test_post_prediction.py +++ b/tests/test_elements/test_post_prediction.py @@ -49,7 +49,7 @@ def test_set_params(self) -> None: def test_fit_transform(self) -> None: """Test fit method.""" rng = np.random.default_rng(20240918) - features = rng.random((100, 10)) + features = rng.random((10, 5)) pca = PCA(n_components=3) pca.fit(features) @@ -65,7 +65,7 @@ def test_fit_transform(self) -> None: def test_inverse_transform(self) -> None: """Test inverse_transform method.""" rng = np.random.default_rng(20240918) - features = rng.random((5, 10)) + features = rng.random((10, 5)) pca = PCA(n_components=3) pca.fit(features)