Skip to content

Commit

Permalink
add tests for occurring bug
Browse files Browse the repository at this point in the history
  • Loading branch information
c-w-feldmann committed Nov 20, 2024
1 parent 78b0fe0 commit 8c282a4
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ def test_prediction(self) -> None:
pred_copy = model_copy.predict(molecule_net_logd_df["smiles"].tolist())
self.assertTrue(np.allclose(pred, pred_copy))

# Test single prediction, this was causing an error before
_ = regression_model.predict([molecule_net_logd_df["smiles"].iloc[0]])


class TestClassificationPipeline(unittest.TestCase):
"""Test the Chemprop model pipeline for classification."""
Expand Down Expand Up @@ -341,6 +344,9 @@ def test_prediction(self) -> None:
self.assertEqual(proba.shape, proba_copy.shape)
self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices]))

# Test single prediction, this was causing an error before
_ = classification_model.predict([molecule_net_bbbp_df["smiles"].iloc[0]])


class TestMulticlassClassificationPipeline(unittest.TestCase):
"""Test the Chemprop model pipeline for multiclass classification."""
Expand Down

0 comments on commit 8c282a4

Please sign in to comment.