From 737db7fda3c9c89c162b8fb2c4e3d0d26fd1f828 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Mon, 26 Aug 2024 15:18:00 +0200 Subject: [PATCH 01/25] add multi class classifier --- molpipeline/estimators/chemprop/models.py | 66 +++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index e3257ae5..8baaecac 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -348,3 +348,69 @@ def __init__( n_jobs=n_jobs, **kwargs, ) + +class ChempropMulticlassClassifier(ChempropModel): + """Chemprop model with default parameters for regression tasks.""" + + def __init__( + self, + model: MPNN | None = None, + lightning_trainer: pl.Trainer | None = None, + batch_size: int = 64, + n_jobs: int = 1, + n_classes: int = 3, + **kwargs: Any, + ) -> None: + """Initialize the chemprop regressor model. + + Parameters + ---------- + model : MPNN | None, optional + The chemprop model to wrap. If None, a default model will be used. + lightning_trainer : pl.Trainer, optional + The lightning trainer to use, by default None + batch_size : int, optional (default=64) + The batch size to use. + n_jobs : int, optional (default=1) + The number of jobs to use. + kwargs : Any + Parameters set using `set_params`. + Can be used to modify components of the model. + """ + if model is None: + bond_encoder = BondMessagePassing() + agg = SumAggregation() + predictor = MulticlassClassificationFFN(n_classes=n_classes) + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + super().__init__( + model=model, + lightning_trainer=lightning_trainer, + batch_size=batch_size, + n_jobs=n_jobs, + **kwargs, + ) + self.n_classes = n_classes + + def set_params(self, **params: Any) -> Self: + """Set the parameters of the model and check if it is a binary classifier. + + Parameters + ---------- + **params + The parameters to set. + + Returns + ------- + Self + The model with the new parameters. + """ + super().set_params(**params) + if not self._is_multiclass_classifier(): + raise ValueError("ChempropMulticlassClassifier should contain more than 2 classes.") + return self + + def get_params(self, deep: bool = False) -> dict[str, Any]: + params = super().get_params(deep) + return params + + From 8692bc2fa04a98172f1f0204b128a9ae96a2b321 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Mon, 26 Aug 2024 15:18:35 +0200 Subject: [PATCH 02/25] use input check to prevent confusing message by torch of the class labels do not match requirements --- molpipeline/estimators/chemprop/models.py | 44 +++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 8baaecac..8a0388f2 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -413,4 +413,48 @@ def get_params(self, deep: bool = False) -> dict[str, Any]: params = super().get_params(deep) return params + def fit( + self, + X: MoleculeDataset, + y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64], + ) -> Self: + """Fit the model to the data. + + Parameters + ---------- + X : MoleculeDataset + The input data. + y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64] + The target data. + Returns + ------- + Self + The fitted model. + """ + self._check_correct_input(y) + return super().fit(X, y) + + def _check_correct_input(self,y) -> None: + """Checks if the input for the multi-class classifier is correct. + + Parameters + ---------- + y : _type_ + Indended classes for the dataset + + Raises + ------ + ValueError + if the classes found in y are not matching n_classes or if the class labels do not start from 0 to n_classes-1 + """ + unique_y = np.unique(y) + log = [] + if self.n_classes != len(unique_y): + log.append(f"Given number of classes in init (n_classes) does not match the number of unique classes (found {unique_y}) in the target data.") + if sorted(unique_y) != list(range(self.n_classes)): + err = f"Classes need to be in the range from 0 to {self.n_classes-1}. Found {unique_y}. Please correct the input data accordingly." + print(err) + log.append(err) + if log: + raise ValueError("\n".join(log)) From b560d18c5a7849eb1d7c0e629e85f5bce9ea4ac6 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Mon, 26 Aug 2024 15:22:38 +0200 Subject: [PATCH 03/25] remove get_params --- molpipeline/estimators/chemprop/models.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 8a0388f2..cb8866f2 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -349,6 +349,7 @@ def __init__( **kwargs, ) + class ChempropMulticlassClassifier(ChempropModel): """Chemprop model with default parameters for regression tasks.""" @@ -373,6 +374,8 @@ def __init__( The batch size to use. n_jobs : int, optional (default=1) The number of jobs to use. + n_classes : int, optional (default=3) + The number of classes for the classifier. kwargs : Any Parameters set using `set_params`. Can be used to modify components of the model. @@ -406,13 +409,11 @@ def set_params(self, **params: Any) -> Self: """ super().set_params(**params) if not self._is_multiclass_classifier(): - raise ValueError("ChempropMulticlassClassifier should contain more than 2 classes.") + raise ValueError( + "ChempropMulticlassClassifier should contain more than 2 classes." + ) return self - - def get_params(self, deep: bool = False) -> dict[str, Any]: - params = super().get_params(deep) - return params - + def fit( self, X: MoleculeDataset, @@ -434,8 +435,8 @@ def fit( """ self._check_correct_input(y) return super().fit(X, y) - - def _check_correct_input(self,y) -> None: + + def _check_correct_input(self, y) -> None: """Checks if the input for the multi-class classifier is correct. Parameters @@ -451,7 +452,9 @@ def _check_correct_input(self,y) -> None: unique_y = np.unique(y) log = [] if self.n_classes != len(unique_y): - log.append(f"Given number of classes in init (n_classes) does not match the number of unique classes (found {unique_y}) in the target data.") + log.append( + f"Given number of classes in init (n_classes) does not match the number of unique classes (found {unique_y}) in the target data." + ) if sorted(unique_y) != list(range(self.n_classes)): err = f"Classes need to be in the range from 0 to {self.n_classes-1}. Found {unique_y}. Please correct the input data accordingly." print(err) From b3d0af83dd3b7810619c7fa99ea23fba3acec7d4 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Mon, 26 Aug 2024 15:26:13 +0200 Subject: [PATCH 04/25] make n classes non-optional --- molpipeline/estimators/chemprop/models.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index cb8866f2..4e035bc5 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -354,18 +354,20 @@ class ChempropMulticlassClassifier(ChempropModel): """Chemprop model with default parameters for regression tasks.""" def __init__( - self, + self, + n_classes: int, model: MPNN | None = None, lightning_trainer: pl.Trainer | None = None, batch_size: int = 64, n_jobs: int = 1, - n_classes: int = 3, **kwargs: Any, ) -> None: """Initialize the chemprop regressor model. Parameters - ---------- + ---------- + n_classes : int + The number of classes for the classifier. model : MPNN | None, optional The chemprop model to wrap. If None, a default model will be used. lightning_trainer : pl.Trainer, optional @@ -374,8 +376,6 @@ def __init__( The batch size to use. n_jobs : int, optional (default=1) The number of jobs to use. - n_classes : int, optional (default=3) - The number of classes for the classifier. kwargs : Any Parameters set using `set_params`. Can be used to modify components of the model. @@ -436,8 +436,8 @@ def fit( self._check_correct_input(y) return super().fit(X, y) - def _check_correct_input(self, y) -> None: - """Checks if the input for the multi-class classifier is correct. + def _check_correct_input(self, y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64]) -> None: + """Check if the input for the multi-class classifier is correct. Parameters ---------- From fe40e4a530fab37ef154e8b472dfd91fc565e712 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Mon, 26 Aug 2024 15:26:27 +0200 Subject: [PATCH 05/25] black --- molpipeline/estimators/chemprop/models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 4e035bc5..de22511e 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -354,7 +354,7 @@ class ChempropMulticlassClassifier(ChempropModel): """Chemprop model with default parameters for regression tasks.""" def __init__( - self, + self, n_classes: int, model: MPNN | None = None, lightning_trainer: pl.Trainer | None = None, @@ -365,7 +365,7 @@ def __init__( """Initialize the chemprop regressor model. Parameters - ---------- + ---------- n_classes : int The number of classes for the classifier. model : MPNN | None, optional @@ -436,7 +436,9 @@ def fit( self._check_correct_input(y) return super().fit(X, y) - def _check_correct_input(self, y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64]) -> None: + def _check_correct_input( + self, y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64] + ) -> None: """Check if the input for the multi-class classifier is correct. Parameters From 70e59280f08159329db2224cab1c104a6e5ab54e Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 27 Aug 2024 10:31:08 +0200 Subject: [PATCH 06/25] ignore loghtning logs --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9ad684ea..ab5c6116 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__ molpipeline.egg-info/ lib/ build/ +lightning_logs/ From 3de60fdd84f04c0dd44850231ae7105d284ac8c8 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 27 Aug 2024 10:31:35 +0200 Subject: [PATCH 07/25] add test for multiclass --- .../test_chemprop/test_chemprop_pipeline.py | 75 ++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index c5f66fa0..fa4a9687 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -23,6 +23,7 @@ ChempropClassifier, ChempropModel, ChempropRegressor, + ChempropMulticlassClassifier, ) from molpipeline.mol2any.mol2chemprop import MolToChemprop from molpipeline.pipeline import Pipeline @@ -139,6 +140,35 @@ def get_classification_pipeline() -> Pipeline: return model_pipeline +def get_multiclass_classification_pipeline() -> Pipeline: + """Get the Chemprop model pipeline for classification. + + Returns + ------- + Pipeline + The Chemprop model pipeline for classification. + """ + smiles2mol = SmilesToMol() + mol2chemprop = MolToChemprop() + error_filter = ErrorFilter(filter_everything=True) + filter_reinserter = FilterReinserter.from_error_filter( + error_filter, fill_value=np.nan + ) + chemprop_model = ChempropMulticlassClassifier( + n_classes=3, lightning_trainer=DEFAULT_TRAINER + ) + model_pipeline = Pipeline( + steps=[ + ("smiles2mol", smiles2mol), + ("mol2chemprop", mol2chemprop), + ("error_filter", error_filter), + ("model", chemprop_model), + ("filter_reinserter", PostPredictionWrapper(filter_reinserter)), + ], + ) + return model_pipeline + + _T = TypeVar("_T") @@ -282,7 +312,9 @@ def test_prediction(self) -> None: molecule_net_bbbp_df = pd.read_csv( TEST_DATA_DIR / "molecule_net_bbbp.tsv.gz", sep="\t", nrows=100 ) - molecule_net_bbbp_df.to_csv("molecule_net_bbbp.tsv.gz", sep="\t", index=False) + molecule_net_bbbp_df.to_csv( + "molecule_net_bbbp.tsv.gz", sep="\t", index=False + ) # TODO: remove this line? classification_model = get_classification_pipeline() classification_model.fit( molecule_net_bbbp_df["smiles"].tolist(), @@ -306,3 +338,44 @@ def test_prediction(self) -> None: self.assertEqual(proba.shape, proba_copy.shape) self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) + + +class TestMulticlassClassificationPipeline(unittest.TestCase): + """Test the Chemprop model pipeline for classification.""" + + def test_prediction(self) -> None: + """Test the prediction of the classification model.""" + + test_data_df = pd.read_csv( + TEST_DATA_DIR / "multiclass_mock.tsv", sep="\t", index_col=False + ) + print(test_data_df.head()) + print(test_data_df.columns) + classification_model = get_multiclass_classification_pipeline() + mols = test_data_df["Molecule"].tolist() + classification_model.fit( + mols, + test_data_df["Label"].to_numpy(), + ) + pred = classification_model.predict(mols) + proba = classification_model.predict_proba(mols) + self.assertEqual(len(pred), len(test_data_df)) + self.assertEqual(proba.shape[1], 3) + self.assertEqual(proba.shape[0], len(test_data_df)) + + model_copy = joblib_dump_load(classification_model) + pred_copy = model_copy.predict(mols) + proba_copy = model_copy.predict_proba(mols) + + nan_indices = np.isnan(pred) + self.assertListEqual(nan_indices.tolist(), np.isnan(pred_copy).tolist()) + self.assertTrue(np.allclose(pred[~nan_indices], pred_copy[~nan_indices])) + + self.assertEqual(proba.shape, proba_copy.shape) + self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) + + with self.assertRaises(ValueError): + classification_model.fit( + mols, + test_data_df["Label"].add(1).to_numpy(), + ) From 3003493e110cf1a29f03a65791ccf3cf6061fad2 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 27 Aug 2024 10:31:44 +0200 Subject: [PATCH 08/25] mock data for test --- tests/test_data/multiclass_mock.tsv | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 tests/test_data/multiclass_mock.tsv diff --git a/tests/test_data/multiclass_mock.tsv b/tests/test_data/multiclass_mock.tsv new file mode 100644 index 00000000..ec494222 --- /dev/null +++ b/tests/test_data/multiclass_mock.tsv @@ -0,0 +1,13 @@ +Molecule Label +"CCCCCC" 0 +"CCCCCCCO" 1 +"CCCC" 0 +"CCCN" 2 +"CCCCCC" 0 +"CCCO" 1 +"CCCCC" 0 +"CCCCCN" 2 +"CC(C)CCC" 0 +"CCCCCCO" 1 +"CCCCCl" 0 +"CCC#N" 2 From 1b84120cf966c417e936f8ae3998dc63a63f1ac3 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 27 Aug 2024 13:08:50 +0200 Subject: [PATCH 09/25] remove random write csv --- test_extras/test_chemprop/test_chemprop_pipeline.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index fa4a9687..aab5cc5c 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -312,9 +312,6 @@ def test_prediction(self) -> None: molecule_net_bbbp_df = pd.read_csv( TEST_DATA_DIR / "molecule_net_bbbp.tsv.gz", sep="\t", nrows=100 ) - molecule_net_bbbp_df.to_csv( - "molecule_net_bbbp.tsv.gz", sep="\t", index=False - ) # TODO: remove this line? classification_model = get_classification_pipeline() classification_model.fit( molecule_net_bbbp_df["smiles"].tolist(), From 3dbcff8d1f8d9642a22c49ffc695f1a79df03047 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 28 Aug 2024 08:24:10 +0200 Subject: [PATCH 10/25] add test for full coverage of multiclass chemprop --- test_extras/test_chemprop/test_chemprop_pipeline.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index aab5cc5c..22d5e0d4 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -140,7 +140,7 @@ def get_classification_pipeline() -> Pipeline: return model_pipeline -def get_multiclass_classification_pipeline() -> Pipeline: +def get_multiclass_classification_pipeline(n_classes: int) -> Pipeline: """Get the Chemprop model pipeline for classification. Returns @@ -155,7 +155,7 @@ def get_multiclass_classification_pipeline() -> Pipeline: error_filter, fill_value=np.nan ) chemprop_model = ChempropMulticlassClassifier( - n_classes=3, lightning_trainer=DEFAULT_TRAINER + n_classes=n_classes, lightning_trainer=DEFAULT_TRAINER ) model_pipeline = Pipeline( steps=[ @@ -348,7 +348,7 @@ def test_prediction(self) -> None: ) print(test_data_df.head()) print(test_data_df.columns) - classification_model = get_multiclass_classification_pipeline() + classification_model = get_multiclass_classification_pipeline(n_classes=3) mols = test_data_df["Molecule"].tolist() classification_model.fit( mols, @@ -376,3 +376,9 @@ def test_prediction(self) -> None: mols, test_data_df["Label"].add(1).to_numpy(), ) + with self.assertRaises(ValueError): + classification_model = get_multiclass_classification_pipeline(n_classes=2) + classification_model.fit( + mols, + test_data_df["Label"].to_numpy(), + ) From dd0ebbe1c5bf160bb69c514f9b5213bffd9524ec Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 28 Aug 2024 08:31:37 +0200 Subject: [PATCH 11/25] add missing parameters for docsig --- test_extras/test_chemprop/test_chemprop_pipeline.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index 22d5e0d4..4cf19486 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -143,6 +143,11 @@ def get_classification_pipeline() -> Pipeline: def get_multiclass_classification_pipeline(n_classes: int) -> Pipeline: """Get the Chemprop model pipeline for classification. + Parameters + ---------- + n_classes : int + The number of classes for model initialization. + Returns ------- Pipeline From d744d4d33bd21d0cff9d846272bb15eb68cc8772 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 28 Aug 2024 08:39:11 +0200 Subject: [PATCH 12/25] code review requests --- molpipeline/estimators/chemprop/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index de22511e..6b81ce92 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -112,7 +112,7 @@ def _is_multiclass_classifier(self) -> bool: bool True if the model is a multiclass classifier, False otherwise. """ - if isinstance(self.model.predictor, MulticlassClassificationFFN): + if isinstance(self.model.predictor, MulticlassClassificationFFN) and self.n_classes > 2: return True return False @@ -351,7 +351,7 @@ def __init__( class ChempropMulticlassClassifier(ChempropModel): - """Chemprop model with default parameters for regression tasks.""" + """Chemprop model with default parameters for multiclass classification tasks.""" def __init__( self, @@ -362,7 +362,7 @@ def __init__( n_jobs: int = 1, **kwargs: Any, ) -> None: - """Initialize the chemprop regressor model. + """Initialize the chemprop multiclass model. Parameters ---------- @@ -395,7 +395,7 @@ def __init__( self.n_classes = n_classes def set_params(self, **params: Any) -> Self: - """Set the parameters of the model and check if it is a binary classifier. + """Set the parameters of the model and check if it is a multiclass classifier. Parameters ---------- From e404579c6d8df7f94c8cf53c0d0e7b0b8b3065fa Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 28 Aug 2024 15:06:00 +0200 Subject: [PATCH 13/25] Adapt Eror message --- molpipeline/estimators/chemprop/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 6b81ce92..d0b4f83b 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -410,7 +410,7 @@ def set_params(self, **params: Any) -> Self: super().set_params(**params) if not self._is_multiclass_classifier(): raise ValueError( - "ChempropMulticlassClassifier should contain more than 2 classes." + "The model's predictor or the number of classes are invalid. Use a multiclass predictor and more than 2 classes." ) return self From 2b2d687a1e1fe861695321a3831cedd1c1528373 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 28 Aug 2024 15:08:17 +0200 Subject: [PATCH 14/25] check classifier in init --- molpipeline/estimators/chemprop/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index d0b4f83b..5036755b 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -393,6 +393,7 @@ def __init__( **kwargs, ) self.n_classes = n_classes + self._is_multiclass_classifier() def set_params(self, **params: Any) -> Self: """Set the parameters of the model and check if it is a multiclass classifier. From f87d68b614f3cfcb73c593fe2f27476d580c4299 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 28 Aug 2024 15:12:31 +0200 Subject: [PATCH 15/25] docstring adaptations --- molpipeline/estimators/chemprop/models.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 5036755b..c724b3af 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -444,13 +444,13 @@ def _check_correct_input( Parameters ---------- - y : _type_ + y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64] Indended classes for the dataset Raises ------ ValueError - if the classes found in y are not matching n_classes or if the class labels do not start from 0 to n_classes-1 + If the classes found in y are not matching n_classes or if the class labels do not start from 0 to n_classes-1. """ unique_y = np.unique(y) log = [] @@ -460,7 +460,6 @@ def _check_correct_input( ) if sorted(unique_y) != list(range(self.n_classes)): err = f"Classes need to be in the range from 0 to {self.n_classes-1}. Found {unique_y}. Please correct the input data accordingly." - print(err) log.append(err) if log: raise ValueError("\n".join(log)) From 7faedc19c67c5f06b80a3e569f7ef6cbf28f4fca Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 28 Aug 2024 15:17:13 +0200 Subject: [PATCH 16/25] fix docstings and naming in tests --- .../test_chemprop/test_chemprop_pipeline.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index 4cf19486..646ac99c 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -141,7 +141,7 @@ def get_classification_pipeline() -> Pipeline: def get_multiclass_classification_pipeline(n_classes: int) -> Pipeline: - """Get the Chemprop model pipeline for classification. + """Get the Chemprop model pipeline for multiclass classification. Parameters ---------- @@ -151,7 +151,7 @@ def get_multiclass_classification_pipeline(n_classes: int) -> Pipeline: Returns ------- Pipeline - The Chemprop model pipeline for classification. + The Chemprop model pipeline for multiclass classification. """ smiles2mol = SmilesToMol() mol2chemprop = MolToChemprop() @@ -343,16 +343,14 @@ def test_prediction(self) -> None: class TestMulticlassClassificationPipeline(unittest.TestCase): - """Test the Chemprop model pipeline for classification.""" + """Test the Chemprop model pipeline for multiclass classification.""" def test_prediction(self) -> None: - """Test the prediction of the classification model.""" + """Test the prediction of the multiclass classification model.""" test_data_df = pd.read_csv( TEST_DATA_DIR / "multiclass_mock.tsv", sep="\t", index_col=False ) - print(test_data_df.head()) - print(test_data_df.columns) classification_model = get_multiclass_classification_pipeline(n_classes=3) mols = test_data_df["Molecule"].tolist() classification_model.fit( @@ -369,12 +367,13 @@ def test_prediction(self) -> None: pred_copy = model_copy.predict(mols) proba_copy = model_copy.predict_proba(mols) - nan_indices = np.isnan(pred) - self.assertListEqual(nan_indices.tolist(), np.isnan(pred_copy).tolist()) - self.assertTrue(np.allclose(pred[~nan_indices], pred_copy[~nan_indices])) + nan_mask = np.isnan(pred) + self.assertListEqual(nan_mask.tolist(), np.isnan(pred_copy).tolist()) + self.assertTrue(np.allclose(pred[~nan_mask], pred_copy[~nan_mask])) self.assertEqual(proba.shape, proba_copy.shape) - self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) + self.assertEqual(pred.shape, pred_copy.shape) + self.assertTrue(np.allclose(proba[~nan_mask], proba_copy[~nan_mask])) with self.assertRaises(ValueError): classification_model.fit( From 4834614c3083721f2bdf836ea8b7b265304b2537 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Fri, 30 Aug 2024 15:00:12 +0200 Subject: [PATCH 17/25] split instace check from validation --- molpipeline/estimators/chemprop/models.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index c724b3af..b464af83 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -112,7 +112,7 @@ def _is_multiclass_classifier(self) -> bool: bool True if the model is a multiclass classifier, False otherwise. """ - if isinstance(self.model.predictor, MulticlassClassificationFFN) and self.n_classes > 2: + if isinstance(self.model.predictor, MulticlassClassificationFFN): return True return False @@ -385,6 +385,7 @@ def __init__( agg = SumAggregation() predictor = MulticlassClassificationFFN(n_classes=n_classes) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + self.n_classes = n_classes super().__init__( model=model, lightning_trainer=lightning_trainer, @@ -392,8 +393,7 @@ def __init__( n_jobs=n_jobs, **kwargs, ) - self.n_classes = n_classes - self._is_multiclass_classifier() + self._is_valid_multiclass_classifier() def set_params(self, **params: Any) -> Self: """Set the parameters of the model and check if it is a multiclass classifier. @@ -409,7 +409,7 @@ def set_params(self, **params: Any) -> Self: The model with the new parameters. """ super().set_params(**params) - if not self._is_multiclass_classifier(): + if not self._is_valid_multiclass_classifier(): raise ValueError( "The model's predictor or the number of classes are invalid. Use a multiclass predictor and more than 2 classes." ) @@ -463,3 +463,15 @@ def _check_correct_input( log.append(err) if log: raise ValueError("\n".join(log)) + + def _is_valid_multiclass_classifier(self) -> bool: + """Check if a multiclass classifier is valid. Needs to be of the correct class and have more than 2 classes. + + Returns + ------- + bool + True if is a valid multiclass classifier, False otherwise. + """ + has_correct_class = self._is_multiclass_classifier() + has_classes = self.n_classes > 2 + return has_correct_class and has_classes From 261d7db2e63177a4a0a3f83e860751a7f65031da Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 15:52:23 +0200 Subject: [PATCH 18/25] add test for set_params and initialize Multiclass FFN properlky --- .../estimators/chemprop/component_wrapper.py | 32 +++++++++ molpipeline/estimators/chemprop/models.py | 24 +++++-- .../chemprop_test_utils/constant_vars.py | 64 ++++++++++++++++++ test_extras/test_chemprop/test_models.py | 65 +++++++++++++++++++ 4 files changed, 180 insertions(+), 5 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index f7182aa2..b4f1b3cb 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -163,6 +163,7 @@ def __init__( task_weights: Tensor | None = None, threshold: float | None = None, output_transform: UnscaleTransform | None = None, + **kwargs: Any, ): """Initialize the BinaryClassificationFFN class. @@ -200,6 +201,7 @@ def __init__( activation=activation, criterion=criterion, output_transform=output_transform, + **kwargs, ) self.n_tasks = n_tasks self._input_dim = input_dim @@ -323,6 +325,36 @@ class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN _T_default_criterion = CrossEntropyLoss _T_default_metric = CrossEntropyMetric + def __init__( + self, + n_classes: int, + n_tasks: int = 1, + input_dim: int = DEFAULT_HIDDEN_DIM, + hidden_dim: int = 300, + n_layers: int = 1, + dropout: float = 0.0, + activation: str = "relu", + criterion: LossFunction | None = None, + task_weights: Tensor | None = None, + threshold: float | None = None, + output_transform: UnscaleTransform | None = None, + ): + super().__init__( + n_tasks * n_classes, + input_dim, + hidden_dim, + n_layers, + dropout, + activation, + criterion, + task_weights, + threshold, + output_transform, + n_classes=n_classes + ) + + self.n_classes = n_classes + class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type: ignore """A wrapper for the MulticlassDirichletFFN class.""" diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index b464af83..797bd733 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -17,7 +17,6 @@ from chemprop.data import MoleculeDataset, build_dataloader from chemprop.nn.predictors import ( BinaryClassificationFFNBase, - MulticlassClassificationFFN, ) from lightning import pytorch as pl except ImportError as error: @@ -34,6 +33,7 @@ BondMessagePassing, RegressionFFN, SumAggregation, + MulticlassClassificationFFN, ) from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP @@ -385,7 +385,10 @@ def __init__( agg = SumAggregation() predictor = MulticlassClassificationFFN(n_classes=n_classes) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - self.n_classes = n_classes + if n_classes != predictor.n_classes: + raise ValueError( + "The number of classes in the predictor does not match the number of classes." + ) super().__init__( model=model, lightning_trainer=lightning_trainer, @@ -395,6 +398,17 @@ def __init__( ) self._is_valid_multiclass_classifier() + @property + def n_classes(self) -> int: + """Return the number of classes.""" + return self.model.predictor.n_classes + + @n_classes.setter + def n_classes(self, n_classes: int) -> None: + """Set the number of classes.""" + self.model.predictor.n_classes = n_classes + self.model.reinitialize_network() + def set_params(self, **params: Any) -> Self: """Set the parameters of the model and check if it is a multiclass classifier. @@ -465,13 +479,13 @@ def _check_correct_input( raise ValueError("\n".join(log)) def _is_valid_multiclass_classifier(self) -> bool: - """Check if a multiclass classifier is valid. Needs to be of the correct class and have more than 2 classes. + """Check if a multiclass classifier is valid. Model FFN needs to be of the correct class and model needs to have more than 2 classes. Returns ------- bool True if is a valid multiclass classifier, False otherwise. """ - has_correct_class = self._is_multiclass_classifier() + has_correct_model = isinstance(self.model.predictor, MulticlassClassificationFFN) has_classes = self.n_classes > 2 - return has_correct_class and has_classes + return has_correct_model and has_classes diff --git a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py index 89663866..2af7e6b9 100644 --- a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py +++ b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py @@ -94,3 +94,67 @@ "model__predictor__threshold": None, "n_jobs": 1, } + +DEFAULT_SET_PARAMS = { + "batch_size": 64, + "lightning_trainer": None, + "lightning_trainer__enable_checkpointing": False, + "lightning_trainer__enable_model_summary": False, + "lightning_trainer__max_epochs": 500, + "lightning_trainer__accelerator": "cpu", + "lightning_trainer__default_root_dir": None, + "lightning_trainer__limit_predict_batches": 1.0, + "lightning_trainer__detect_anomaly": False, + "lightning_trainer__reload_dataloaders_every_n_epochs": 0, + "lightning_trainer__precision": "32-true", + "lightning_trainer__min_steps": None, + "lightning_trainer__max_time": None, + "lightning_trainer__limit_train_batches": 1.0, + "lightning_trainer__strategy": "auto", + "lightning_trainer__gradient_clip_algorithm": None, + "lightning_trainer__log_every_n_steps": 50, + "lightning_trainer__limit_val_batches": 1.0, + "lightning_trainer__gradient_clip_val": None, + "lightning_trainer__overfit_batches": 0.0, + "lightning_trainer__num_nodes": 1, + "lightning_trainer__use_distributed_sampler": True, + "lightning_trainer__check_val_every_n_epoch": 1, + "lightning_trainer__benchmark": False, + "lightning_trainer__inference_mode": True, + "lightning_trainer__limit_test_batches": 1.0, + "lightning_trainer__fast_dev_run": False, + "lightning_trainer__logger": None, + "lightning_trainer__max_steps": -1, + "lightning_trainer__num_sanity_val_steps": 2, + "lightning_trainer__devices": "auto", + "lightning_trainer__min_epochs": None, + "lightning_trainer__val_check_interval": 1.0, + "lightning_trainer__barebones": False, + "lightning_trainer__accumulate_grad_batches": 1, + "lightning_trainer__deterministic": False, + "lightning_trainer__enable_progress_bar": True, + "model__agg__dim": 0, + "model__batch_norm": True, + "model__final_lr": 0.0001, + "model__init_lr": 0.0001, + "model__max_lr": 0.001, + "model__message_passing__activation": "relu", + "model__message_passing__bias": False, + "model__message_passing__d_e": 14, + "model__message_passing__d_h": 300, + "model__message_passing__d_v": 72, + "model__message_passing__d_vd": None, + "model__message_passing__depth": 3, + "model__message_passing__dropout_rate": 0.0, + "model__message_passing__undirected": False, + "model__metric_list": None, + "model__predictor__activation": "relu", + "model__warmup_epochs": 2, + "model__predictor__dropout": 0, + "model__predictor__hidden_dim": 300, + "model__predictor__input_dim": 300, + "model__predictor__n_layers": 1, + "model__predictor__n_tasks": 1, + "model__predictor__threshold": None, + "n_jobs": 1, +} diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index d6eb6df9..7348f7f1 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -10,12 +10,17 @@ from molpipeline.estimators.chemprop.component_wrapper import ( MPNN, + BondMessagePassing, MeanAggregation, + MulticlassClassificationFFN, + BinaryClassificationFFN, RegressionFFN, + SumAggregation, ) from molpipeline.estimators.chemprop.models import ( ChempropClassifier, ChempropModel, + ChempropMulticlassClassifier, ChempropRegressor, ) from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP @@ -26,6 +31,7 @@ from test_extras.test_chemprop.chemprop_test_utils.constant_vars import ( DEFAULT_PARAMS, NO_IDENTITY_CHECK, + DEFAULT_SET_PARAMS, ) from test_extras.test_chemprop.chemprop_test_utils.default_models import ( get_chemprop_model_binary_classification_mpnn, @@ -127,6 +133,7 @@ def test_json_serialization(self) -> None: ) + class TestChempropClassifier(unittest.TestCase): """Test the Chemprop classifier model.""" @@ -150,6 +157,13 @@ def test_get_params(self) -> None: self.assertEqual( param_dict[param_name], param, f"Test failed for {param_name}" ) + def test_set_params(self) -> None: + """Test the set_params methods.""" + chemprop_model = ChempropClassifier(lightning_trainer__accelerator="cpu") + chemprop_model.set_params(**DEFAULT_SET_PARAMS) + current_params = chemprop_model.get_params(deep=True) + for param,value in DEFAULT_SET_PARAMS.items(): + self.assertEqual(current_params[param], value) class TestChempropRegressor(unittest.TestCase): @@ -177,3 +191,54 @@ def test_get_params(self) -> None: self.assertEqual( param_dict[param_name], param, f"Test failed for {param_name}" ) + +class TestChempropMulticlassClassifier(unittest.TestCase): + """Test the Chemprop classifier model.""" + + def test_get_params(self) -> None: + """Test the get_params and set_params methods.""" + n_classes = 3 + chemprop_model = ChempropMulticlassClassifier(lightning_trainer__accelerator="cpu", n_classes=n_classes) + param_dict = chemprop_model.get_params(deep=True) + expected_params = dict(DEFAULT_PARAMS) # Shallow copy + expected_params["n_classes"] = n_classes + self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) + for param_name, param in expected_params.items(): + if param_name in NO_IDENTITY_CHECK: + if isinstance(param, Iterable): + self.assertIsInstance(param_dict[param_name], type(param)) + for i, p in enumerate(param): + self.assertIsInstance(param_dict[param_name][i], p) + elif isinstance(param, type): + self.assertIsInstance(param_dict[param_name], param) + else: + raise ValueError(f"{param_name} should be a type.") + else: + self.assertEqual( + param_dict[param_name], param, f"Test failed for {param_name}" + ) + + def test_set_params(self) -> None: + """Test the set_params methods.""" + chemprop_model = ChempropMulticlassClassifier(lightning_trainer__accelerator="cpu", n_classes=3) + chemprop_model.set_params(**DEFAULT_SET_PARAMS) + params = {"n_classes": 4, "batch_size": 20, "lightning_trainer__max_epochs":10, "model__predictor__n_layers":2} + chemprop_model.set_params(**params) + current_params = chemprop_model.get_params(deep=True) + for param,value in params.items(): + self.assertEqual(current_params[param], value) + + def test_error_for_multiclass_predictor(self) -> None: + """Test the error for using a multiclass predictor for a binary classification model.""" + chemprop_model = ChempropMulticlassClassifier(lightning_trainer__accelerator="cpu", n_classes=3) + with self.assertRaises(ValueError): + chemprop_model.set_params(model__predictor=RegressionFFN) + bond_encoder = BondMessagePassing() + agg = SumAggregation() + predictor = MulticlassClassificationFFN(n_classes=2) + with self.assertRaises(ValueError): + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + predictor = RegressionFFN() + with self.assertRaises(ValueError): + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + From 4e841118212a54a16c516771578556cd28d89ce9 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 16:22:43 +0200 Subject: [PATCH 19/25] raise attribute error if wrong model.predictor is passed --- molpipeline/estimators/chemprop/models.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 797bd733..d5435c97 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -379,13 +379,22 @@ def __init__( kwargs : Any Parameters set using `set_params`. Can be used to modify components of the model. + + Raises + ------ + AttributeError + If the passed model.predictor does not have an attribute n_classes. + ValueError + If the number of classes in the predictor does not match the number of classes given as attribute. """ if model is None: bond_encoder = BondMessagePassing() agg = SumAggregation() predictor = MulticlassClassificationFFN(n_classes=n_classes) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - if n_classes != predictor.n_classes: + if not hasattr(model.predictor, "n_classes"): + raise AttributeError("The predictor does not have an attribute n_classes.Please use a MulticlassClassificationFFN predictor or define n_classes.") + if n_classes != model.predictor.n_classes: raise ValueError( "The number of classes in the predictor does not match the number of classes." ) From c5810fc5fe8432f93878d4595107ce9d99e64a21 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 16:23:01 +0200 Subject: [PATCH 20/25] test multiclass setter and getter --- .../chemprop_test_utils/constant_vars.py | 79 ++++++++++++++++++- test_extras/test_chemprop/test_models.py | 31 +++++--- 2 files changed, 97 insertions(+), 13 deletions(-) diff --git a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py index 2af7e6b9..51b72743 100644 --- a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py +++ b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py @@ -7,8 +7,10 @@ MPNN, BinaryClassificationFFN, BondMessagePassing, + MulticlassClassificationFFN, SumAggregation, ) +from molpipeline.estimators.chemprop.loss_wrapper import CrossEntropyLoss # These are model parameters which are copied by value, but are too complex to check for equality. # Thus, for these model parameters, only the type is checked. @@ -23,7 +25,7 @@ # Default parameters for the Chemprop model. -DEFAULT_PARAMS = { +DEFAULT_BINARY_CLASSIFICATION_PARAMS = { "batch_size": 64, "lightning_trainer": None, "lightning_trainer__enable_checkpointing": False, @@ -158,3 +160,78 @@ "model__predictor__threshold": None, "n_jobs": 1, } + + +DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS = { + "batch_size": 64, + "lightning_trainer": None, + "lightning_trainer__enable_checkpointing": False, + "lightning_trainer__enable_model_summary": False, + "lightning_trainer__max_epochs": 500, + "lightning_trainer__accelerator": "cpu", + "lightning_trainer__default_root_dir": None, + "lightning_trainer__limit_predict_batches": 1.0, + "lightning_trainer__detect_anomaly": False, + "lightning_trainer__reload_dataloaders_every_n_epochs": 0, + "lightning_trainer__precision": "32-true", + "lightning_trainer__min_steps": None, + "lightning_trainer__max_time": None, + "lightning_trainer__limit_train_batches": 1.0, + "lightning_trainer__strategy": "auto", + "lightning_trainer__gradient_clip_algorithm": None, + "lightning_trainer__log_every_n_steps": 50, + "lightning_trainer__limit_val_batches": 1.0, + "lightning_trainer__gradient_clip_val": None, + "lightning_trainer__overfit_batches": 0.0, + "lightning_trainer__num_nodes": 1, + "lightning_trainer__use_distributed_sampler": True, + "lightning_trainer__check_val_every_n_epoch": 1, + "lightning_trainer__benchmark": False, + "lightning_trainer__inference_mode": True, + "lightning_trainer__limit_test_batches": 1.0, + "lightning_trainer__fast_dev_run": False, + "lightning_trainer__logger": None, + "lightning_trainer__max_steps": -1, + "lightning_trainer__num_sanity_val_steps": 2, + "lightning_trainer__devices": "auto", + "lightning_trainer__min_epochs": None, + "lightning_trainer__val_check_interval": 1.0, + "lightning_trainer__barebones": False, + "lightning_trainer__accumulate_grad_batches": 1, + "lightning_trainer__deterministic": False, + "lightning_trainer__enable_progress_bar": True, + "model": MPNN, + "model__agg__dim": 0, + "model__agg": SumAggregation, + "model__batch_norm": True, + "model__final_lr": 0.0001, + "model__init_lr": 0.0001, + "model__max_lr": 0.001, + "model__message_passing__activation": "relu", + "model__message_passing__bias": False, + "model__message_passing__d_e": 14, + "model__message_passing__d_h": 300, + "model__message_passing__d_v": 72, + "model__message_passing__d_vd": None, + "model__message_passing__depth": 3, + "model__message_passing__dropout_rate": 0.0, + "model__message_passing__undirected": False, + "model__message_passing": BondMessagePassing, + "model__metric_list": None, + "model__predictor__activation": "relu", + "model__warmup_epochs": 2, + "model__predictor": MulticlassClassificationFFN, + "model__predictor__criterion": CrossEntropyLoss, + "model__predictor__criterion__task_weights": Tensor([1.0, 1.0, 1.0]), + "model__predictor__dropout": 0, + "model__predictor__hidden_dim": 300, + "model__predictor__input_dim": 300, + "model__predictor__n_classes" : 3, + "model__predictor__n_layers": 1, + "model__predictor__n_tasks": 1, + "model__predictor__output_transform": nn.Identity, + "model__predictor__task_weights": Tensor([1.0, 1.0, 1.0]), + "model__predictor__threshold": None, + "n_classes" : 3, + "n_jobs": 1, +} diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index 7348f7f1..56564148 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -29,9 +29,10 @@ # pylint: disable=relative-beyond-top-level from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params from test_extras.test_chemprop.chemprop_test_utils.constant_vars import ( - DEFAULT_PARAMS, NO_IDENTITY_CHECK, DEFAULT_SET_PARAMS, + DEFAULT_BINARY_CLASSIFICATION_PARAMS, + DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS ) from test_extras.test_chemprop.chemprop_test_utils.default_models import ( get_chemprop_model_binary_classification_mpnn, @@ -47,7 +48,7 @@ def test_get_params(self) -> None: """Test the get_params and set_params methods.""" chemprop_model = get_chemprop_model_binary_classification_mpnn() orig_params = chemprop_model.get_params(deep=True) - expected_params = dict(DEFAULT_PARAMS) # Shallow copy + expected_params = dict(DEFAULT_BINARY_CLASSIFICATION_PARAMS) # Shallow copy self.assertSetEqual(set(orig_params), set(expected_params)) # Check if the parameters are as expected @@ -114,8 +115,8 @@ def test_json_serialization(self) -> None: chemprop_model_copy = recursive_from_json(chemprop_json) param_dict = chemprop_model_copy.get_params(deep=True) - self.assertSetEqual(set(param_dict.keys()), set(DEFAULT_PARAMS.keys())) - for param_name, param in DEFAULT_PARAMS.items(): + self.assertSetEqual(set(param_dict.keys()), set(DEFAULT_BINARY_CLASSIFICATION_PARAMS.keys())) + for param_name, param in DEFAULT_BINARY_CLASSIFICATION_PARAMS.items(): if param_name in NO_IDENTITY_CHECK: if isinstance(param, Iterable): self.assertIsInstance(param_dict[param_name], type(param)) @@ -141,7 +142,7 @@ def test_get_params(self) -> None: """Test the get_params and set_params methods.""" chemprop_model = ChempropClassifier(lightning_trainer__accelerator="cpu") param_dict = chemprop_model.get_params(deep=True) - expected_params = dict(DEFAULT_PARAMS) # Shallow copy + expected_params = dict(DEFAULT_BINARY_CLASSIFICATION_PARAMS) # Shallow copy self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) for param_name, param in expected_params.items(): if param_name in NO_IDENTITY_CHECK: @@ -173,7 +174,7 @@ def test_get_params(self) -> None: """Test the get_params and set_params methods.""" chemprop_model = ChempropRegressor(lightning_trainer__accelerator="cpu") param_dict = chemprop_model.get_params(deep=True) - expected_params = dict(DEFAULT_PARAMS) + expected_params = dict(DEFAULT_BINARY_CLASSIFICATION_PARAMS) expected_params["model__predictor"] = RegressionFFN expected_params["model__predictor__criterion"] = MSELoss self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) @@ -200,7 +201,8 @@ def test_get_params(self) -> None: n_classes = 3 chemprop_model = ChempropMulticlassClassifier(lightning_trainer__accelerator="cpu", n_classes=n_classes) param_dict = chemprop_model.get_params(deep=True) - expected_params = dict(DEFAULT_PARAMS) # Shallow copy + expected_params = dict(DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS) # Shallow copy + expected_params["model__predictor__n_classes"] = n_classes expected_params["n_classes"] = n_classes self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) for param_name, param in expected_params.items(): @@ -213,6 +215,8 @@ def test_get_params(self) -> None: self.assertIsInstance(param_dict[param_name], param) else: raise ValueError(f"{param_name} should be a type.") + elif (isinstance(param, torch.Tensor)): + self.assertTrue(torch.allclose(param_dict[param_name], param)) else: self.assertEqual( param_dict[param_name], param, f"Test failed for {param_name}" @@ -230,15 +234,18 @@ def test_set_params(self) -> None: def test_error_for_multiclass_predictor(self) -> None: """Test the error for using a multiclass predictor for a binary classification model.""" - chemprop_model = ChempropMulticlassClassifier(lightning_trainer__accelerator="cpu", n_classes=3) - with self.assertRaises(ValueError): - chemprop_model.set_params(model__predictor=RegressionFFN) bond_encoder = BondMessagePassing() agg = SumAggregation() - predictor = MulticlassClassificationFFN(n_classes=2) with self.assertRaises(ValueError): + predictor = MulticlassClassificationFFN(n_classes=2) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - predictor = RegressionFFN() + chemprop_model = ChempropMulticlassClassifier(n_classes=2, model=model) with self.assertRaises(ValueError): + predictor = MulticlassClassificationFFN(n_classes=3) + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + chemprop_model = ChempropMulticlassClassifier(n_classes=4, model=model) + with self.assertRaises(AttributeError): + predictor = RegressionFFN() model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + chemprop_model = ChempropMulticlassClassifier(n_classes=4, model=model) From a064100b430fa040d1fe721268388415c19bd1ca Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 16:23:11 +0200 Subject: [PATCH 21/25] pass correct tasks --- molpipeline/estimators/chemprop/component_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index b4f1b3cb..8840f2d8 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -340,7 +340,7 @@ def __init__( output_transform: UnscaleTransform | None = None, ): super().__init__( - n_tasks * n_classes, + n_tasks, input_dim, hidden_dim, n_layers, From 33e120262339e2ce8b6328ed02186a137ad99f92 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 16:37:18 +0200 Subject: [PATCH 22/25] black --- .../estimators/chemprop/component_wrapper.py | 2 +- molpipeline/estimators/chemprop/models.py | 12 ++++--- .../chemprop_test_utils/constant_vars.py | 6 ++-- test_extras/test_chemprop/test_models.py | 35 ++++++++++++------- 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index 8840f2d8..bfc8b065 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -350,7 +350,7 @@ def __init__( task_weights, threshold, output_transform, - n_classes=n_classes + n_classes=n_classes, ) self.n_classes = n_classes diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index d5435c97..3fce11df 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -379,7 +379,7 @@ def __init__( kwargs : Any Parameters set using `set_params`. Can be used to modify components of the model. - + Raises ------ AttributeError @@ -393,7 +393,9 @@ def __init__( predictor = MulticlassClassificationFFN(n_classes=n_classes) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) if not hasattr(model.predictor, "n_classes"): - raise AttributeError("The predictor does not have an attribute n_classes.Please use a MulticlassClassificationFFN predictor or define n_classes.") + raise AttributeError( + "The predictor does not have an attribute n_classes.Please use a MulticlassClassificationFFN predictor or define n_classes." + ) if n_classes != model.predictor.n_classes: raise ValueError( "The number of classes in the predictor does not match the number of classes." @@ -411,7 +413,7 @@ def __init__( def n_classes(self) -> int: """Return the number of classes.""" return self.model.predictor.n_classes - + @n_classes.setter def n_classes(self, n_classes: int) -> None: """Set the number of classes.""" @@ -495,6 +497,8 @@ def _is_valid_multiclass_classifier(self) -> bool: bool True if is a valid multiclass classifier, False otherwise. """ - has_correct_model = isinstance(self.model.predictor, MulticlassClassificationFFN) + has_correct_model = isinstance( + self.model.predictor, MulticlassClassificationFFN + ) has_classes = self.n_classes > 2 return has_correct_model and has_classes diff --git a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py index 51b72743..7861024d 100644 --- a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py +++ b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py @@ -225,13 +225,13 @@ "model__predictor__criterion__task_weights": Tensor([1.0, 1.0, 1.0]), "model__predictor__dropout": 0, "model__predictor__hidden_dim": 300, - "model__predictor__input_dim": 300, - "model__predictor__n_classes" : 3, + "model__predictor__input_dim": 300, + "model__predictor__n_classes": 3, "model__predictor__n_layers": 1, "model__predictor__n_tasks": 1, "model__predictor__output_transform": nn.Identity, "model__predictor__task_weights": Tensor([1.0, 1.0, 1.0]), "model__predictor__threshold": None, - "n_classes" : 3, + "n_classes": 3, "n_jobs": 1, } diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index 56564148..66a964e6 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -32,7 +32,7 @@ NO_IDENTITY_CHECK, DEFAULT_SET_PARAMS, DEFAULT_BINARY_CLASSIFICATION_PARAMS, - DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS + DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS, ) from test_extras.test_chemprop.chemprop_test_utils.default_models import ( get_chemprop_model_binary_classification_mpnn, @@ -115,7 +115,9 @@ def test_json_serialization(self) -> None: chemprop_model_copy = recursive_from_json(chemprop_json) param_dict = chemprop_model_copy.get_params(deep=True) - self.assertSetEqual(set(param_dict.keys()), set(DEFAULT_BINARY_CLASSIFICATION_PARAMS.keys())) + self.assertSetEqual( + set(param_dict.keys()), set(DEFAULT_BINARY_CLASSIFICATION_PARAMS.keys()) + ) for param_name, param in DEFAULT_BINARY_CLASSIFICATION_PARAMS.items(): if param_name in NO_IDENTITY_CHECK: if isinstance(param, Iterable): @@ -134,7 +136,6 @@ def test_json_serialization(self) -> None: ) - class TestChempropClassifier(unittest.TestCase): """Test the Chemprop classifier model.""" @@ -158,12 +159,13 @@ def test_get_params(self) -> None: self.assertEqual( param_dict[param_name], param, f"Test failed for {param_name}" ) + def test_set_params(self) -> None: """Test the set_params methods.""" chemprop_model = ChempropClassifier(lightning_trainer__accelerator="cpu") chemprop_model.set_params(**DEFAULT_SET_PARAMS) current_params = chemprop_model.get_params(deep=True) - for param,value in DEFAULT_SET_PARAMS.items(): + for param, value in DEFAULT_SET_PARAMS.items(): self.assertEqual(current_params[param], value) @@ -193,13 +195,16 @@ def test_get_params(self) -> None: param_dict[param_name], param, f"Test failed for {param_name}" ) + class TestChempropMulticlassClassifier(unittest.TestCase): """Test the Chemprop classifier model.""" def test_get_params(self) -> None: """Test the get_params and set_params methods.""" n_classes = 3 - chemprop_model = ChempropMulticlassClassifier(lightning_trainer__accelerator="cpu", n_classes=n_classes) + chemprop_model = ChempropMulticlassClassifier( + lightning_trainer__accelerator="cpu", n_classes=n_classes + ) param_dict = chemprop_model.get_params(deep=True) expected_params = dict(DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS) # Shallow copy expected_params["model__predictor__n_classes"] = n_classes @@ -215,23 +220,30 @@ def test_get_params(self) -> None: self.assertIsInstance(param_dict[param_name], param) else: raise ValueError(f"{param_name} should be a type.") - elif (isinstance(param, torch.Tensor)): + elif isinstance(param, torch.Tensor): self.assertTrue(torch.allclose(param_dict[param_name], param)) else: self.assertEqual( param_dict[param_name], param, f"Test failed for {param_name}" ) - + def test_set_params(self) -> None: """Test the set_params methods.""" - chemprop_model = ChempropMulticlassClassifier(lightning_trainer__accelerator="cpu", n_classes=3) + chemprop_model = ChempropMulticlassClassifier( + lightning_trainer__accelerator="cpu", n_classes=3 + ) chemprop_model.set_params(**DEFAULT_SET_PARAMS) - params = {"n_classes": 4, "batch_size": 20, "lightning_trainer__max_epochs":10, "model__predictor__n_layers":2} + params = { + "n_classes": 4, + "batch_size": 20, + "lightning_trainer__max_epochs": 10, + "model__predictor__n_layers": 2, + } chemprop_model.set_params(**params) current_params = chemprop_model.get_params(deep=True) - for param,value in params.items(): + for param, value in params.items(): self.assertEqual(current_params[param], value) - + def test_error_for_multiclass_predictor(self) -> None: """Test the error for using a multiclass predictor for a binary classification model.""" bond_encoder = BondMessagePassing() @@ -248,4 +260,3 @@ def test_error_for_multiclass_predictor(self) -> None: predictor = RegressionFFN() model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) chemprop_model = ChempropMulticlassClassifier(n_classes=4, model=model) - From 4a117d5a79121e279f210ece6a203a45d0e76713 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 16:39:47 +0200 Subject: [PATCH 23/25] docsig and pydocstyle --- .../estimators/chemprop/component_wrapper.py | 27 +++++++++++++++++++ molpipeline/estimators/chemprop/models.py | 8 +++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index bfc8b065..bd4481e1 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -339,6 +339,33 @@ def __init__( threshold: float | None = None, output_transform: UnscaleTransform | None = None, ): + """Initialize the MulticlassClassificationFFN class. + + Parameters + ---------- + n_classes : int + how many classes are expected in the output + n_tasks : int, optional (default=1) + Number of tasks. + input_dim : int, optional (default=DEFAULT_HIDDEN_DIM) + Input dimension. + hidden_dim : int, optional (default=300) + Hidden dimension. + n_layers : int, optional (default=1) + Number of layers. + dropout : float, optional (default=0) + Dropout rate. + activation : str, optional (default="relu") + Activation function. + criterion : LossFunction or None, optional (default=None) + Loss function. None defaults to BCELoss. + task_weights : Tensor or None, optional (default=None) + Task weights. + threshold : float or None, optional (default=None) + Threshold for binary classification. + output_transform : UnscaleTransform or None, optional (default=None) + Transformations to apply to the output. None defaults to UnscaleTransform. + """ super().__init__( n_tasks, input_dim, diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 3fce11df..c2e09d2c 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -416,7 +416,13 @@ def n_classes(self) -> int: @n_classes.setter def n_classes(self, n_classes: int) -> None: - """Set the number of classes.""" + """Set the number of classes. + + Parameters + ---------- + n_classes : int + number of classes + """ self.model.predictor.n_classes = n_classes self.model.reinitialize_network() From fdb1d311fc10983881f51c55ef7ca3001f11f251 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 16:49:30 +0200 Subject: [PATCH 24/25] lint: docstrings and tests --- molpipeline/estimators/chemprop/component_wrapper.py | 2 ++ test_extras/test_chemprop/test_models.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index bd4481e1..83f8e2a6 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -189,6 +189,8 @@ def __init__( Threshold for binary classification. output_transform : UnscaleTransform or None, optional (default=None) Transformations to apply to the output. None defaults to UnscaleTransform. + kwargs : Any + Additional keyword arguments. """ if task_weights is None: task_weights = torch.ones(n_tasks) diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index 66a964e6..9afaf111 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -13,7 +13,6 @@ BondMessagePassing, MeanAggregation, MulticlassClassificationFFN, - BinaryClassificationFFN, RegressionFFN, SumAggregation, ) @@ -251,12 +250,12 @@ def test_error_for_multiclass_predictor(self) -> None: with self.assertRaises(ValueError): predictor = MulticlassClassificationFFN(n_classes=2) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - chemprop_model = ChempropMulticlassClassifier(n_classes=2, model=model) + ChempropMulticlassClassifier(n_classes=2, model=model) with self.assertRaises(ValueError): predictor = MulticlassClassificationFFN(n_classes=3) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - chemprop_model = ChempropMulticlassClassifier(n_classes=4, model=model) + ChempropMulticlassClassifier(n_classes=4, model=model) with self.assertRaises(AttributeError): predictor = RegressionFFN() model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - chemprop_model = ChempropMulticlassClassifier(n_classes=4, model=model) + ChempropMulticlassClassifier(n_classes=4, model=model) From eef3d224168adbb5dbfca14d8c54ea9d6e73e6e1 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Wed, 4 Sep 2024 13:34:52 +0200 Subject: [PATCH 25/25] missing space --- molpipeline/estimators/chemprop/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index c2e09d2c..b94bcb02 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -394,7 +394,7 @@ def __init__( model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) if not hasattr(model.predictor, "n_classes"): raise AttributeError( - "The predictor does not have an attribute n_classes.Please use a MulticlassClassificationFFN predictor or define n_classes." + "The predictor does not have an attribute n_classes. Please use a MulticlassClassificationFFN predictor or define n_classes." ) if n_classes != model.predictor.n_classes: raise ValueError(