Skip to content

Commit

Permalink
Merge pull request #475 from theislab/fix/discriminator_classifier_NN…
Browse files Browse the repository at this point in the history
…_dimensions

Fix/discriminator classifier nn dimensions
  • Loading branch information
Lilly-May authored Dec 29, 2023
2 parents 8299e5d + 1aa5f42 commit 51db830
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 25 deletions.
68 changes: 45 additions & 23 deletions pertpy/tools/_perturbation_space/_discriminator_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@
from typing import TYPE_CHECKING

import anndata
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import scipy
import torch
from anndata import AnnData
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
from torch import optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler

from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace

if TYPE_CHECKING:
import numpy as np


class DiscriminatorClassifierSpace(PerturbationSpace):
"""Leveraging discriminator classifier. Fit a regressor model to the data and take the feature space.
Expand Down Expand Up @@ -66,7 +65,7 @@ def load( # type: ignore
>>> dcs.load(adata, target_col="gene_target")
"""
if layer_key is not None and layer_key not in adata.obs.columns:
raise ValueError(f"Layer key {layer_key} not found in adata. {layer_key}")
raise ValueError(f"Layer key {layer_key} not found in adata.")

if target_col not in adata.obs:
raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")
Expand All @@ -76,10 +75,10 @@ def load( # type: ignore

# Labels are strings, one hot encoding for classification
n_classes = len(adata.obs[target_col].unique())
labels = adata.obs[target_col]
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(labels)
adata.obs["encoded_perturbations"] = encoded_labels
labels = adata.obs[target_col].values.reshape(-1, 1)
encoder = OneHotEncoder()
encoded_labels = encoder.fit_transform(labels).toarray()
adata.obs["encoded_perturbations"] = [np.float32(label) for label in encoded_labels]

# Split the data in train, test and validation
X = list(range(0, adata.n_obs))
Expand All @@ -103,13 +102,12 @@ def load( # type: ignore
# Fix class unbalance (likely to happen in perturbation datasets)
# Usually control cells are overrepresented such that predicting control all time would give good results
# Cells with rare perturbations are sampled more
class_weights = 1.0 / torch.bincount(torch.tensor(train_dataset.labels.values))
train_weights = class_weights[train_dataset.labels]
train_weights = 1 / (1 + torch.sum(torch.tensor(train_dataset.labels), dim=1))
train_sampler = WeightedRandomSampler(train_weights, len(train_weights))

self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
self.test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
self.valid_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Define the network
sizes = [adata.n_vars] + hidden_dim + [n_classes]
Expand All @@ -119,7 +117,10 @@ def load( # type: ignore
total_dataset = PLDataset(
adata=adata, target_col="encoded_perturbations", label_col=target_col, layer_key=layer_key
)
self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=4)
self.entire_dataset = DataLoader(total_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=0)

# Save adata observations for embedding annotations in get_embeddings
self.adata_obs = adata.obs.reset_index(drop=True)

return self

Expand Down Expand Up @@ -148,7 +149,7 @@ def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int =
accelerator="auto",
)

self.model = PerturbationClassifier(model=self.net)
self.model = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)

self.trainer.fit(
model=self.model, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader
Expand All @@ -173,13 +174,25 @@ def get_embeddings(self) -> AnnData:
self.model.eval()
for dataset_count, batch in enumerate(self.entire_dataset):
emb, y = self.model.get_embeddings(batch)
emb = torch.squeeze(emb)
batch_adata = AnnData(X=emb.cpu().numpy())
batch_adata.obs["perturbations"] = y
if dataset_count == 0:
pert_adata = batch_adata
else:
pert_adata = anndata.concat([pert_adata, batch_adata])

# Add .obs annotations to the pert_adata. Because shuffle=False and num_workers=0, the order of the data is stable
# and we can just add the annotations from the original AnnData object
pert_adata.obs = pert_adata.obs.reset_index(drop=True)
if "perturbations" in self.adata_obs.columns:
self.adata_obs = self.adata_obs.drop("perturbations", axis=1)
pert_adata.obs = pd.concat([pert_adata.obs, self.adata_obs], axis=1)

# Drop the 'encoded_perturbations' colums, since this stores the one-hot encoded labels as numpy arrays,
# which would cause errors in the downstream processing of the AnnData object (e.g. when plotting)
pert_adata.obs = pert_adata.obs.drop("encoded_perturbations", axis=1)

return pert_adata


Expand Down Expand Up @@ -284,8 +297,7 @@ def __len__(self):

def __getitem__(self, idx):
"""Returns a sample and corresponding perturbations applied (labels)"""

sample = self.data[idx].A if scipy.sparse.issparse(self.data) else self.data[idx]
sample = self.data[idx].A.squeeze() if scipy.sparse.issparse(self.data) else self.data[idx]
num_label = self.labels.iloc[idx]
str_label = self.pert_labels.iloc[idx]

Expand All @@ -296,6 +308,7 @@ class PerturbationClassifier(pl.LightningModule):
def __init__(
self,
model: torch.nn.Module,
batch_size: int,
layers: list = [512], # noqa
dropout: float = 0.0,
batch_norm: bool = True,
Expand All @@ -306,6 +319,8 @@ def __init__(
):
"""
Args:
model: model to be trained
batch_size: batch size
layers: list of layers of the MLP
dropout: dropout probability
batch_norm: whether to apply batch norm
Expand All @@ -315,6 +330,7 @@ def __init__(
seed: random seed
"""
super().__init__()
self.batch_size = batch_size
self.save_hyperparameters()
if model:
self.net = model
Expand Down Expand Up @@ -342,36 +358,42 @@ def configure_optimizers(self):
def training_step(self, batch, batch_idx):
x, y, _ = batch
x = x.to(torch.float32)
y = y.to(torch.long)

y_hat = self.forward(x)

y = torch.argmax(y, dim=1)
y_hat = y_hat.squeeze()

loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("train_loss", loss, prog_bar=True)
self.log("train_loss", loss, prog_bar=True, batch_size=self.batch_size)

return loss

def validation_step(self, batch, batch_idx):
x, y, _ = batch
x = x.to(torch.float32)
y = y.to(torch.long)

y_hat = self.forward(x)

y = torch.argmax(y, dim=1)
y_hat = y_hat.squeeze()

loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_loss", loss, prog_bar=True, batch_size=self.batch_size)

return loss

def test_step(self, batch, batch_idx):
x, y, _ = batch
x = x.to(torch.float32)
y = y.to(torch.long)

y_hat = self.forward(x)

y = torch.argmax(y, dim=1)
y_hat = y_hat.squeeze()

loss = torch.nn.functional.cross_entropy(y_hat, y)
self.log("test_loss", loss, prog_bar=True)
self.log("test_loss", loss, prog_bar=True, batch_size=self.batch_size)

return loss

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def test_discriminator_classifier():
X = np.zeros((20, 5))
X = np.zeros((20, 5), dtype=np.float32)

pert_index = [
"control",
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_discriminator_classifier():

# Compute the embeddings using the classifier
ps = pt.tl.DiscriminatorClassifierSpace()
classifier_ps = ps.load(adata)
classifier_ps = ps.load(adata, hidden_dim=[128])
classifier_ps.train(max_epochs=5)
pert_embeddings = classifier_ps.get_embeddings()

Expand Down

0 comments on commit 51db830

Please sign in to comment.