Skip to content

Commit

Permalink
PertSpace docs improvements (#471)
Browse files Browse the repository at this point in the history
* Added return statements to perturbation space docs

* Fixed default parameter annotation

* Added and improved parameter descriptions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Lilly-May and pre-commit-ci[bot] authored Dec 27, 2023
1 parent 3019978 commit 0879d41
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 12 deletions.
2 changes: 2 additions & 0 deletions pertpy/tools/_perturbation_space/_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def evaluate_clustering(
true_label_col: ground truth labels.
cluster_col: cluster computed labels.
metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw'].
**kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed.
For asw, metric, distances, sample_size, and random_state can be passed.
Examples:
Example usage with KMeansSpace:
Expand Down
4 changes: 2 additions & 2 deletions pertpy/tools/_perturbation_space/_discriminator_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ def __getitem__(self, idx):
"""Returns a sample and corresponding perturbations applied (labels)"""

sample = self.data[idx].A if scipy.sparse.issparse(self.data) else self.data[idx]
num_label = self.labels[idx]
str_label = self.pert_labels[idx]
num_label = self.labels.iloc[idx]
str_label = self.pert_labels.iloc[idx]

return sample, num_label, str_label

Expand Down
34 changes: 25 additions & 9 deletions pertpy/tools/_perturbation_space/_perturbation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def compute_control_diff( # type: ignore
new_embedding_key: str = "control_diff",
all_data: bool = False,
copy: bool = False,
):
) -> AnnData:
"""Subtract mean of the control from the perturbation.
Args:
Expand All @@ -43,14 +43,18 @@ def compute_control_diff( # type: ignore
group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups. Defaults to 'perturbations'.
reference_key: The key of the control values. Defaults to 'control'.
layer_key: Key of the AnnData layer to use for computation. Defaults to the `X` matrix otherwise.
new_layer_key: the results are stored in the given layer. Defaults to 'differential diff'.
new_layer_key: the results are stored in the given layer. Defaults to 'control_diff'.
embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise.
new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control diff'.
new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control_diff'.
all_data: if True, do the computation in all data representations (X, all layers and all embeddings)
copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object.
Returns:
Updated AnnData object.
Examples:
Example usage with PseudobulkSpace:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ps = pt.tl.PseudobulkSpace()
Expand Down Expand Up @@ -144,18 +148,24 @@ def add(
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
):
) -> AnnData:
"""Add perturbations linearly. Assumes input of size n_perts x dimensionality
Args:
adata: Anndata object of size n_perts x dim.
perturbations: Perturbations to add.
reference_key: perturbation source from which the perturbation summation starts.
reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'.
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
Returns:
Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations.
If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
provided as input but updated using compute_control_diff.
Examples:
Example usage with PseudobulkSpace:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ps = pt.tl.PseudobulkSpace()
Expand Down Expand Up @@ -245,23 +255,29 @@ def subtract(
reference_key: str = "control",
ensure_consistency: bool = False,
target_col: str = "perturbations",
):
) -> AnnData:
"""Subtract perturbations linearly. Assumes input of size n_perts x dimensionality
Args:
adata: Anndata object of size n_perts x dim.
perturbations: Perturbations to subtract,
reference_key: Perturbation source from which the perturbation subtraction starts
perturbations: Perturbations to subtract.
reference_key: Perturbation source from which the perturbation subtraction starts. Defaults to 'control'.
ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space.
target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'.
Returns:
Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations.
If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object
provided as input but updated using compute_control_diff.
Examples:
Example usage with PseudobulkSpace:
>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ps = pt.tl.PseudobulkSpace()
>>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target")
>>> new_perturbation = ps.add(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
>>> new_perturbation = ps.subtract(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"])
"""
new_pert_name = reference_key + "-"
for perturbation in perturbations:
Expand Down
2 changes: 1 addition & 1 deletion pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def compute( # type: ignore
Returns:
If return_object is True, the adata and the clustering object is returned.
Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key,
that stores the cluster labels.
that stores the cluster labels.
Examples:
>>> import pertpy as pt
Expand Down

0 comments on commit 0879d41

Please sign in to comment.