Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

72 make chemprop multiclass classification model #73

Merged
merged 25 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
737db7f
add multi class classifier
JenniferHem Aug 26, 2024
8692bc2
use input check to prevent confusing message by torch of the class la…
JenniferHem Aug 26, 2024
b560d18
remove get_params
JenniferHem Aug 26, 2024
b3d0af8
make n classes non-optional
JenniferHem Aug 26, 2024
fe40e4a
black
JenniferHem Aug 26, 2024
70e5928
ignore loghtning logs
JenniferHem Aug 27, 2024
3de60fd
add test for multiclass
JenniferHem Aug 27, 2024
3003493
mock data for test
JenniferHem Aug 27, 2024
1b84120
remove random write csv
JenniferHem Aug 27, 2024
3dbcff8
add test for full coverage of multiclass chemprop
JenniferHem Aug 28, 2024
dd0ebbe
add missing parameters for docsig
JenniferHem Aug 28, 2024
d744d4d
code review requests
JenniferHem Aug 28, 2024
e404579
Adapt Eror message
JenniferHem Aug 28, 2024
2b2d687
check classifier in init
JenniferHem Aug 28, 2024
f87d68b
docstring adaptations
JenniferHem Aug 28, 2024
7faedc1
fix docstings and naming in tests
JenniferHem Aug 28, 2024
4834614
split instace check from validation
JenniferHem Aug 30, 2024
261d7db
add test for set_params and initialize Multiclass FFN properlky
JenniferHem Sep 3, 2024
4e84111
raise attribute error if wrong model.predictor is passed
JenniferHem Sep 3, 2024
c5810fc
test multiclass setter and getter
JenniferHem Sep 3, 2024
a064100
pass correct tasks
JenniferHem Sep 3, 2024
33e1202
black
JenniferHem Sep 3, 2024
4a117d5
docsig and pydocstyle
JenniferHem Sep 3, 2024
fdb1d31
lint: docstrings and tests
JenniferHem Sep 3, 2024
eef3d22
missing space
JenniferHem Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__
molpipeline.egg-info/
lib/
build/
lightning_logs/
c-w-feldmann marked this conversation as resolved.
Show resolved Hide resolved

61 changes: 61 additions & 0 deletions molpipeline/estimators/chemprop/component_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
task_weights: Tensor | None = None,
threshold: float | None = None,
output_transform: UnscaleTransform | None = None,
**kwargs: Any,
):
"""Initialize the BinaryClassificationFFN class.

Expand All @@ -188,6 +189,8 @@ def __init__(
Threshold for binary classification.
output_transform : UnscaleTransform or None, optional (default=None)
Transformations to apply to the output. None defaults to UnscaleTransform.
kwargs : Any
Additional keyword arguments.
"""
if task_weights is None:
task_weights = torch.ones(n_tasks)
Expand All @@ -200,6 +203,7 @@ def __init__(
activation=activation,
criterion=criterion,
output_transform=output_transform,
**kwargs,
)
self.n_tasks = n_tasks
self._input_dim = input_dim
Expand Down Expand Up @@ -323,6 +327,63 @@ class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN
_T_default_criterion = CrossEntropyLoss
_T_default_metric = CrossEntropyMetric

def __init__(
self,
n_classes: int,
n_tasks: int = 1,
input_dim: int = DEFAULT_HIDDEN_DIM,
hidden_dim: int = 300,
n_layers: int = 1,
dropout: float = 0.0,
activation: str = "relu",
criterion: LossFunction | None = None,
task_weights: Tensor | None = None,
threshold: float | None = None,
output_transform: UnscaleTransform | None = None,
):
"""Initialize the MulticlassClassificationFFN class.

Parameters
----------
n_classes : int
how many classes are expected in the output
n_tasks : int, optional (default=1)
Number of tasks.
input_dim : int, optional (default=DEFAULT_HIDDEN_DIM)
Input dimension.
hidden_dim : int, optional (default=300)
Hidden dimension.
n_layers : int, optional (default=1)
Number of layers.
dropout : float, optional (default=0)
Dropout rate.
activation : str, optional (default="relu")
Activation function.
criterion : LossFunction or None, optional (default=None)
Loss function. None defaults to BCELoss.
task_weights : Tensor or None, optional (default=None)
Task weights.
threshold : float or None, optional (default=None)
Threshold for binary classification.
output_transform : UnscaleTransform or None, optional (default=None)
Transformations to apply to the output. None defaults to UnscaleTransform.
"""
super().__init__(
n_tasks,
input_dim,
hidden_dim,
n_layers,
dropout,
activation,
criterion,
task_weights,
threshold,
output_transform,
n_classes=n_classes,
)

self.n_classes = n_classes


