diff --git a/.gitignore b/.gitignore index 9ad684ea..ab5c6116 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ __pycache__ molpipeline.egg-info/ lib/ build/ +lightning_logs/ diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index f7182aa2..83f8e2a6 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -163,6 +163,7 @@ def __init__( task_weights: Tensor | None = None, threshold: float | None = None, output_transform: UnscaleTransform | None = None, + **kwargs: Any, ): """Initialize the BinaryClassificationFFN class. @@ -188,6 +189,8 @@ def __init__( Threshold for binary classification. output_transform : UnscaleTransform or None, optional (default=None) Transformations to apply to the output. None defaults to UnscaleTransform. + kwargs : Any + Additional keyword arguments. """ if task_weights is None: task_weights = torch.ones(n_tasks) @@ -200,6 +203,7 @@ def __init__( activation=activation, criterion=criterion, output_transform=output_transform, + **kwargs, ) self.n_tasks = n_tasks self._input_dim = input_dim @@ -323,6 +327,63 @@ class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN _T_default_criterion = CrossEntropyLoss _T_default_metric = CrossEntropyMetric + def __init__( + self, + n_classes: int, + n_tasks: int = 1, + input_dim: int = DEFAULT_HIDDEN_DIM, + hidden_dim: int = 300, + n_layers: int = 1, + dropout: float = 0.0, + activation: str = "relu", + criterion: LossFunction | None = None, + task_weights: Tensor | None = None, + threshold: float | None = None, + output_transform: UnscaleTransform | None = None, + ): + """Initialize the MulticlassClassificationFFN class. + + Parameters + ---------- + n_classes : int + how many classes are expected in the output + n_tasks : int, optional (default=1) + Number of tasks. + input_dim : int, optional (default=DEFAULT_HIDDEN_DIM) + Input dimension. + hidden_dim : int, optional (default=300) + Hidden dimension. + n_layers : int, optional (default=1) + Number of layers. + dropout : float, optional (default=0) + Dropout rate. + activation : str, optional (default="relu") + Activation function. + criterion : LossFunction or None, optional (default=None) + Loss function. None defaults to BCELoss. + task_weights : Tensor or None, optional (default=None) + Task weights. + threshold : float or None, optional (default=None) + Threshold for binary classification. + output_transform : UnscaleTransform or None, optional (default=None) + Transformations to apply to the output. None defaults to UnscaleTransform. + """ + super().__init__( + n_tasks, + input_dim, + hidden_dim, + n_layers, + dropout, + activation, + criterion, + task_weights, + threshold, + output_transform, + n_classes=n_classes, + ) + + self.n_classes = n_classes + class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type: ignore """A wrapper for the MulticlassDirichletFFN class.""" diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index e3257ae5..b94bcb02 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -17,7 +17,6 @@ from chemprop.data import MoleculeDataset, build_dataloader from chemprop.nn.predictors import ( BinaryClassificationFFNBase, - MulticlassClassificationFFN, ) from lightning import pytorch as pl except ImportError as error: @@ -34,6 +33,7 @@ BondMessagePassing, RegressionFFN, SumAggregation, + MulticlassClassificationFFN, ) from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP @@ -348,3 +348,163 @@ def __init__( n_jobs=n_jobs, **kwargs, ) + + +class ChempropMulticlassClassifier(ChempropModel): + """Chemprop model with default parameters for multiclass classification tasks.""" + + def __init__( + self, + n_classes: int, + model: MPNN | None = None, + lightning_trainer: pl.Trainer | None = None, + batch_size: int = 64, + n_jobs: int = 1, + **kwargs: Any, + ) -> None: + """Initialize the chemprop multiclass model. + + Parameters + ---------- + n_classes : int + The number of classes for the classifier. + model : MPNN | None, optional + The chemprop model to wrap. If None, a default model will be used. + lightning_trainer : pl.Trainer, optional + The lightning trainer to use, by default None + batch_size : int, optional (default=64) + The batch size to use. + n_jobs : int, optional (default=1) + The number of jobs to use. + kwargs : Any + Parameters set using `set_params`. + Can be used to modify components of the model. + + Raises + ------ + AttributeError + If the passed model.predictor does not have an attribute n_classes. + ValueError + If the number of classes in the predictor does not match the number of classes given as attribute. + """ + if model is None: + bond_encoder = BondMessagePassing() + agg = SumAggregation() + predictor = MulticlassClassificationFFN(n_classes=n_classes) + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + if not hasattr(model.predictor, "n_classes"): + raise AttributeError( + "The predictor does not have an attribute n_classes. Please use a MulticlassClassificationFFN predictor or define n_classes." + ) + if n_classes != model.predictor.n_classes: + raise ValueError( + "The number of classes in the predictor does not match the number of classes." + ) + super().__init__( + model=model, + lightning_trainer=lightning_trainer, + batch_size=batch_size, + n_jobs=n_jobs, + **kwargs, + ) + self._is_valid_multiclass_classifier() + + @property + def n_classes(self) -> int: + """Return the number of classes.""" + return self.model.predictor.n_classes + + @n_classes.setter + def n_classes(self, n_classes: int) -> None: + """Set the number of classes. + + Parameters + ---------- + n_classes : int + number of classes + """ + self.model.predictor.n_classes = n_classes + self.model.reinitialize_network() + + def set_params(self, **params: Any) -> Self: + """Set the parameters of the model and check if it is a multiclass classifier. + + Parameters + ---------- + **params + The parameters to set. + + Returns + ------- + Self + The model with the new parameters. + """ + super().set_params(**params) + if not self._is_valid_multiclass_classifier(): + raise ValueError( + "The model's predictor or the number of classes are invalid. Use a multiclass predictor and more than 2 classes." + ) + return self + + def fit( + self, + X: MoleculeDataset, + y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64], + ) -> Self: + """Fit the model to the data. + + Parameters + ---------- + X : MoleculeDataset + The input data. + y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64] + The target data. + + Returns + ------- + Self + The fitted model. + """ + self._check_correct_input(y) + return super().fit(X, y) + + def _check_correct_input( + self, y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64] + ) -> None: + """Check if the input for the multi-class classifier is correct. + + Parameters + ---------- + y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64] + Indended classes for the dataset + + Raises + ------ + ValueError + If the classes found in y are not matching n_classes or if the class labels do not start from 0 to n_classes-1. + """ + unique_y = np.unique(y) + log = [] + if self.n_classes != len(unique_y): + log.append( + f"Given number of classes in init (n_classes) does not match the number of unique classes (found {unique_y}) in the target data." + ) + if sorted(unique_y) != list(range(self.n_classes)): + err = f"Classes need to be in the range from 0 to {self.n_classes-1}. Found {unique_y}. Please correct the input data accordingly." + log.append(err) + if log: + raise ValueError("\n".join(log)) + + def _is_valid_multiclass_classifier(self) -> bool: + """Check if a multiclass classifier is valid. Model FFN needs to be of the correct class and model needs to have more than 2 classes. + + Returns + ------- + bool + True if is a valid multiclass classifier, False otherwise. + """ + has_correct_model = isinstance( + self.model.predictor, MulticlassClassificationFFN + ) + has_classes = self.n_classes > 2 + return has_correct_model and has_classes diff --git a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py index 89663866..7861024d 100644 --- a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py +++ b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py @@ -7,8 +7,10 @@ MPNN, BinaryClassificationFFN, BondMessagePassing, + MulticlassClassificationFFN, SumAggregation, ) +from molpipeline.estimators.chemprop.loss_wrapper import CrossEntropyLoss # These are model parameters which are copied by value, but are too complex to check for equality. # Thus, for these model parameters, only the type is checked. @@ -23,7 +25,7 @@ # Default parameters for the Chemprop model. -DEFAULT_PARAMS = { +DEFAULT_BINARY_CLASSIFICATION_PARAMS = { "batch_size": 64, "lightning_trainer": None, "lightning_trainer__enable_checkpointing": False, @@ -94,3 +96,142 @@ "model__predictor__threshold": None, "n_jobs": 1, } + +DEFAULT_SET_PARAMS = { + "batch_size": 64, + "lightning_trainer": None, + "lightning_trainer__enable_checkpointing": False, + "lightning_trainer__enable_model_summary": False, + "lightning_trainer__max_epochs": 500, + "lightning_trainer__accelerator": "cpu", + "lightning_trainer__default_root_dir": None, + "lightning_trainer__limit_predict_batches": 1.0, + "lightning_trainer__detect_anomaly": False, + "lightning_trainer__reload_dataloaders_every_n_epochs": 0, + "lightning_trainer__precision": "32-true", + "lightning_trainer__min_steps": None, + "lightning_trainer__max_time": None, + "lightning_trainer__limit_train_batches": 1.0, + "lightning_trainer__strategy": "auto", + "lightning_trainer__gradient_clip_algorithm": None, + "lightning_trainer__log_every_n_steps": 50, + "lightning_trainer__limit_val_batches": 1.0, + "lightning_trainer__gradient_clip_val": None, + "lightning_trainer__overfit_batches": 0.0, + "lightning_trainer__num_nodes": 1, + "lightning_trainer__use_distributed_sampler": True, + "lightning_trainer__check_val_every_n_epoch": 1, + "lightning_trainer__benchmark": False, + "lightning_trainer__inference_mode": True, + "lightning_trainer__limit_test_batches": 1.0, + "lightning_trainer__fast_dev_run": False, + "lightning_trainer__logger": None, + "lightning_trainer__max_steps": -1, + "lightning_trainer__num_sanity_val_steps": 2, + "lightning_trainer__devices": "auto", + "lightning_trainer__min_epochs": None, + "lightning_trainer__val_check_interval": 1.0, + "lightning_trainer__barebones": False, + "lightning_trainer__accumulate_grad_batches": 1, + "lightning_trainer__deterministic": False, + "lightning_trainer__enable_progress_bar": True, + "model__agg__dim": 0, + "model__batch_norm": True, + "model__final_lr": 0.0001, + "model__init_lr": 0.0001, + "model__max_lr": 0.001, + "model__message_passing__activation": "relu", + "model__message_passing__bias": False, + "model__message_passing__d_e": 14, + "model__message_passing__d_h": 300, + "model__message_passing__d_v": 72, + "model__message_passing__d_vd": None, + "model__message_passing__depth": 3, + "model__message_passing__dropout_rate": 0.0, + "model__message_passing__undirected": False, + "model__metric_list": None, + "model__predictor__activation": "relu", + "model__warmup_epochs": 2, + "model__predictor__dropout": 0, + "model__predictor__hidden_dim": 300, + "model__predictor__input_dim": 300, + "model__predictor__n_layers": 1, + "model__predictor__n_tasks": 1, + "model__predictor__threshold": None, + "n_jobs": 1, +} + + +DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS = { + "batch_size": 64, + "lightning_trainer": None, + "lightning_trainer__enable_checkpointing": False, + "lightning_trainer__enable_model_summary": False, + "lightning_trainer__max_epochs": 500, + "lightning_trainer__accelerator": "cpu", + "lightning_trainer__default_root_dir": None, + "lightning_trainer__limit_predict_batches": 1.0, + "lightning_trainer__detect_anomaly": False, + "lightning_trainer__reload_dataloaders_every_n_epochs": 0, + "lightning_trainer__precision": "32-true", + "lightning_trainer__min_steps": None, + "lightning_trainer__max_time": None, + "lightning_trainer__limit_train_batches": 1.0, + "lightning_trainer__strategy": "auto", + "lightning_trainer__gradient_clip_algorithm": None, + "lightning_trainer__log_every_n_steps": 50, + "lightning_trainer__limit_val_batches": 1.0, + "lightning_trainer__gradient_clip_val": None, + "lightning_trainer__overfit_batches": 0.0, + "lightning_trainer__num_nodes": 1, + "lightning_trainer__use_distributed_sampler": True, + "lightning_trainer__check_val_every_n_epoch": 1, + "lightning_trainer__benchmark": False, + "lightning_trainer__inference_mode": True, + "lightning_trainer__limit_test_batches": 1.0, + "lightning_trainer__fast_dev_run": False, + "lightning_trainer__logger": None, + "lightning_trainer__max_steps": -1, + "lightning_trainer__num_sanity_val_steps": 2, + "lightning_trainer__devices": "auto", + "lightning_trainer__min_epochs": None, + "lightning_trainer__val_check_interval": 1.0, + "lightning_trainer__barebones": False, + "lightning_trainer__accumulate_grad_batches": 1, + "lightning_trainer__deterministic": False, + "lightning_trainer__enable_progress_bar": True, + "model": MPNN, + "model__agg__dim": 0, + "model__agg": SumAggregation, + "model__batch_norm": True, + "model__final_lr": 0.0001, + "model__init_lr": 0.0001, + "model__max_lr": 0.001, + "model__message_passing__activation": "relu", + "model__message_passing__bias": False, + "model__message_passing__d_e": 14, + "model__message_passing__d_h": 300, + "model__message_passing__d_v": 72, + "model__message_passing__d_vd": None, + "model__message_passing__depth": 3, + "model__message_passing__dropout_rate": 0.0, + "model__message_passing__undirected": False, + "model__message_passing": BondMessagePassing, + "model__metric_list": None, + "model__predictor__activation": "relu", + "model__warmup_epochs": 2, + "model__predictor": MulticlassClassificationFFN, + "model__predictor__criterion": CrossEntropyLoss, + "model__predictor__criterion__task_weights": Tensor([1.0, 1.0, 1.0]), + "model__predictor__dropout": 0, + "model__predictor__hidden_dim": 300, + "model__predictor__input_dim": 300, + "model__predictor__n_classes": 3, + "model__predictor__n_layers": 1, + "model__predictor__n_tasks": 1, + "model__predictor__output_transform": nn.Identity, + "model__predictor__task_weights": Tensor([1.0, 1.0, 1.0]), + "model__predictor__threshold": None, + "n_classes": 3, + "n_jobs": 1, +} diff --git a/test_extras/test_chemprop/test_chemprop_pipeline.py b/test_extras/test_chemprop/test_chemprop_pipeline.py index c5f66fa0..646ac99c 100644 --- a/test_extras/test_chemprop/test_chemprop_pipeline.py +++ b/test_extras/test_chemprop/test_chemprop_pipeline.py @@ -23,6 +23,7 @@ ChempropClassifier, ChempropModel, ChempropRegressor, + ChempropMulticlassClassifier, ) from molpipeline.mol2any.mol2chemprop import MolToChemprop from molpipeline.pipeline import Pipeline @@ -139,6 +140,40 @@ def get_classification_pipeline() -> Pipeline: return model_pipeline +def get_multiclass_classification_pipeline(n_classes: int) -> Pipeline: + """Get the Chemprop model pipeline for multiclass classification. + + Parameters + ---------- + n_classes : int + The number of classes for model initialization. + + Returns + ------- + Pipeline + The Chemprop model pipeline for multiclass classification. + """ + smiles2mol = SmilesToMol() + mol2chemprop = MolToChemprop() + error_filter = ErrorFilter(filter_everything=True) + filter_reinserter = FilterReinserter.from_error_filter( + error_filter, fill_value=np.nan + ) + chemprop_model = ChempropMulticlassClassifier( + n_classes=n_classes, lightning_trainer=DEFAULT_TRAINER + ) + model_pipeline = Pipeline( + steps=[ + ("smiles2mol", smiles2mol), + ("mol2chemprop", mol2chemprop), + ("error_filter", error_filter), + ("model", chemprop_model), + ("filter_reinserter", PostPredictionWrapper(filter_reinserter)), + ], + ) + return model_pipeline + + _T = TypeVar("_T") @@ -282,7 +317,6 @@ def test_prediction(self) -> None: molecule_net_bbbp_df = pd.read_csv( TEST_DATA_DIR / "molecule_net_bbbp.tsv.gz", sep="\t", nrows=100 ) - molecule_net_bbbp_df.to_csv("molecule_net_bbbp.tsv.gz", sep="\t", index=False) classification_model = get_classification_pipeline() classification_model.fit( molecule_net_bbbp_df["smiles"].tolist(), @@ -306,3 +340,49 @@ def test_prediction(self) -> None: self.assertEqual(proba.shape, proba_copy.shape) self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices])) + + +class TestMulticlassClassificationPipeline(unittest.TestCase): + """Test the Chemprop model pipeline for multiclass classification.""" + + def test_prediction(self) -> None: + """Test the prediction of the multiclass classification model.""" + + test_data_df = pd.read_csv( + TEST_DATA_DIR / "multiclass_mock.tsv", sep="\t", index_col=False + ) + classification_model = get_multiclass_classification_pipeline(n_classes=3) + mols = test_data_df["Molecule"].tolist() + classification_model.fit( + mols, + test_data_df["Label"].to_numpy(), + ) + pred = classification_model.predict(mols) + proba = classification_model.predict_proba(mols) + self.assertEqual(len(pred), len(test_data_df)) + self.assertEqual(proba.shape[1], 3) + self.assertEqual(proba.shape[0], len(test_data_df)) + + model_copy = joblib_dump_load(classification_model) + pred_copy = model_copy.predict(mols) + proba_copy = model_copy.predict_proba(mols) + + nan_mask = np.isnan(pred) + self.assertListEqual(nan_mask.tolist(), np.isnan(pred_copy).tolist()) + self.assertTrue(np.allclose(pred[~nan_mask], pred_copy[~nan_mask])) + + self.assertEqual(proba.shape, proba_copy.shape) + self.assertEqual(pred.shape, pred_copy.shape) + self.assertTrue(np.allclose(proba[~nan_mask], proba_copy[~nan_mask])) + + with self.assertRaises(ValueError): + classification_model.fit( + mols, + test_data_df["Label"].add(1).to_numpy(), + ) + with self.assertRaises(ValueError): + classification_model = get_multiclass_classification_pipeline(n_classes=2) + classification_model.fit( + mols, + test_data_df["Label"].to_numpy(), + ) diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index d6eb6df9..9afaf111 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -10,12 +10,16 @@ from molpipeline.estimators.chemprop.component_wrapper import ( MPNN, + BondMessagePassing, MeanAggregation, + MulticlassClassificationFFN, RegressionFFN, + SumAggregation, ) from molpipeline.estimators.chemprop.models import ( ChempropClassifier, ChempropModel, + ChempropMulticlassClassifier, ChempropRegressor, ) from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP @@ -24,8 +28,10 @@ # pylint: disable=relative-beyond-top-level from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params from test_extras.test_chemprop.chemprop_test_utils.constant_vars import ( - DEFAULT_PARAMS, NO_IDENTITY_CHECK, + DEFAULT_SET_PARAMS, + DEFAULT_BINARY_CLASSIFICATION_PARAMS, + DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS, ) from test_extras.test_chemprop.chemprop_test_utils.default_models import ( get_chemprop_model_binary_classification_mpnn, @@ -41,7 +47,7 @@ def test_get_params(self) -> None: """Test the get_params and set_params methods.""" chemprop_model = get_chemprop_model_binary_classification_mpnn() orig_params = chemprop_model.get_params(deep=True) - expected_params = dict(DEFAULT_PARAMS) # Shallow copy + expected_params = dict(DEFAULT_BINARY_CLASSIFICATION_PARAMS) # Shallow copy self.assertSetEqual(set(orig_params), set(expected_params)) # Check if the parameters are as expected @@ -108,8 +114,10 @@ def test_json_serialization(self) -> None: chemprop_model_copy = recursive_from_json(chemprop_json) param_dict = chemprop_model_copy.get_params(deep=True) - self.assertSetEqual(set(param_dict.keys()), set(DEFAULT_PARAMS.keys())) - for param_name, param in DEFAULT_PARAMS.items(): + self.assertSetEqual( + set(param_dict.keys()), set(DEFAULT_BINARY_CLASSIFICATION_PARAMS.keys()) + ) + for param_name, param in DEFAULT_BINARY_CLASSIFICATION_PARAMS.items(): if param_name in NO_IDENTITY_CHECK: if isinstance(param, Iterable): self.assertIsInstance(param_dict[param_name], type(param)) @@ -134,7 +142,7 @@ def test_get_params(self) -> None: """Test the get_params and set_params methods.""" chemprop_model = ChempropClassifier(lightning_trainer__accelerator="cpu") param_dict = chemprop_model.get_params(deep=True) - expected_params = dict(DEFAULT_PARAMS) # Shallow copy + expected_params = dict(DEFAULT_BINARY_CLASSIFICATION_PARAMS) # Shallow copy self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) for param_name, param in expected_params.items(): if param_name in NO_IDENTITY_CHECK: @@ -151,6 +159,14 @@ def test_get_params(self) -> None: param_dict[param_name], param, f"Test failed for {param_name}" ) + def test_set_params(self) -> None: + """Test the set_params methods.""" + chemprop_model = ChempropClassifier(lightning_trainer__accelerator="cpu") + chemprop_model.set_params(**DEFAULT_SET_PARAMS) + current_params = chemprop_model.get_params(deep=True) + for param, value in DEFAULT_SET_PARAMS.items(): + self.assertEqual(current_params[param], value) + class TestChempropRegressor(unittest.TestCase): """Test the Chemprop regressor model.""" @@ -159,7 +175,7 @@ def test_get_params(self) -> None: """Test the get_params and set_params methods.""" chemprop_model = ChempropRegressor(lightning_trainer__accelerator="cpu") param_dict = chemprop_model.get_params(deep=True) - expected_params = dict(DEFAULT_PARAMS) + expected_params = dict(DEFAULT_BINARY_CLASSIFICATION_PARAMS) expected_params["model__predictor"] = RegressionFFN expected_params["model__predictor__criterion"] = MSELoss self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) @@ -177,3 +193,69 @@ def test_get_params(self) -> None: self.assertEqual( param_dict[param_name], param, f"Test failed for {param_name}" ) + + +class TestChempropMulticlassClassifier(unittest.TestCase): + """Test the Chemprop classifier model.""" + + def test_get_params(self) -> None: + """Test the get_params and set_params methods.""" + n_classes = 3 + chemprop_model = ChempropMulticlassClassifier( + lightning_trainer__accelerator="cpu", n_classes=n_classes + ) + param_dict = chemprop_model.get_params(deep=True) + expected_params = dict(DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS) # Shallow copy + expected_params["model__predictor__n_classes"] = n_classes + expected_params["n_classes"] = n_classes + self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) + for param_name, param in expected_params.items(): + if param_name in NO_IDENTITY_CHECK: + if isinstance(param, Iterable): + self.assertIsInstance(param_dict[param_name], type(param)) + for i, p in enumerate(param): + self.assertIsInstance(param_dict[param_name][i], p) + elif isinstance(param, type): + self.assertIsInstance(param_dict[param_name], param) + else: + raise ValueError(f"{param_name} should be a type.") + elif isinstance(param, torch.Tensor): + self.assertTrue(torch.allclose(param_dict[param_name], param)) + else: + self.assertEqual( + param_dict[param_name], param, f"Test failed for {param_name}" + ) + + def test_set_params(self) -> None: + """Test the set_params methods.""" + chemprop_model = ChempropMulticlassClassifier( + lightning_trainer__accelerator="cpu", n_classes=3 + ) + chemprop_model.set_params(**DEFAULT_SET_PARAMS) + params = { + "n_classes": 4, + "batch_size": 20, + "lightning_trainer__max_epochs": 10, + "model__predictor__n_layers": 2, + } + chemprop_model.set_params(**params) + current_params = chemprop_model.get_params(deep=True) + for param, value in params.items(): + self.assertEqual(current_params[param], value) + + def test_error_for_multiclass_predictor(self) -> None: + """Test the error for using a multiclass predictor for a binary classification model.""" + bond_encoder = BondMessagePassing() + agg = SumAggregation() + with self.assertRaises(ValueError): + predictor = MulticlassClassificationFFN(n_classes=2) + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + ChempropMulticlassClassifier(n_classes=2, model=model) + with self.assertRaises(ValueError): + predictor = MulticlassClassificationFFN(n_classes=3) + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + ChempropMulticlassClassifier(n_classes=4, model=model) + with self.assertRaises(AttributeError): + predictor = RegressionFFN() + model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor) + ChempropMulticlassClassifier(n_classes=4, model=model) diff --git a/tests/test_data/multiclass_mock.tsv b/tests/test_data/multiclass_mock.tsv new file mode 100644 index 00000000..ec494222 --- /dev/null +++ b/tests/test_data/multiclass_mock.tsv @@ -0,0 +1,13 @@ +Molecule Label +"CCCCCC" 0 +"CCCCCCCO" 1 +"CCCC" 0 +"CCCN" 2 +"CCCCCC" 0 +"CCCO" 1 +"CCCCC" 0 +"CCCCCN" 2 +"CC(C)CCC" 0 +"CCCCCCO" 1 +"CCCCCl" 0 +"CCC#N" 2