Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/theislab/pertpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Zethson committed Sep 26, 2023
2 parents c18d1ab + 1c67eb6 commit 2038248
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 39 deletions.
139 changes: 104 additions & 35 deletions pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -46,16 +47,13 @@ class Distance:
obsm_key: Name of embedding in adata.obsm to use.
metric_fct: Distance metric function.
Example:
.. code-block:: python
import pertpy as pt
adata = pt.dt.distance_example_data()
Distance = pt.tools.Distance(metric="edistance")
X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
D = Distance(X, Y)
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.distance_example_data()
>>> Distance = pt.tools.Distance(metric="edistance")
>>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
>>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
>>> D = Distance(X, Y)
"""

def __init__(
Expand Down Expand Up @@ -101,26 +99,23 @@ def __call__(
Returns:
float: Distance between X and Y.
Example:
.. code-block:: python
import pertpy as pt
adata = pt.dt.distance_example_data()
Distance = pt.tools.Distance(metric="edistance")
X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
D = Distance(X, Y)
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.distance_example_data()
>>> Distance = pt.tools.Distance(metric="edistance")
>>> X = adata.obsm["X_pca"][adata.obs["perturbation"] == "p-sgCREB1-2"]
>>> Y = adata.obsm["X_pca"][adata.obs["perturbation"] == "control"]
>>> D = Distance(X, Y)
"""
return self.metric_fct(X, Y, **kwargs)

def pairwise(
self,
adata: AnnData,
groupby: str,
groups: list[str] | None = None,
verbose: bool = True,
n_jobs: int = 1,
groups: Iterable = None,
show_progressbar: bool = True,
n_jobs: int = -1,
**kwargs,
) -> pd.DataFrame:
"""Get pairwise distances between groups of cells.
Expand All @@ -130,25 +125,27 @@ def pairwise(
groupby: Column name in adata.obs.
groups: List of groups to compute pairwise distances for.
If None, uses all groups. Defaults to None.
verbose: Whether to show progress bar. Defaults to True.
show_progressbar: Whether to show progress bar. Defaults to True.
Returns:
pd.DataFrame: Dataframe with pairwise distances.
Example:
.. code-block:: python
import pertpy as pt
adata = pt.dt.distance_example_data()
Distance = pt.tools.Distance(metric="edistance")
pairwise_df = distance.pairwise(adata, groupby="perturbation")
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.distance_example_data()
>>> Distance = pt.tools.Distance(metric="edistance")
>>> pairwise_df = distance.pairwise(adata, groupby="perturbation")
"""
groups = adata.obs[groupby].unique() if groups is None else groups
grouping = adata.obs[groupby].copy()
df = pd.DataFrame(index=groups, columns=groups, dtype=float)
fct = track if verbose else lambda iterable: iterable
fct = track if show_progressbar else lambda iterable: iterable

