Skip to content

Commit

Permalink
Merge pull request #481 from theislab/feature/categorical_perturbations
Browse files Browse the repository at this point in the history
Save perturbation and clustering labels as categorical
  • Loading branch information
Lilly-May authored Jan 2, 2024
2 parents 3b82005 + 7ded6ba commit 2c6e816
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def add(
key_name = key.removesuffix("_control_diff")
new_perturbation.obsm[key_name] = data["embeddings"][key]

new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")

if ensure_consistency:
return new_perturbation, adata

Expand Down Expand Up @@ -350,6 +352,8 @@ def subtract(
key_name = key.removesuffix("_control_diff")
new_perturbation.obsm[key_name] = data["embeddings"][key]

new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category")

if ensure_consistency:
return new_perturbation, adata

Expand Down
6 changes: 6 additions & 0 deletions pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def compute(
mapping = {pert: obs_df.loc[pert][obs_name] for pert in index}
ps_adata.obs[obs_name] = ps_adata.obs[target_col].map(mapping)

ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")

return ps_adata


Expand Down Expand Up @@ -159,6 +161,8 @@ def compute(

ps_adata = dc.get_pseudobulk(adata, sample_col=target_col, layer=layer_key, **kwargs) # type: ignore

ps_adata.obs[target_col] = ps_adata.obs[target_col].astype("category")

return ps_adata


Expand Down Expand Up @@ -220,6 +224,7 @@ def compute( # type: ignore

clustering = KMeans(**kwargs).fit(self.X)
adata.obs[cluster_key] = clustering.labels_
adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")

if return_object:
return adata, clustering
Expand Down Expand Up @@ -282,6 +287,7 @@ def compute( # type: ignore

clustering = DBSCAN(**kwargs).fit(self.X)
adata.obs[cluster_key] = clustering.labels_
adata.obs[cluster_key] = adata.obs[cluster_key].astype("category")

if return_object:
return adata, clustering
Expand Down

0 comments on commit 2c6e816

Please sign in to comment.