Skip to content

Commit

Permalink
Compute method for MLPClassifierSpace (#565)
Browse files Browse the repository at this point in the history
* Merged MLP methods into one compute method

* Improved parameter explanation

* Method formatting

* Remove not so important comment

* Decrease default parameters

---------

Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>
  • Loading branch information
Lilly-May and Zethson authored Mar 26, 2024
1 parent 413910d commit b3cbbea
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 51 deletions.
93 changes: 46 additions & 47 deletions pertpy/tools/_perturbation_space/_discriminator_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,13 @@ class MLPClassifierSpace(PerturbationSpace):
"""Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
We train the ANN to classify the different perturbations. After training, the penultimate layer is used as the
feature space, resulting in one embedding per cell.
feature space, resulting in one embedding per cell. Consider employing the PseudoBulk or another PerturbationSpace
to obtain one embedding per perturbation.
See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
We use either the coefficients of the model for each perturbation as a feature or train a classifier example
(simple MLP or logistic regression) and take the penultimate layer as feature space and apply pseudobulking approach.
"""

def load( # type: ignore
def compute( # type: ignore
self,
adata: AnnData,
target_col: str = "perturbations",
Expand All @@ -149,32 +148,49 @@ def load( # type: ignore
batch_size: int = 256,
test_split_size: float = 0.2,
validation_split_size: float = 0.25,
):
"""Creates a classifier model and dataloaders required for training and testing.
max_epochs: int = 20,
val_epochs_check: int = 2,
patience: int = 2,
) -> AnnData:
"""Creates cell embeddings by training a MLP classifier model to distinguish between perturbations.
A model is created using the specified parameters (hidden_dim, dropout, batch_norm). Further parameters such as
the number of classes to predict (number of perturbations) are obtained from the provided AnnData object directly.
It further creates dataloaders and fixes class imbalance due to control.
Sets the device to a GPU if available.
Dataloaders that take into account class imbalances are created. Next, the model is trained and tested, using the
GPU if available. The embeddings are obtained by passing the data through the model and extracting the values in
the last layer of the MLP. You will get one embedding per cell, so be aware that you might need to apply another
perturbation space to aggregate the embeddings per perturbation.
Args:
adata: AnnData object of size cells x genes
target_col: .obs column that stores the perturbations. Defaults to "perturbations".
layer_key: Layer in adata to use. Defaults to None.
hidden_dim: List of hidden layers of the neural network. For instance: [512, 256].
hidden_dim: List of number of neurons in each hidden layers of the neural network. For instance, [512, 256]
will create a neural network with two hidden layers, the first with 512 neurons and the second with 256 neurons.
Defaults to [512].
dropout: Amount of dropout applied, constant for all layers. Defaults to 0.
batch_norm: Whether to apply batch normalization. Defaults to True.
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
test_split_size: Fraction of data to put in the test set. Default to 0.2.
validation_split_size: Fraction of data to put in the validation set of the resultant train set.
E.g. a test_split_size of 0.2 and a validation_split_size of 0.25 means that 25% of 80% of the data
will be used for validation. Defaults to 0.25.
max_epochs: Maximum number of epochs for training. Defaults to 20.
val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
Note that this affects early stopping, as the model will be stopped if the validation performance does not
improve for patience epochs. Defaults to 2.
patience: Number of validation performance checks without improvement, after which the early stopping flag
is activated and training is therefore stopped. Defaults to 2.
Returns:
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
The AnnData will have shape (n_cells, n_features) where n_features is the number of features in the last layer of the MLP.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.papalexi_2021()["rna"]
>>> adata = pt.dt.norman_2019()
>>> dcs = pt.tl.MLPClassifierSpace()
>>> dcs.load(adata, target_col="gene_target")
>>> cell_embeddings = dcs.compute(adata, target_col="perturbation_name")
"""
if layer_key is not None and layer_key not in adata.obs.columns:
raise ValueError(f"Layer key {layer_key} not found in adata.")
Expand Down Expand Up @@ -234,25 +250,6 @@ def load( # type: ignore
# Save adata observations for embedding annotations in get_embeddings
self.adata_obs = adata.obs.reset_index(drop=True)

return self

def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = 2):
"""Trains and tests the ANN model defined in the load step.
Args:
max_epochs: Maximum number of epochs for training. Defaults to 40.
val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
Defaults to 5.
patience: Number of validation performance checks without improvement, after which the early stopping flag
is activated and training is therefore stopped. Defaults to 2.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.papalexi_2021()["rna"]
>>> dcs = pt.tl.MLPClassifierSpace()
>>> dcs.load(adata, target_col="gene_target")
>>> dcs.train(max_epochs=5)
"""
self.trainer = pl.Trainer(
min_epochs=1,
max_epochs=max_epochs,
Expand All @@ -267,23 +264,7 @@ def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int =
self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)

def get_embeddings(self) -> AnnData:
"""Obtain the embeddings of the data.
The embeddings correspond to the values in the last layer of the MLP. You will get one embedding per cell,
so be aware that you might need to apply another perturbation space to aggregate the embeddings per perturbation.
Returns:
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.papalexi_2021()["rna"]
>>> dcs = pt.tl.MLPClassifierSpace()
>>> dcs.load(adata, target_col="gene_target")
>>> dcs.train()
>>> embeddings = dcs.get_embeddings()
"""
# Obtain cell embeddings
with torch.no_grad():
self.mlp.eval()
for dataset_count, batch in enumerate(self.entire_dataset):
Expand All @@ -309,6 +290,24 @@ def get_embeddings(self) -> AnnData:

return pert_adata

def load(self, adata, **kwargs):
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
raise DeprecationWarning(
"The load method is deprecated and will be removed in the future. Please use the compute method instead."
)

def train(self, **kwargs):
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
raise DeprecationWarning(
"The train method is deprecated and will be removed in the future. Please use the compute method instead."
)

def get_embeddings(self, **kwargs):
"""This method is deprecated and will be removed in the future. Please use the compute method instead."""
raise DeprecationWarning(
"The get_embeddings method is deprecated and will be removed in the future. Please use the compute method instead."
)


class MLP(torch.nn.Module):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ def adata():


def test_mlp_classifier_space(adata):
ps = pt.tl.MLPClassifierSpace()
classifier_ps = ps.load(adata, hidden_dim=[128])
classifier_ps.train(max_epochs=2)
pert_embeddings = classifier_ps.get_embeddings()
classifier_ps = pt.tl.MLPClassifierSpace()
pert_embeddings = classifier_ps.compute(adata, hidden_dim=[128], max_epochs=2)

# The embeddings should cluster in 3 perfects clusters since the perturbations are easily separable
ps = pt.tl.KMeansSpace()
Expand Down

0 comments on commit b3cbbea

Please sign in to comment.