Skip to content

Commit

Permalink
Refactor mixscape
Browse files Browse the repository at this point in the history
Signed-off-by: zethson <lukas.heumos@posteo.net>
  • Loading branch information
Zethson committed Jan 13, 2024
1 parent 5c784ac commit c95bbcb
Showing 1 changed file with 70 additions and 69 deletions.
139 changes: 70 additions & 69 deletions pertpy/tools/_mixscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ def perturbation_signature(
Returns:
If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`.
Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`.
Examples:
Calcutate perturbation signature for each cell in the dataset:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
"""
if copy:
adata = adata.copy()
Expand All @@ -95,18 +95,17 @@ def perturbation_signature(
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
else:
split_obs = adata.obs[split_by]
cats = split_obs.unique()
split_masks = [split_obs == cat for cat in cats]
split_masks = [split_obs == cat for cat in split_obs.unique()]

R = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)

for split_mask in split_masks:
control_mask_split = control_mask & split_mask

R_split = R[split_mask]
R_control = R[control_mask_split]
R_split = representation[split_mask]
R_control = representation[control_mask_split]

from pynndescent import NNDescent # saves a lot of import time
from pynndescent import NNDescent

eps = kwargs.pop("epsilon", 0.1)
nn_index = NNDescent(R_control, **kwargs)
Expand Down Expand Up @@ -170,7 +169,6 @@ def mixscape(
Args:
adata: The annotated data object.
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
labels: The column of `.obs` with target gene labels.
control: Control category from the `pert_key` column.
new_class_name: Name of mixscape classification to be stored in `.obs`.
Expand All @@ -186,26 +184,26 @@ def mixscape(
Returns:
If `copy=True`, returns the copy of `adata` with the classification result in `.obs`.
Otherwise writes the results directly to `.obs` of the provided `adata`.
Otherwise, writes the results directly to `.obs` of the provided `adata`.
mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
- mixscape_class: pandas.Series (`adata.obs['mixscape_class']`).
Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class.
mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
Global classification result (perturbed, NP or NT)
- mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`).
Global classification result (perturbed, NP or NT).
mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
Posterior probabilities used to determine if a cell is KO (default).
Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP
- mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`).
Posterior probabilities used to determine if a cell is KO (default).
Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP.
Examples:
Calcutate perturbation signature for each cell in the dataset:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
"""
if copy:
adata = adata.copy()
Expand All @@ -229,10 +227,9 @@ def mixscape(
try:
X = adata_comp.layers["X_pert"]
except KeyError:
print(
'[bold yellow]No "X_pert" found in .layers! -- Please run pert_sign first to calculate perturbation signature!'
)
raise
raise KeyError(
"No 'X_pert' found in .layers! Please run pert_sign first to calculate perturbation signature!"
) from None
# initialize return variables
adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0
adata.obs[new_class_name] = adata.obs[labels].astype(str)
Expand Down Expand Up @@ -351,15 +348,17 @@ def lda(
control: Control category from the `pert_key` column. Defaults to 'NT'.
n_comps: Number of principal components to use. Defaults to 10.
min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells.
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. Defaults to 0.25.
logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells.
Defaults to 0.25.
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
pval_cutoff: P-value cut-off for selection of significantly DE genes.
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO.
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
Defaults to KO.
copy: Determines whether a copy of the `adata` is returned.
Returns:
If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`.
Otherwise writes the results directly to `.uns` of the provided `adata`.
Otherwise, writes the results directly to `.uns` of the provided `adata`.
mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`).
LDA result.
Expand All @@ -369,10 +368,10 @@ def lda(
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
"""
if copy:
adata = adata.copy()
Expand Down Expand Up @@ -524,12 +523,12 @@ def plot_barplot( # pragma: no cover
show: bool | None = None,
save: bool | str | None = None,
):
"""Barplot to visualize perturbation scores calculated from RunMixscape function.
"""Barplot to visualize perturbation scores calculated by the `mixscape` function.
Args:
adata: The annotated data object.
guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels.
The format must be <gene_target>g<#>. For example, 'STAT2g1' and 'ATF2g1'.
The format must be <gene_target>g<#>. Examples are 'STAT2g1' and 'ATF2g1'.
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
show: Show the plot, do not return axis.
save: If True or a str, save the figure. A string is appended to the default filename.
Expand All @@ -541,13 +540,13 @@ def plot_barplot( # pragma: no cover
Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms_pt.plot_barplot(mdata['rna'], guide_rna_column='NT')
"""
if mixscape_class_global not in adata.obs:
raise ValueError("Please run `pt.tl.mixscape` first.")
raise ValueError("Please run the `mixscape` function first.")
count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column])
all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index()
KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"]
Expand Down Expand Up @@ -604,7 +603,7 @@ def plot_barplot( # pragma: no cover
)
pl.tight_layout()

return pl.gcf()
return ax

def plot_heatmap( # pragma: no cover
self,
Expand Down Expand Up @@ -642,10 +641,10 @@ def plot_heatmap( # pragma: no cover
Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ms = pt.tl.Mixscape()
>>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms_pt.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT')
"""
if "mixscape_class" not in adata.obs:
raise ValueError("Please run `pt.tl.mixscape` first.")
Expand Down Expand Up @@ -676,8 +675,10 @@ def plot_perturbscore( # pragma: no cover
split_by: str = None,
before_mixscape=False,
perturbation_type: str = "KO",
):
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first.
) -> None:
"""Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function.
Requires `pt.tl.mixscape` to be run first.
https://satijalab.org/seurat/reference/plotperturbscore
Expand All @@ -688,9 +689,10 @@ def plot_perturbscore( # pragma: no cover
mixscape_class: The column of `.obs` with mixscape classifications.
color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey.
split_by: Provide the column `.obs` if multiple biological replicates exist to calculate
the perturbation signature for every replicate separately.
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID.
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Default is KO.
the perturbation signature for every replicate separately.
before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification.
Default is set to NULL and plots cells by original class ID.
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to `KO`.
Returns:
The ggplot object used for drawn.
Expand All @@ -700,13 +702,13 @@ def plot_perturbscore( # pragma: no cover
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> mixscape_identifier = pt.tl.Mixscape()
>>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms_pt.plot_perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange')
"""
if "mixscape" not in adata.uns:
raise ValueError("Please run `pt.tl.mixscape` first.")
raise ValueError("Please run the `mixscape` function first.")
perturbation_score = None
for key in adata.uns["mixscape"][target_gene].keys():
perturbation_score_temp = adata.uns["mixscape"][target_gene][key]
Expand Down Expand Up @@ -807,8 +809,6 @@ def plot_perturbscore( # pragma: no cover
pl.legend(title="mixscape class", title_fontsize=14, fontsize=12)
sns.despine()

return pl.gcf()

def plot_violin( # pragma: no cover
self,
adata: AnnData,
Expand Down Expand Up @@ -859,10 +859,10 @@ def plot_violin( # pragma: no cover
Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ms = pt.tl.Mixscape()
>>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms_pt.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class')
"""
if isinstance(target_gene_idents, str):
mixscape_class_mask = adata.obs[groupby] == target_gene_idents
Expand Down Expand Up @@ -1023,12 +1023,11 @@ def plot_lda( # pragma: no cover
Args:
adata: The annotated data object.
control: Control category from the `pert_key` column.
labels: The column of `.obs` with target gene labels.
mixscape_class: The column of `.obs` with the mixscape classification result.
mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT).
perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications.
Defaults to 'KO'.
lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results.
lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results.
n_components: The number of dimensions of the embedding.
show: Show the plot, do not return axis.
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
Expand All @@ -1038,16 +1037,18 @@ def plot_lda( # pragma: no cover
Examples:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ms = pt.tl.Mixscape()
>>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
>>> ms.plot_lda(adata=mdata['rna'], control='NT')
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate')
>>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert')
>>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert')
>>> ms_pt.plot_lda(adata=mdata['rna'], control='NT')
"""
if mixscape_class not in adata.obs:
raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.')
raise ValueError(
f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first first.'
)
if lda_key not in adata.uns:
raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.')
raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Run the `lda` function first.')

adata_subset = adata[
(adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)
Expand Down

0 comments on commit c95bbcb

Please sign in to comment.