From fdb1d311fc10983881f51c55ef7ca3001f11f251 Mon Sep 17 00:00:00 2001 From: hemmerj3 Date: Tue, 3 Sep 2024 16:49:30 +0200 Subject: [PATCH] lint: docstrings and tests --- molpipeline/estimators/chemprop/component_wrapper.py | 2 ++ test_extras/test_chemprop/test_models.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index bd4481e1..83f8e2a6 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -189,6 +189,8 @@ def __init__( Threshold for binary classification. output_transform : UnscaleTransform or None, optional (default=None) Transformations to apply to the output. None defaults to UnscaleTransform. + kwargs : Any + Additional keyword arguments. """ if task_weights is None: task_weights = torch.ones(n_tasks) diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index 66a964e6..9afaf111 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -13,7 +13,6 @@ BondMessagePassing, MeanAggregation, MulticlassClassificationFFN, - BinaryClassificationFFN, RegressionFFN, SumAggregation, ) @@ -251,12 +250,12 @@ def test_error_for_multiclass_predictor(self) -> None: with self.assertRaises(ValueError): predictor = MulticlassClassificationFFN(n_classes=2) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - chemprop_model = ChempropMulticlassClassifier(n_classes=2, model=model) + ChempropMulticlassClassifier(n_classes=2, model=model) with self.assertRaises(ValueError): predictor = MulticlassClassificationFFN(n_classes=3) model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - chemprop_model = ChempropMulticlassClassifier(n_classes=4, model=model) + ChempropMulticlassClassifier(n_classes=4, model=model) with self.assertRaises(AttributeError): predictor = RegressionFFN() model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) - chemprop_model = ChempropMulticlassClassifier(n_classes=4, model=model) + ChempropMulticlassClassifier(n_classes=4, model=model)