Skip to content

Commit

Permalink
Switch to formulaic-contrasts (#682)
Browse files Browse the repository at this point in the history
* Switch to formulaic-contrasts

* Cleanup

* removing design matrix workaround (#691)

Co-authored-by: Emma Dann <emmadann@comino.stanford.edu>

* Fix PyDESeq2

* Update tests

* fix typo in gitignore

* Remove contrast dataclass, which isnt used anywhere

* Fix edgeR rpy2 tests (#692)

* fix broken rpy2 edger tests

* updated edger tests

* Fix tests (scipy)

Signed-off-by: zethson <lukas.heumos@posteo.net>

* submodule

Signed-off-by: zethson <lukas.heumos@posteo.net>

* Remove unused code

Signed-off-by: zethson <lukas.heumos@posteo.net>

* type hints

Signed-off-by: zethson <lukas.heumos@posteo.net>

---------

Signed-off-by: zethson <lukas.heumos@posteo.net>
Co-authored-by: Emma Dann <32264060+emdann@users.noreply.github.com>
Co-authored-by: Emma Dann <emmadann@comino.stanford.edu>
Co-authored-by: zethson <lukas.heumos@posteo.net>
  • Loading branch information
4 people authored Jan 4, 2025
1 parent 447fec9 commit e43d8ff
Show file tree
Hide file tree
Showing 16 changed files with 146 additions and 560 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ dmypy.json
# Jetbrains IDE
.idea/

# VSCode
.vscode

# Coala
*.orig

Expand All @@ -160,3 +163,6 @@ node_modules
test.ipynb
test-perturbation
test-bug

# uv
uv.lock
3 changes: 3 additions & 0 deletions docs/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ For questions about the usage of pertpy use the [scverse discourse](https://disc
## Quick start: Tool specific tutorials

### Data transformation

```{eval-rst}
.. nbgallery::
Expand All @@ -25,6 +26,7 @@ For questions about the usage of pertpy use the [scverse discourse](https://disc
```

### Knowledge inference

```{eval-rst}
.. nbgallery::
Expand All @@ -43,6 +45,7 @@ For questions about the usage of pertpy use the [scverse discourse](https://disc
```

## Use cases

Our use cases showcase a variety of pertpy tools applied to one dataset.
They are designed to give you a sense of how to use pertpy in a real-world scenario.
The use cases featured here are those we present in the pertpy [preprint](https://www.biorxiv.org/content/10.1101/2024.08.04.606516v1).
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
3 changes: 1 addition & 2 deletions pertpy/tools/_differential_gene_expression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._base import ContrastType, LinearModelBase, MethodBase
from ._base import LinearModelBase, MethodBase
from ._dge_comparison import DGEEVAL
from ._edger import EdgeR
from ._pydeseq2 import PyDESeq2
Expand All @@ -14,7 +14,6 @@
"SimpleComparisonBase",
"WilcoxonTest",
"TTest",
"ContrastType",
]

AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest]
128 changes: 33 additions & 95 deletions pertpy/tools/_differential_gene_expression/_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import math
import os
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from itertools import chain, zip_longest
from collections.abc import Iterable, Mapping, Sequence
from itertools import zip_longest
from types import MappingProxyType

import adjustText
Expand All @@ -12,34 +10,15 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
import statsmodels
from formulaic_contrasts import FormulaicContrasts
from lamin_utils import logger
from matplotlib.pyplot import Figure
from matplotlib.ticker import MaxNLocator

from pertpy._doc import _doc_params, doc_common_plot_args
from pertpy.tools import PseudobulkSpace
from pertpy.tools._differential_gene_expression._checks import check_is_numeric_matrix
from pertpy.tools._differential_gene_expression._formulaic import (
AmbiguousAttributeError,
Factor,
get_factor_storage_and_materializer,
resolve_ambiguous,
)


@dataclass
class Contrast:
"""Simple contrast for comparison between groups"""

column: str
baseline: str
group_to_compare: str


ContrastType = Contrast | tuple[str, str, str]


class MethodBase(ABC):
Expand Down Expand Up @@ -910,8 +889,7 @@ def _get_significance(p_val):

class LinearModelBase(MethodBase):
def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
"""
Initialize the method.
"""Initialize the method.
Args:
adata: AnnData object, usually pseudobulked.
Expand All @@ -923,26 +901,24 @@ def __init__(self, adata, design, *, mask=None, layer=None, **kwargs):
super().__init__(adata, mask=mask, layer=layer)
self._check_counts()

self.factor_storage = None
self.variable_to_factors = None

self.formulaic_contrasts = None
if isinstance(design, str):
self.factor_storage, self.variable_to_factors, materializer_class = get_factor_storage_and_materializer()
self.design = materializer_class(adata.obs, record_factor_metadata=True).get_model_matrix(design)
self.formulaic_contrasts = FormulaicContrasts(adata.obs, design)
self.design = self.formulaic_contrasts.design_matrix
else:
self.design = design

@classmethod
def compare_groups(
cls,
adata,
column,
baseline,
groups_to_compare,
adata: ad.AnnData,
column: str,
baseline: str,
groups_to_compare: str | Iterable[str],
*,
paired_by=None,
mask=None,
layer=None,
paired_by: str | None = None,
mask: pd.Series | None = None,
layer: str | None = None,
fit_kwargs=MappingProxyType({}),
test_kwargs=MappingProxyType({}),
):
Expand All @@ -968,17 +944,16 @@ def compare_groups(
@property
def variables(self):
"""Get the names of the variables used in the model definition."""
try:
return self.design.model_spec.variables_by_source["data"]
except AttributeError:
if self.formulaic_contrasts is None:
raise ValueError(
"Retrieving variables is only possible if the model was initialized using a formula."
) from None
else:
return self.formulaic_contrasts.variables

@abstractmethod
def _check_counts(self):
"""
Check that counts are valid for the specific method.
"""Check that counts are valid for the specific method.
Raises:
ValueError: if the data matrix does not comply with the expectations.
Expand All @@ -987,8 +962,7 @@ def _check_counts(self):

@abstractmethod
def fit(self, **kwargs):
"""
Fit the model.
"""Fit the model.
Args:
**kwargs: Additional arguments for fitting the specific method.
Expand All @@ -998,9 +972,8 @@ def fit(self, **kwargs):
@abstractmethod
def _test_single_contrast(self, contrast, **kwargs): ...

def test_contrasts(self, contrasts, **kwargs):
"""
Perform a comparison as specified in a contrast vector.
def test_contrasts(self, contrasts: np.ndarray | Mapping[str | None, np.ndarray], **kwargs):
"""Perform a comparison as specified in a contrast vector.
Args:
contrasts: Either a numeric contrast vector, or a dictionary of numeric contrast vectors.
Expand All @@ -1016,11 +989,11 @@ def test_contrasts(self, contrasts, **kwargs):
results.append(self._test_single_contrast(contrast, **kwargs).assign(contrast=name))

results_df = pd.concat(results)

return results_df

def test_reduced(self, modelB):
"""
Test against a reduced model.
"""Test against a reduced model.
Args:
modelB: the reduced model against which to test.
Expand All @@ -1034,61 +1007,22 @@ def test_reduced(self, modelB):
raise NotImplementedError

def cond(self, **kwargs):
"""
Get a contrast vector representing a specific condition.
"""Get a contrast vector representing a specific condition.
Args:
**kwargs: column/value pairs.
Returns:
A contrast vector that aligns to the columns of the design matrix.
"""
if self.factor_storage is None:
if self.formulaic_contrasts is None:
raise RuntimeError(
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
)
cond_dict = kwargs
if not set(cond_dict.keys()).issubset(self.variables):
raise ValueError(
"You specified a variable that is not part of the model. Available variables: "
+ ",".join(self.variables)
)
for var in self.variables:
if var in cond_dict:
self._check_category(var, cond_dict[var])
else:
cond_dict[var] = self._get_default_value(var)
df = pd.DataFrame([kwargs])
return self.design.model_spec.get_model_matrix(df).iloc[0]

def _get_factor_metadata_for_variable(self, var):
factors = self.variable_to_factors[var]
return list(chain.from_iterable(self.factor_storage[f] for f in factors))

def _get_default_value(self, var):
factor_metadata = self._get_factor_metadata_for_variable(var)
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL:
try:
tmp_base = resolve_ambiguous(factor_metadata, "base")
except AmbiguousAttributeError as e:
raise ValueError(
f"Could not automatically resolve base category for variable {var}. Please specify it explicity in `model.cond`."
) from e
return tmp_base if tmp_base is not None else "\0"
else:
return 0
return self.formulaic_contrasts.cond(**kwargs)

def _check_category(self, var, value):
factor_metadata = self._get_factor_metadata_for_variable(var)
tmp_categories = resolve_ambiguous(factor_metadata, "categories")
if resolve_ambiguous(factor_metadata, "kind") == Factor.Kind.CATEGORICAL and value not in tmp_categories:
raise ValueError(
f"You specified a non-existant category for {var}. Possible categories: {', '.join(tmp_categories)}"
)

def contrast(self, column, baseline, group_to_compare):
"""
Build a simple contrast for pairwise comparisons.
def contrast(self, *args, **kwargs):
"""Build a simple contrast for pairwise comparisons.
Args:
column: column in adata.obs to test on.
Expand All @@ -1098,4 +1032,8 @@ def contrast(self, column, baseline, group_to_compare):
Returns:
Numeric contrast vector.
"""
return self.cond(**{column: group_to_compare}) - self.cond(**{column: baseline})
if self.formulaic_contrasts is None:
raise RuntimeError(
"Building contrasts with `cond` only works if you specified the model using a formulaic formula. Please manually provide a contrast vector."
)
return self.formulaic_contrasts.contrast(*args, **kwargs)
4 changes: 2 additions & 2 deletions pertpy/tools/_differential_gene_expression/_edger.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def fit(self, **kwargs): # adata, design, mask, layer
logger.info("Calculating NormFactors")
dge = edger.calcNormFactors(dge)

with localconverter(get_conversion() + pandas2ri.converter):
design_r = ro.conversion.py2rpy(pd.DataFrame(self.design))
with localconverter(get_conversion() + numpy2ri.converter):
design_r = ro.conversion.py2rpy(self.design.values)

logger.info("Estimating Dispersions")
dge = edger.estimateDisp(dge, design=design_r)
Expand Down
Loading

0 comments on commit e43d8ff

Please sign in to comment.