Skip to content

Commit

Permalink
Added Mixscape seeds and test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilly-May committed Dec 9, 2024
1 parent d94428c commit 0c08bf5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
16 changes: 12 additions & 4 deletions pertpy/tools/_mixscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def perturbation_signature(
split_by: str | None = None,
n_neighbors: int = 20,
use_rep: str | None = None,
n_dims: int | None = 15,
n_pcs: int | None = None,
batch_size: int | None = None,
copy: bool = False,
Expand All @@ -66,7 +67,8 @@ def perturbation_signature(
If `None`, the representation is chosen automatically:
For `.n_vars` < 50, `.X` is used, otherwise 'X_pca' is used.
If 'X_pca' is not present, it’s computed with default parameters.
n_pcs: Use this many PCs. If `n_pcs==0` use `.X` if `use_rep is None`.
n_dims: Number of dimensions to use from the representation to calculate the perturbation signature. If `None`, use all dimensions.
n_pcs: If PCA representation is used, the number of principal components to compute. If `n_pcs==0` use `.X` if `use_rep is None`.
batch_size: Size of batch to calculate the perturbation signature.
If 'None', the perturbation signature is calcuated in the full mode, requiring more memory.
The batched mode is very inefficient for sparse data.
Expand Down Expand Up @@ -99,6 +101,8 @@ def perturbation_signature(
split_masks = [split_obs == cat for cat in split_obs.unique()]

representation = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs)
if n_dims is not None and n_dims < representation.shape[1]:
representation = representation[:, :n_dims]

for split_mask in split_masks:
control_mask_split = control_mask & split_mask
Expand Down Expand Up @@ -126,7 +130,7 @@ def perturbation_signature(
shape=(n_split, n_control),
)
neigh_matrix /= n_neighbors
adata.layers["X_pert"][split_mask] -= np.log1p(neigh_matrix @ X_control)
adata.layers["X_pert"][split_mask] = np.log1p(neigh_matrix @ X_control) - adata.layers["X_pert"][split_mask]
else:
is_sparse = issparse(X_control)
split_indices = np.where(split_mask)[0]
Expand All @@ -144,7 +148,7 @@ def perturbation_signature(
means_batch = means_batch.toarray() if is_sparse else means_batch
means_batch = means_batch.reshape(size, n_neighbors, -1).mean(1)

adata.layers["X_pert"][split_batch] -= np.log1p(means_batch)
adata.layers["X_pert"][split_batch] = np.log1p(means_batch) - adata.layers["X_pert"][split_batch]

if copy:
return adata
Expand All @@ -162,11 +166,13 @@ def mixscape(
split_by: str | None = None,
pval_cutoff: float | None = 5e-2,
perturbation_type: str | None = "KO",
random_state: int | None = 0,
copy: bool | None = False,
):
"""Identify perturbed and non-perturbed gRNA expressing cells that accounts for multiple treatments/conditions/chemical perturbations.
The implementation resembles https://satijalab.org/seurat/reference/runmixscape
The implementation resembles https://satijalab.org/seurat/reference/runmixscape. Note that in the original implementation, the
perturbation signature is calculated on unscaled data by default and we therefore recommend to do the same.
Args:
adata: The annotated data object.
Expand All @@ -181,6 +187,7 @@ def mixscape(
the perturbation signature for every replicate separately.
pval_cutoff: P-value cut-off for selection of significantly DE genes.
perturbation_type: specify type of CRISPR perturbation expected for labeling mixscape classifications.
random_state: Random seed for the GaussianMixture model.
copy: Determines whether a copy of the `adata` is returned.
Returns:
Expand Down Expand Up @@ -293,6 +300,7 @@ def mixscape(
covariance_type="spherical",
means_init=means_init,
precisions_init=precisions_init,
random_state=random_state,
).fit(np.asarray(pvec).reshape(-1, 1))
probabilities = mm.predict_proba(np.array(pvec[orig_guide_cells_index]).reshape(-1, 1))
lik_ratio = probabilities[:, 0] / probabilities[:, 1]
Expand Down
38 changes: 38 additions & 0 deletions tests/tools/test_mixscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,41 @@ def test_lda(adata):
mixscape_identifier.lda(adata=adata, labels="gene_target", control="NT")

assert "mixscape_lda" in adata.uns

def test_deterministic_perturbation_signature():
n_genes = 5
n_cells_per_class = 50
cell_classes = ["NT", "KO", "NP"]
groups = ["Group1", "Group2"]

cell_classes_array = np.repeat(cell_classes, n_cells_per_class)
groups_array = np.tile(np.repeat(groups, n_cells_per_class // 2), len(cell_classes))
obs = pd.DataFrame({"cell_class": cell_classes_array, "group": groups_array,
"perturbation": ["control" if cell_class == "NT" else "pert1" for cell_class in cell_classes_array]})

data = np.zeros((len(obs), n_genes))
pert_effect = np.random.uniform(-1, 1, size=(n_cells_per_class//len(groups), n_genes))
for group_idx, group in enumerate(groups):
baseline_expr = 2 if group == "Group1" else 10
group_mask = obs["group"] == group

nt_mask = (obs["cell_class"] == "NT") & group_mask
data[nt_mask] = baseline_expr

ko_mask = (obs["cell_class"] == "KO") & group_mask
data[ko_mask] = baseline_expr + pert_effect

np_mask = (obs["cell_class"] == "NP") & group_mask
data[np_mask] = baseline_expr

var = pd.DataFrame(index=[f"Gene{i + 1}" for i in range(n_genes)])
adata = anndata.AnnData(X=data, obs=obs, var=var)

mixscape_identifier = pt.tl.Mixscape()
mixscape_identifier.perturbation_signature(adata, pert_key="perturbation", control="control", n_neighbors=5, split_by="group")

assert "X_pert" in adata.layers
assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NT"], 0)
assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "NP"], 0)
assert np.allclose(adata.layers["X_pert"][obs["cell_class"] == "KO"], -np.concatenate([pert_effect] * len(groups), axis=0))

0 comments on commit 0c08bf5

Please sign in to comment.