diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index 32c4e677..07fbc1e5 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -307,6 +307,9 @@ def test_prediction(self) -> None: pred_copy = model_copy.predict(molecule_net_logd_df["smiles"].tolist()) self.assertTrue(np.allclose(pred, pred_copy)) + # Test single prediction, this was causing an error before + _ = regression_model.predict([molecule_net_logd_df["smiles"].iloc[0]]) + class TestClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for classification.""" @@ -341,6 +344,9 @@ def test_prediction(self) -> None: self.assertEqual(proba.shape, proba_copy.shape) self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) + # Test single prediction, this was causing an error before + _ = classification_model.predict([molecule_net_bbbp_df["smiles"].iloc[0]]) + class TestMulticlassClassificationPipeline(unittest.TestCase): """Test the Chemprop model pipeline for multiclass classification."""