diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 81efa334..49a227a2 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -74,15 +74,15 @@ def perturbation_signature( Returns: If `copy=True`, returns the copy of `adata` with the perturbation signature in `.layers["X_pert"]`. - Otherwise writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`. + Otherwise, writes the perturbation signature directly to `.layers["X_pert"]` of the provided `adata`. Examples: Calcutate perturbation signature for each cell in the dataset: >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> mixscape_identifier = pt.tl.Mixscape() - >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') """ if copy: adata = adata.copy() @@ -95,18 +95,17 @@ def perturbation_signature( split_masks = [np.full(adata.n_obs, True, dtype=bool)] else: split_obs = adata.obs[split_by] - cats = split_obs.unique() - split_masks = [split_obs == cat for cat in cats] + split_masks = [split_obs == cat for cat in split_obs.unique()] - R = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) for split_mask in split_masks: control_mask_split = control_mask & split_mask - R_split = R[split_mask] - R_control = R[control_mask_split] + R_split = representation[split_mask] + R_control = representation[control_mask_split] - from pynndescent import NNDescent # saves a lot of import time + from pynndescent import NNDescent eps = kwargs.pop("epsilon", 0.1) nn_index = NNDescent(R_control, **kwargs) @@ -170,7 +169,6 @@ def mixscape( Args: adata: The annotated data object. - pert_key: The column of `.obs` with perturbation categories, should also contain `control`. labels: The column of `.obs` with target gene labels. control: Control category from the `pert_key` column. new_class_name: Name of mixscape classification to be stored in `.obs`. @@ -186,26 +184,26 @@ def mixscape( Returns: If `copy=True`, returns the copy of `adata` with the classification result in `.obs`. - Otherwise writes the results directly to `.obs` of the provided `adata`. + Otherwise, writes the results directly to `.obs` of the provided `adata`. - mixscape_class: pandas.Series (`adata.obs['mixscape_class']`). - Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class. + - mixscape_class: pandas.Series (`adata.obs['mixscape_class']`). + Classification result with cells being either classified as perturbed (KO, by default) or non-perturbed (NP) based on their target gene class. - mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`). - Global classification result (perturbed, NP or NT) + - mixscape_class_global: pandas.Series (`adata.obs['mixscape_class_global']`). + Global classification result (perturbed, NP or NT). - mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`). - Posterior probabilities used to determine if a cell is KO (default). - Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP + - mixscape_class_p_ko: pandas.Series (`adata.obs['mixscape_class_p_ko']`). + Posterior probabilities used to determine if a cell is KO (default). + Name of this item will change to match perturbation_type parameter setting. (>0.5) or NP. Examples: Calcutate perturbation signature for each cell in the dataset: >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> mixscape_identifier = pt.tl.Mixscape() - >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') """ if copy: adata = adata.copy() @@ -229,10 +227,9 @@ def mixscape( try: X = adata_comp.layers["X_pert"] except KeyError: - print( - '[bold yellow]No "X_pert" found in .layers! -- Please run pert_sign first to calculate perturbation signature!' - ) - raise + raise KeyError( + "No 'X_pert' found in .layers! Please run pert_sign first to calculate perturbation signature!" + ) from None # initialize return variables adata.obs[f"{new_class_name}_p_{perturbation_type.lower()}"] = 0 adata.obs[new_class_name] = adata.obs[labels].astype(str) @@ -351,15 +348,17 @@ def lda( control: Control category from the `pert_key` column. Defaults to 'NT'. n_comps: Number of principal components to use. Defaults to 10. min_de_genes: Required number of genes that are differentially expressed for method to separate perturbed and non-perturbed cells. - logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. Defaults to 0.25. + logfc_threshold: Limit testing to genes which show, on average, at least X-fold difference (log-scale) between the two groups of cells. + Defaults to 0.25. split_by: Provide the column `.obs` if multiple biological replicates exist to calculate pval_cutoff: P-value cut-off for selection of significantly DE genes. - perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to KO. + perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications. + Defaults to KO. copy: Determines whether a copy of the `adata` is returned. Returns: If `copy=True`, returns the copy of `adata` with the LDA result in `.uns`. - Otherwise writes the results directly to `.uns` of the provided `adata`. + Otherwise, writes the results directly to `.uns` of the provided `adata`. mixscape_lda: numpy.ndarray (`adata.uns['mixscape_lda']`). LDA result. @@ -369,10 +368,10 @@ def lda( >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> mixscape_identifier = pt.tl.Mixscape() - >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> mixscape_identifier.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') """ if copy: adata = adata.copy() @@ -524,12 +523,12 @@ def plot_barplot( # pragma: no cover show: bool | None = None, save: bool | str | None = None, ): - """Barplot to visualize perturbation scores calculated from RunMixscape function. + """Barplot to visualize perturbation scores calculated by the `mixscape` function. Args: adata: The annotated data object. guide_rna_column: The column of `.obs` with guide RNA labels. The target gene labels. - The format must be g<#>. For example, 'STAT2g1' and 'ATF2g1'. + The format must be g<#>. Examples are 'STAT2g1' and 'ATF2g1'. mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT). show: Show the plot, do not return axis. save: If True or a str, save the figure. A string is appended to the default filename. @@ -541,13 +540,13 @@ def plot_barplot( # pragma: no cover Examples: >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> mixscape_identifier = pt.tl.Mixscape() - >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> pt.pl.ms.barplot(mdata['rna'], guide_rna_column='NT') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt.plot_barplot(mdata['rna'], guide_rna_column='NT') """ if mixscape_class_global not in adata.obs: - raise ValueError("Please run `pt.tl.mixscape` first.") + raise ValueError("Please run the `mixscape` function first.") count = pd.crosstab(index=adata.obs[mixscape_class_global], columns=adata.obs[guide_rna_column]) all_cells_percentage = pd.melt(count / count.sum(), ignore_index=False).reset_index() KO_cells_percentage = all_cells_percentage[all_cells_percentage[mixscape_class_global] == "KO"] @@ -604,7 +603,7 @@ def plot_barplot( # pragma: no cover ) pl.tight_layout() - return pl.gcf() + return ax def plot_heatmap( # pragma: no cover self, @@ -642,10 +641,10 @@ def plot_heatmap( # pragma: no cover Examples: >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> ms = pt.tl.Mixscape() - >>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt.plot_heatmap(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', layer='X_pert', control='NT') """ if "mixscape_class" not in adata.obs: raise ValueError("Please run `pt.tl.mixscape` first.") @@ -676,8 +675,10 @@ def plot_perturbscore( # pragma: no cover split_by: str = None, before_mixscape=False, perturbation_type: str = "KO", - ): - """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. Requires `pt.tl.mixscape` to be run first. + ) -> None: + """Density plots to visualize perturbation scores calculated by the `pt.tl.mixscape` function. + + Requires `pt.tl.mixscape` to be run first. https://satijalab.org/seurat/reference/plotperturbscore @@ -688,9 +689,10 @@ def plot_perturbscore( # pragma: no cover mixscape_class: The column of `.obs` with mixscape classifications. color: Specify color of target gene class or knockout cell class. For control non-targeting and non-perturbed cells, colors are set to different shades of grey. split_by: Provide the column `.obs` if multiple biological replicates exist to calculate - the perturbation signature for every replicate separately. - before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. Default is set to NULL and plots cells by original class ID. - perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Default is KO. + the perturbation signature for every replicate separately. + before_mixscape: Option to split densities based on mixscape classification (default) or original target gene classification. + Default is set to NULL and plots cells by original class ID. + perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to `KO`. Returns: The ggplot object used for drawn. @@ -700,13 +702,13 @@ def plot_perturbscore( # pragma: no cover >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> mixscape_identifier = pt.tl.Mixscape() - >>> mixscape_identifier.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> mixscape_identifier.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> pt.pl.ms.perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt.plot_perturbscore(adata = mdata['rna'], labels='gene_target', target_gene='IFNGR2', color = 'orange') """ if "mixscape" not in adata.uns: - raise ValueError("Please run `pt.tl.mixscape` first.") + raise ValueError("Please run the `mixscape` function first.") perturbation_score = None for key in adata.uns["mixscape"][target_gene].keys(): perturbation_score_temp = adata.uns["mixscape"][target_gene][key] @@ -807,8 +809,6 @@ def plot_perturbscore( # pragma: no cover pl.legend(title="mixscape class", title_fontsize=14, fontsize=12) sns.despine() - return pl.gcf() - def plot_violin( # pragma: no cover self, adata: AnnData, @@ -859,10 +859,10 @@ def plot_violin( # pragma: no cover Examples: >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> ms = pt.tl.Mixscape() - >>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt.plot_violin(adata = mdata['rna'], target_gene_idents=['NT', 'IFNGR2 NP', 'IFNGR2 KO'], groupby='mixscape_class') """ if isinstance(target_gene_idents, str): mixscape_class_mask = adata.obs[groupby] == target_gene_idents @@ -1023,12 +1023,11 @@ def plot_lda( # pragma: no cover Args: adata: The annotated data object. control: Control category from the `pert_key` column. - labels: The column of `.obs` with target gene labels. mixscape_class: The column of `.obs` with the mixscape classification result. mixscape_class_global: The column of `.obs` with mixscape global classification result (perturbed, NP or NT). perturbation_type: Specify type of CRISPR perturbation expected for labeling mixscape classifications. Defaults to 'KO'. - lda_key: If not speficied, lda looks .uns["mixscape_lda"] for the LDA results. + lda_key: If not specified, lda looks .uns["mixscape_lda"] for the LDA results. n_components: The number of dimensions of the embedding. show: Show the plot, do not return axis. save: If `True` or a `str`, save the figure. A string is appended to the default filename. @@ -1038,16 +1037,18 @@ def plot_lda( # pragma: no cover Examples: >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() - >>> ms = pt.tl.Mixscape() - >>> ms.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') - >>> ms.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') - >>> ms.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') - >>> ms.plot_lda(adata=mdata['rna'], control='NT') + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata['rna'], 'perturbation', 'NT', 'replicate') + >>> ms_pt.mixscape(adata = mdata['rna'], control = 'NT', labels='gene_target', layer='X_pert') + >>> ms_pt.lda(adata=mdata['rna'], control='NT', labels='gene_target', layer='X_pert') + >>> ms_pt.plot_lda(adata=mdata['rna'], control='NT') """ if mixscape_class not in adata.obs: - raise ValueError(f'Did not find .obs["{mixscape_class!r}"]. Please run `pt.tl.mixscape` first.') + raise ValueError( + f'Did not find `.obs["{mixscape_class!r}"]`. Please run the `mixscape` function first first.' + ) if lda_key not in adata.uns: - raise ValueError(f'Did not find .uns["{lda_key!r}"]. Run `pt.tl.neighbors` first.') + raise ValueError(f'Did not find `.uns["{lda_key!r}"]`. Run the `lda` function first.') adata_subset = adata[ (adata.obs[mixscape_class_global] == perturbation_type) | (adata.obs[mixscape_class_global] == control)