Skip to content

Commit

Permalink
109 chemprop predictions fail for a single molecule (#110)
Browse files Browse the repository at this point in the history
* add tests for occurring bug
* adapt squeeze function
  • Loading branch information
c-w-feldmann authored Nov 21, 2024
1 parent 78b0fe0 commit 2a3ef6e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
3 changes: 1 addition & 2 deletions molpipeline/estimators/chemprop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def _predict(
test_data = build_dataloader(X, num_workers=self.n_jobs, shuffle=False)
predictions = self.lightning_trainer.predict(self.model, test_data)
prediction_array = np.vstack(predictions) # type: ignore
prediction_array = prediction_array.squeeze()

prediction_array = prediction_array.squeeze(axis=1)
# Check if the predictions have the same length as the input dataset
if prediction_array.shape[0] != len(X):
raise AssertionError(
Expand Down
26 changes: 26 additions & 0 deletions test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ def test_prediction(self) -> None:
pred_copy = model_copy.predict(molecule_net_logd_df["smiles"].tolist())
self.assertTrue(np.allclose(pred, pred_copy))

# Test single prediction, this was causing an error before
single_mol_pred = regression_model.predict(
[molecule_net_logd_df["smiles"].iloc[0]]
)
self.assertEqual(single_mol_pred.shape, (1,))


class TestClassificationPipeline(unittest.TestCase):
"""Test the Chemprop model pipeline for classification."""
Expand Down Expand Up @@ -341,6 +347,16 @@ def test_prediction(self) -> None:
self.assertEqual(proba.shape, proba_copy.shape)
self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices]))

# Test single prediction, this was causing an error before
single_mol_pred = classification_model.predict(
[molecule_net_bbbp_df["smiles"].iloc[0]]
)
self.assertEqual(single_mol_pred.shape, (1,))
single_mol_proba = classification_model.predict_proba(
[molecule_net_bbbp_df["smiles"].iloc[0]]
)
self.assertEqual(single_mol_proba.shape, (1, 2))


class TestMulticlassClassificationPipeline(unittest.TestCase):
"""Test the Chemprop model pipeline for multiclass classification."""
Expand Down Expand Up @@ -375,6 +391,16 @@ def test_prediction(self) -> None:
self.assertEqual(pred.shape, pred_copy.shape)
self.assertTrue(np.allclose(proba[~nan_mask], proba_copy[~nan_mask]))

# Test single prediction, this was causing an error before
single_mol_pred = classification_model.predict(
[test_data_df["Molecule"].iloc[0]]
)
self.assertEqual(single_mol_pred.shape, (1,))
single_mol_proba = classification_model.predict_proba(
[test_data_df["Molecule"].iloc[0]]
)
self.assertEqual(single_mol_proba.shape, (1, 3))

with self.assertRaises(ValueError):
classification_model.fit(
mols,
Expand Down

0 comments on commit 2a3ef6e

Please sign in to comment.