diff --git a/pertpy/tools/_perturbation_space/_clustering.py b/pertpy/tools/_perturbation_space/_clustering.py index 30fa2482..b86acf14 100644 --- a/pertpy/tools/_perturbation_space/_clustering.py +++ b/pertpy/tools/_perturbation_space/_clustering.py @@ -31,6 +31,8 @@ def evaluate_clustering( true_label_col: ground truth labels. cluster_col: cluster computed labels. metrics: Metrics to compute. Defaults to ['nmi', 'ari', 'asw']. + **kwargs: Additional arguments to pass to the metrics. For nmi, average_method can be passed. + For asw, metric, distances, sample_size, and random_state can be passed. Examples: Example usage with KMeansSpace: diff --git a/pertpy/tools/_perturbation_space/_discriminator_classifier.py b/pertpy/tools/_perturbation_space/_discriminator_classifier.py index 7d43fb85..8df569c3 100644 --- a/pertpy/tools/_perturbation_space/_discriminator_classifier.py +++ b/pertpy/tools/_perturbation_space/_discriminator_classifier.py @@ -286,8 +286,8 @@ def __getitem__(self, idx): """Returns a sample and corresponding perturbations applied (labels)""" sample = self.data[idx].A if scipy.sparse.issparse(self.data) else self.data[idx] - num_label = self.labels[idx] - str_label = self.pert_labels[idx] + num_label = self.labels.iloc[idx] + str_label = self.pert_labels.iloc[idx] return sample, num_label, str_label diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index 275f4078..32ae0ab6 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -34,7 +34,7 @@ def compute_control_diff( # type: ignore new_embedding_key: str = "control_diff", all_data: bool = False, copy: bool = False, - ): + ) -> AnnData: """Subtract mean of the control from the perturbation. Args: @@ -43,14 +43,18 @@ def compute_control_diff( # type: ignore group_col: .obs column name that stores the label of the group of eah cell. If None, ignore groups. Defaults to 'perturbations'. reference_key: The key of the control values. Defaults to 'control'. layer_key: Key of the AnnData layer to use for computation. Defaults to the `X` matrix otherwise. - new_layer_key: the results are stored in the given layer. Defaults to 'differential diff'. + new_layer_key: the results are stored in the given layer. Defaults to 'control_diff'. embedding_key: `obsm` key of the AnnData embedding to use for computation. Defaults to the 'X' matrix otherwise. - new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control diff'. + new_embedding_key: Results are stored in a new embedding in `obsm` with this key. Defaults to 'control_diff'. all_data: if True, do the computation in all data representations (X, all layers and all embeddings) copy: If True returns a new Anndata of same size with the new column; otherwise it updates the initial AnnData object. + Returns: + Updated AnnData object. + Examples: Example usage with PseudobulkSpace: + >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() @@ -144,18 +148,24 @@ def add( reference_key: str = "control", ensure_consistency: bool = False, target_col: str = "perturbations", - ): + ) -> AnnData: """Add perturbations linearly. Assumes input of size n_perts x dimensionality Args: adata: Anndata object of size n_perts x dim. perturbations: Perturbations to add. - reference_key: perturbation source from which the perturbation summation starts. + reference_key: perturbation source from which the perturbation summation starts. Defaults to 'control'. ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space. target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'. + Returns: + Anndata object of size (n_perts+1) x dim, where the last row is the addition of the specified perturbations. + If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object + provided as input but updated using compute_control_diff. + Examples: Example usage with PseudobulkSpace: + >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() @@ -245,23 +255,29 @@ def subtract( reference_key: str = "control", ensure_consistency: bool = False, target_col: str = "perturbations", - ): + ) -> AnnData: """Subtract perturbations linearly. Assumes input of size n_perts x dimensionality Args: adata: Anndata object of size n_perts x dim. - perturbations: Perturbations to subtract, - reference_key: Perturbation source from which the perturbation subtraction starts + perturbations: Perturbations to subtract. + reference_key: Perturbation source from which the perturbation subtraction starts. Defaults to 'control'. ensure_consistency: If True, runs differential expression on all data matrices to ensure consistency of linear space. target_col: .obs column name that stores the label of the perturbation applied to each cell. Defaults to 'perturbations'. + Returns: + Anndata object of size (n_perts+1) x dim, where the last row is the subtraction of the specified perturbations. + If ensure_consistency is True, returns a tuple of (new_perturbation, adata) where adata is the AnnData object + provided as input but updated using compute_control_diff. + Examples: Example usage with PseudobulkSpace: + >>> import pertpy as pt >>> mdata = pt.dt.papalexi_2021() >>> ps = pt.tl.PseudobulkSpace() >>> ps_adata = ps.compute(mdata["rna"], target_col="gene_target", groups_col="gene_target") - >>> new_perturbation = ps.add(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"]) + >>> new_perturbation = ps.subtract(ps_adata, reference_key="ATF2", perturbations=["BRD4", "CUL3"]) """ new_pert_name = reference_key + "-" for perturbation in perturbations: diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index 8c21ad5d..7a7dbb18 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -254,7 +254,7 @@ def compute( # type: ignore Returns: If return_object is True, the adata and the clustering object is returned. Otherwise, only the adata is returned. The adata is updated with a new .obs column as specified in cluster_key, - that stores the cluster labels. + that stores the cluster labels. Examples: >>> import pertpy as pt