Skip to content

Commit

Permalink
Sparse guide RNA plot (#554)
Browse files Browse the repository at this point in the history
* Rename cmap to palette

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Rename cmap to palette

Signed-off-by: zethson <lukas.heumos@posteo.net>

* sparse support for plot guide rna

Signed-off-by: zethson <lukas.heumos@posteo.net>

---------

Signed-off-by: zethson <lukas.heumos@posteo.net>
  • Loading branch information
Zethson authored Mar 11, 2024
1 parent 27fdd78 commit e6ec187
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions pertpy/preprocessing/_guide_rna.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import uuid
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import scanpy as sc
import scipy

Expand Down Expand Up @@ -150,28 +152,38 @@ def plot_heatmap(
data = adata.X if layer is None else adata.layers[layer]

if order_by is None:
max_guide_index = np.where(
np.array(data.max(axis=1)).squeeze() != data.min(), np.array(data.argmax(axis=1)).squeeze(), -1
)
if scipy.sparse.issparse(data):
max_values = data.max(axis=1).A.squeeze()
data_argmax = data.argmax(axis=1).A.squeeze()
max_guide_index = np.where(max_values != data.min(axis=1).A.squeeze(), data_argmax, -1)
else:
max_guide_index = np.where(
data.max(axis=1).squeeze() != data.min(axis=1).squeeze(), data.argmax(axis=1).squeeze(), -1
)
order = np.argsort(max_guide_index)
elif isinstance(order_by, str):
order = adata.obs[order_by]
order = np.argsort(adata.obs[order_by])
else:
order = order_by

adata.obs["_tmp_pertpy_grna_plot_dummy_group"] = ""
temp_col_name = f"_tmp_pertpy_grna_plot_{uuid.uuid4()}"
adata.obs[temp_col_name] = pd.Categorical(["" for _ in range(adata.shape[0])])

if key_to_save_order is not None:
adata.obs[key_to_save_order] = order
axis_group = sc.pl.heatmap(
adata[order],
adata.var.index.tolist(),
groupby="_tmp_pertpy_grna_plot_dummy_group",
cmap="viridis",
use_raw=False,
dendrogram=False,
layer=layer,
**kwargs,
)
del adata.obs["_tmp_pertpy_grna_plot_dummy_group"]
adata.obs[key_to_save_order] = pd.Categorical(order)

try:
axis_group = sc.pl.heatmap(
adata[order, :],
var_names=adata.var.index.tolist(),
groupby=temp_col_name,
cmap="viridis",
use_raw=False,
dendrogram=False,
layer=layer,
**kwargs,
)
finally:
del adata.obs[temp_col_name]

return axis_group

0 comments on commit e6ec187

Please sign in to comment.