From 2a3ef6e75bd27b90e1bffed347b20c70977c1592 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:58:37 +0100 Subject: [PATCH] 109 chemprop predictions fail for a single molecule (#110) * add tests for occurring bug * adapt squeeze function --- molpipeline/estimators/chemprop/models.py | 3 +-- .../test_chemprop/test_chemprop_pipeline.py | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index e720e029..47f5e59a 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -143,8 +143,7 @@ def _predict( test_data = build_dataloader(X, num_workers=self.n_jobs, shuffle=False) predictions = self.lightning_trainer.predict(self.model, test_data) prediction_array = np.vstack(predictions) # type: ignore - prediction_array = prediction_array.squeeze() - + prediction_array = prediction_array.squeeze(axis=1) # Check if the predictions have the same length as the input dataset if prediction_array.shape[0] != len(X): raise AssertionError( diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index 32c4e677..c600a3a3 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -307,6 +307,12 @@ def test_prediction(self) -> None: pred_copy = model_copy.predict(molecule_net_logd_df["smiles"].tolist()) self.assertTrue(np.allclose(pred, pred_copy)) + # Test single prediction, this was causing an error before + single_mol_pred = regression_model.predict( + [molecule_net_logd_df["smiles"].iloc[0]] + ) + self.assertEqual(single_mol_pred.shape, (1,)) + class TestClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for classification.""" @@ -341,6 +347,16 @@ def test_prediction(self) -> None: self.assertEqual(proba.shape, proba_copy.shape) self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) + # Test single prediction, this was causing an error before + single_mol_pred = classification_model.predict( + [molecule_net_bbbp_df["smiles"].iloc[0]] + ) + self.assertEqual(single_mol_pred.shape, (1,)) + single_mol_proba = classification_model.predict_proba( + [molecule_net_bbbp_df["smiles"].iloc[0]] + ) + self.assertEqual(single_mol_proba.shape, (1, 2)) + class TestMulticlassClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for multiclass classification.""" @@ -375,6 +391,16 @@ def test_prediction(self) -> None: self.assertEqual(pred.shape, pred_copy.shape) self.assertTrue(np.allclose(proba[~nan_mask], proba_copy[~nan_mask])) + # Test single prediction, this was causing an error before + single_mol_pred = classification_model.predict( + [test_data_df["Molecule"].iloc[0]] + ) + self.assertEqual(single_mol_pred.shape, (1,)) + single_mol_proba = classification_model.predict_proba( + [test_data_df["Molecule"].iloc[0]] + ) + self.assertEqual(single_mol_proba.shape, (1, 3)) + with self.assertRaises(ValueError): classification_model.fit( mols,