# Some metrics are able to handle precomputed distances. This means that
# the pairwise distances between all cells are computed once and then
# passed to the metric function. This is much faster than computing the
# pairwise distances for each group separately. Other metrics are not
# able to handle precomputed distances such as the PsuedobulkDistance.
if self.metric_fct.accepts_precomputed:
# Precompute the pairwise distances if needed
if f"{self.obsm_key}_predistances" not in adata.obsp.keys():
Expand All @@ -158,9 +155,10 @@ def pairwise(
idx_x = grouping == group_x
for group_y in groups[index_x:]:
if group_x == group_y:
dist = 0.0
dist = 0.0 # by distance axiom
else:
idx_y = grouping == group_y
# subset the pairwise distance matrix to the two groups
sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
sub_idx = grouping[idx_x | idx_y] == group_x
dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
Expand All @@ -182,8 +180,79 @@ def pairwise(
df.columns.name = groupby
df.name = f"pairwise {self.metric}"
return df

def onesided_distances(
self,
adata: AnnData,
groupby: str,
selected_group: str | None = None,
groups: Iterable = None,
show_progressbar: bool = True,
n_jobs: int = -1,
**kwargs,
) -> pd.DataFrame:
"""Get pairwise distances between groups of cells.
Args:
adata: Annotated data matrix.
groupby: Column name in adata.obs.
selected_group: Group to compute pairwise distances to all other.
groups: List of groups to compute distances to selected_group for.
If None, uses all groups. Defaults to None.
show_progressbar: Whether to show progress bar. Defaults to True.
Returns:
pd.DataFrame: Dataframe with distances of groups to selected_group.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.distance_example_data()
>>> Distance = pt.tools.Distance(metric="edistance")
>>> pairwise_df = distance.onesided_distances(adata, groupby="perturbation", selected_group="control")
"""
groups = adata.obs[groupby].unique() if groups is None else groups
grouping = adata.obs[groupby].copy()
df = pd.Series(index=groups, dtype=float)
fct = track if show_progressbar else lambda iterable: iterable

# Some metrics are able to handle precomputed distances. This means that
# the pairwise distances between all cells are computed once and then
# passed to the metric function. This is much faster than computing the
# pairwise distances for each group separately. Other metrics are not
# able to handle precomputed distances such as the PsuedobulkDistance.
if self.metric_fct.accepts_precomputed:
# Precompute the pairwise distances if needed
if f"{self.obsm_key}_predistances" not in adata.obsp.keys():
self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
pwd = adata.obsp[f"{self.obsm_key}_predistances"]
for group_x in fct(groups):
idx_x = grouping == group_x
group_y = selected_group
if group_x == group_y:
dist = 0.0 # by distance axiom
else:
idx_y = grouping == group_y
# subset the pairwise distance matrix to the two groups
sub_pwd = pwd[idx_x | idx_y, :][:, idx_x | idx_y]
sub_idx = grouping[idx_x | idx_y] == group_x
dist = self.metric_fct.from_precomputed(sub_pwd, sub_idx, **kwargs)
df.loc[group_x] = dist
else:
embedding = adata.obsm[self.obsm_key].copy()
for group_x in fct(groups):
cells_x = embedding[grouping == group_x].copy()
group_y = selected_group
if group_x == group_y:
dist = 0.0
else:
cells_y = embedding[grouping == group_y].copy()
dist = self.metric_fct(cells_x, cells_y, **kwargs)
df.loc[group_x] = dist
df.index.name = groupby
df.name = f"{self.metric} to {selected_group}"
return df

def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean", n_jobs: int = None) -> None:
def precompute_distances(self, adata: AnnData, cell_wise_metric: str = "euclidean", n_jobs: int = -1) -> None:
"""Precompute pairwise distances between all cells, writes to adata.obsp.
The precomputed distances are stored in adata.obsp under the key
Expand Down
16 changes: 12 additions & 4 deletions tests/tools/_distances/test_distances.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pertpy as pt
from pandas import DataFrame
from pandas import DataFrame, Series
from pytest import fixture, mark

actual_distances = ["edistance", "pseudobulk", "wasserstein"]
Expand All @@ -17,7 +17,7 @@ def adata(self):
def test_distance_axioms(self, adata, distance):
# Test if distances are well-defined in accordance with metric axioms
Distance = pt.tl.Distance(distance, "X_pca")
df = Distance.pairwise(adata, groupby="perturbation", verbose=True)
df = Distance.pairwise(adata, groupby="perturbation", show_progressbar=True)
# (M1) Positiv definiteness
assert all(np.diag(df.values) == 0) # distance to self is 0
assert len(df) == np.sum(df.values == 0) # distance to other is not 0
Expand All @@ -30,9 +30,17 @@ def test_distance_axioms(self, adata, distance):
assert df.loc[triplet[0], triplet[1]] + df.loc[triplet[1], triplet[2]] >= df.loc[triplet[0], triplet[2]]

@mark.parametrize("distance", actual_distances + pseudo_distances)
def test_distance(self, adata, distance):
def test_distance_pairwise(self, adata, distance):
Distance = pt.tl.Distance(distance, "X_pca")
df = Distance.pairwise(adata, groupby="perturbation", verbose=True)
df = Distance.pairwise(adata, groupby="perturbation", show_progressbar=True)
assert isinstance(df, DataFrame)
assert df.columns.equals(df.index)
assert np.sum(df.values - df.values.T) == 0 # symmetry

@mark.parametrize("distance", actual_distances + pseudo_distances)
def test_distance_onesided(self, adata, distance):
Distance = pt.tl.Distance(distance, "X_pca")
selected_group = adata.obs.perturbation.unique()[0]
df = Distance.onesided_distances(adata, groupby="perturbation", selected_group=selected_group, show_progressbar=True)
assert isinstance(df, Series)
assert df.loc[selected_group] == 0 # distance to self is 0

0 comments on commit 2038248

Please sign in to comment.