class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type: ignore
"""A wrapper for the MulticlassDirichletFFN class."""
Expand Down
162 changes: 161 additions & 1 deletion molpipeline/estimators/chemprop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from chemprop.data import MoleculeDataset, build_dataloader
from chemprop.nn.predictors import (
BinaryClassificationFFNBase,
MulticlassClassificationFFN,
)
from lightning import pytorch as pl
except ImportError as error:
Expand All @@ -34,6 +33,7 @@
BondMessagePassing,
RegressionFFN,
SumAggregation,
MulticlassClassificationFFN,
)
from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP

Expand Down Expand Up @@ -348,3 +348,163 @@ def __init__(
n_jobs=n_jobs,
**kwargs,
)


class ChempropMulticlassClassifier(ChempropModel):
"""Chemprop model with default parameters for multiclass classification tasks."""

def __init__(
self,
n_classes: int,
model: MPNN | None = None,
lightning_trainer: pl.Trainer | None = None,
batch_size: int = 64,
n_jobs: int = 1,
**kwargs: Any,
) -> None:
"""Initialize the chemprop multiclass model.

Parameters
----------
n_classes : int
The number of classes for the classifier.
model : MPNN | None, optional
The chemprop model to wrap. If None, a default model will be used.
lightning_trainer : pl.Trainer, optional
The lightning trainer to use, by default None
batch_size : int, optional (default=64)
The batch size to use.
n_jobs : int, optional (default=1)
The number of jobs to use.
kwargs : Any
Parameters set using `set_params`.
Can be used to modify components of the model.

Raises
------
AttributeError
If the passed model.predictor does not have an attribute n_classes.
ValueError
If the number of classes in the predictor does not match the number of classes given as attribute.
"""
if model is None:
bond_encoder = BondMessagePassing()
agg = SumAggregation()
predictor = MulticlassClassificationFFN(n_classes=n_classes)
model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor)
if not hasattr(model.predictor, "n_classes"):
raise AttributeError(
"The predictor does not have an attribute n_classes. Please use a MulticlassClassificationFFN predictor or define n_classes."
)
if n_classes != model.predictor.n_classes:
raise ValueError(
"The number of classes in the predictor does not match the number of classes."
)
super().__init__(
model=model,
lightning_trainer=lightning_trainer,
batch_size=batch_size,
n_jobs=n_jobs,
**kwargs,
)
self._is_valid_multiclass_classifier()

@property
def n_classes(self) -> int:
"""Return the number of classes."""
return self.model.predictor.n_classes

@n_classes.setter
def n_classes(self, n_classes: int) -> None:
"""Set the number of classes.

Parameters
----------
n_classes : int
number of classes
"""
self.model.predictor.n_classes = n_classes
self.model.reinitialize_network()

def set_params(self, **params: Any) -> Self:
"""Set the parameters of the model and check if it is a multiclass classifier.

Parameters
----------
**params
The parameters to set.

Returns
-------
Self
The model with the new parameters.
"""
super().set_params(**params)
JochenSiegWork marked this conversation as resolved.
Show resolved Hide resolved
if not self._is_valid_multiclass_classifier():
raise ValueError(
"The model's predictor or the number of classes are invalid. Use a multiclass predictor and more than 2 classes."
)
return self

def fit(
self,
X: MoleculeDataset,
y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64],
) -> Self:
"""Fit the model to the data.

Parameters
----------
X : MoleculeDataset
The input data.
y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64]
The target data.

Returns
-------
Self
The fitted model.
"""
self._check_correct_input(y)
JochenSiegWork marked this conversation as resolved.
Show resolved Hide resolved
return super().fit(X, y)

def _check_correct_input(
self, y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64]
) -> None:
"""Check if the input for the multi-class classifier is correct.

Parameters
----------
y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64]
Indended classes for the dataset

Raises
------
ValueError
If the classes found in y are not matching n_classes or if the class labels do not start from 0 to n_classes-1.
"""
unique_y = np.unique(y)
log = []
if self.n_classes != len(unique_y):
log.append(
f"Given number of classes in init (n_classes) does not match the number of unique classes (found {unique_y}) in the target data."
)
if sorted(unique_y) != list(range(self.n_classes)):
JenniferHem marked this conversation as resolved.
Show resolved Hide resolved
err = f"Classes need to be in the range from 0 to {self.n_classes-1}. Found {unique_y}. Please correct the input data accordingly."
log.append(err)
if log:
raise ValueError("\n".join(log))

def _is_valid_multiclass_classifier(self) -> bool:
"""Check if a multiclass classifier is valid. Model FFN needs to be of the correct class and model needs to have more than 2 classes.

Returns
-------
bool
True if is a valid multiclass classifier, False otherwise.
"""
has_correct_model = isinstance(
self.model.predictor, MulticlassClassificationFFN
)
has_classes = self.n_classes > 2
return has_correct_model and has_classes
Loading
Loading