diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index fddc54c2..30f466b9 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -44,6 +44,7 @@ def perturbation_signature( split_by: str | None = None, n_neighbors: int = 20, use_rep: str | None = None, + n_dims: int | None = 15, n_pcs: int | None = None, batch_size: int | None = None, copy: bool = False, @@ -66,7 +67,8 @@ def perturbation_signature( If `None`, the representation is chosen automatically: For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used. If 'X_pca' is not present, it’s computed with default parameters. - n_pcs: Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`. + n_dims: Number of dimensions to use from the representation to calculate the perturbation signature. If `None`, use all dimensions. + n_pcs: If PCA representation is used, the number of principal components to compute. If `n_pcs==0` use `.X` if `use_rep is None`. batch_size: Size of batch to calculate the perturbation signature. If 'None', the perturbation signature is calcuated in the full mode, requiring more memory. The batched mode is very inefficient for sparse data. @@ -99,6 +101,8 @@ def perturbation_signature( split_masks = [split_obs == cat for cat in split_obs.unique()] representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) + if n_dims is not None and n_dims < representation.shape[1]: + representation = representation[:, :n_dims] for split_mask in split_masks: control_mask_split = control_mask & split_mask @@ -126,7 +130,7 @@ def perturbation_signature( shape=(n_split, n_control), ) neigh_matrix /= n_neighbors - adata.layers["X_pert"][split_mask] -= np.log1p(neigh_matrix @ X_control) + adata.layers["X_pert"][split_mask] = np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask] else: is_sparse = issparse(X_control) split_indices = np.where(split_mask)[0] @@ -144,7 +148,7 @@ def perturbation_signature( means_batch = means_batch.toarray() if is_sparse else means_batch means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1) - adata.layers["X_pert"][split_batch] -= np.log1p(means_batch) + adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch] if copy: return adata @@ -162,11 +166,13 @@ def mixscape( split_by: str | None = None, pval_cutoff: float | None = 5e-2, perturbation_type: str | None = "KO", + random_state: int | None = 0, copy: bool | None = False, ): """Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations. - The implementation resembles https://satijalab.org/seurat/reference/runmixscape + The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the + perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same. Args: adata: The annotated data object. @@ -181,6 +187,7 @@ def mixscape( the perturbation signature for every replicate separately. pval_cutoff: P-value cut-off for selection of significantly DE genes. perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications. + random_state: Random seed for the GaussianMixture model. copy: Determines whether a copy of the `adata` is returned. Returns: @@ -293,6 +300,7 @@ def mixscape( covariance_type="spherical", means_init=means_init, precisions_init=precisions_init, + random_state=random_state, ).fit(np.asarray(pvec).reshape(-1, 1)) probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1)) lik_ratio = probabilities[:, 0] / probabilities[:, 1] diff --git a/tests/tools/test_mixscape.py b/tests/tools/test_mixscape.py index ec9387c1..8f160669 100644 --- a/tests/tools/test_mixscape.py +++ b/tests/tools/test_mixscape.py @@ -98,3 +98,41 @@ def test_lda(adata): mixscape_identifier.lda(adata=adata, labels="gene_target", control="NT") assert "mixscape_lda" in adata.uns + +def test_deterministic_perturbation_signature(): + n_genes = 5 + n_cells_per_class = 50 + cell_classes = ["NT", "KO", "NP"] + groups = ["Group1", "Group2"] + + cell_classes_array = np.repeat(cell_classes, n_cells_per_class) + groups_array = np.tile(np.repeat(groups, n_cells_per_class // 2), len(cell_classes)) + obs = pd.DataFrame({"cell_class": cell_classes_array, "group": groups_array, + "perturbation": ["control" if cell_class == "NT" else "pert1" for cell_class in cell_classes_array]}) + + data = np.zeros((len(obs), n_genes)) + pert_effect = np.random.uniform(-1, 1, size=(n_cells_per_class//len(groups), n_genes)) + for group_idx, group in enumerate(groups): + baseline_expr = 2 if group == "Group1" else 10 + group_mask = obs["group"] == group + + nt_mask = (obs["cell_class"] == "NT") & group_mask + data[nt_mask] = baseline_expr + + ko_mask = (obs["cell_class"] == "KO") & group_mask + data[ko_mask] = baseline_expr + pert_effect + + np_mask = (obs["cell_class"] == "NP") & group_mask + data[np_mask] = baseline_expr + + var = pd.DataFrame(index=[f"Gene{i + 1}" for i in range(n_genes)]) + adata = anndata.AnnData(X=data, obs=obs, var=var) + + mixscape_identifier = pt.tl.Mixscape() + mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", n_neighbors=5, split_by="group") + + assert "X_pert" in adata.layers + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0) + assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0)) +