Skip to content

Commit

Permalink
120 model collection (#122)
Browse files Browse the repository at this point in the history
* Add ModelCollection interface.
  • Loading branch information
jgallowa07 authored Oct 11, 2023
1 parent d008828 commit 014ac85
Show file tree
Hide file tree
Showing 16 changed files with 2,756 additions and 610 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/build_test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:

- name: Ruff Linting
uses: chartboost/ruff-action@v1
with:
version: 0.0.289

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand All @@ -47,7 +49,7 @@ jobs:
options: "--check"
src: "."
jupyter: false
version: "~= 23.3" # this is the version that ships with the vs code extension, currently
version: "~= 23.3.0" # this is the version that ships with the vs code extension, currently

- name: Test
run: |
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Attic/

*.swp
.DS_Store
/_build
Expand Down Expand Up @@ -167,4 +169,4 @@ cython_debug/
#.idea/


notebooks/test_dump.pkl
notebooks/test_dump.pkl
11 changes: 11 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ All notable changes to this project will be documented in this file.

The format is based on `Keep a Changelog <https://keepachangelog.com>`_.

HEAD
----
Major Changes:
- Adds initial [multidms.model_collection](https://github.com/matsengrp/multidms/blob/120_model_collection/multidms/model_collection.py) module with `multidms.fit_models` for the ability to fit multiple models across a range of parameter spaces in parallel using `multiprocessing`. This is inspired by the `polyclonal.fit_models` function.
- Adds the `ModelCollection` class for split-apply-combine interface to the mutational dataframes for a collection of models
- Adds two altair plotting methods to `ModelCollection`. (1) `mut_param_heatmap` for visualizing aggregated parameter sets across fits, and (2) `mut_param_traceplot` making trace plots across fits with variable lasso coeff strengths
- removes `utils` module.
- Cleans up #114
- optionally removes "wts", "sites", "muts" from the mutations dataframe returned by `Model.get_mutations_df`. Those were unnecessary IMO
- Changes the naming of columns produced by `Model.get_mutations_df()`, in particular, it moves the condition name for predicted func score to be a suffix (as with shift, and time_seen) rather than a prefix. e.g. "delta_predicted_func_score" -> "predicted_func_score_delta".


0.2.2
---------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ and how much the effects differ between experiments.
installation
biophysical_model
fit_delta_BA1_example
model_collection
multidms
acknowledgments
contributing
Expand Down
3 changes: 3 additions & 0 deletions docs/model_collection.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../notebooks/model_collection.ipynb"
}
16 changes: 4 additions & 12 deletions multidms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
- :mod:`~multidms.data.Data`
- :mod:`~multidms.model.Model`
for some helpful utilities for working with multiple model fits
and wrangle their output, see:
- :mod:`~multidms.utils`
- :mod:`~multidms.model_collection.ModelCollection`
For a brief description about how the :class:`~multidms.model.Model`
class works to compose, compile, and optimize the model parameters
Expand All @@ -37,8 +33,6 @@ class works to compose, compile, and optimize the model parameters
:mod:`~multidms.plot` mostly contains code for interactive plotting
at the moment.
It also imports the following alphabets:
- :const:`~multidms.alphabets.AAS`
Expand All @@ -63,9 +57,7 @@ class works to compose, compile, and optimize the model parameters

from multidms.data import Data # noqa: F401
from multidms.model import Model # noqa: F401

import multidms.biophysical # noqa: F401
import multidms.utils # noqa: F401
from multidms.model_collection import ModelCollection, fit_models # noqa: F401

# This lets Sphinx know you want to document foo.foo.Foo as foo.Foo.
__all__ = ["Data", "Model"]
__all__ = ["Data", "Model", "ModelCollection", "fit_models"]
6 changes: 3 additions & 3 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _gamma_corrected_cost_smooth(
scale_coeff_ridge_shift=0,
scale_coeff_ridge_beta=0,
scale_coeff_ridge_gamma=0,
scale_coeff_ridge_cd=0,
scale_coeff_ridge_alpha_d=0,
**kwargs,
):
"""
Expand All @@ -391,7 +391,7 @@ def _gamma_corrected_cost_smooth(
Ridge penalty coefficient for beta parameters
scale_coeff_ridge_gamma : float
Ridge penalty coefficient for gamma parameters
scale_coeff_ridge_cd : float
scale_coeff_ridge_alpha_d : float
Ridge penalty coefficient for alpha parameters
kwargs : dict
Additional keyword arguments to pass to the biophysical model function
Expand Down Expand Up @@ -429,7 +429,7 @@ def _gamma_corrected_cost_smooth(
# compute a regularization term that penalizes non-zero
# parameters and add it to the loss function
loss += scale_coeff_ridge_shift * jnp.sum(d_params["s_md"] ** 2)
loss += scale_coeff_ridge_cd * jnp.sum(d_params["alpha_d"] ** 2)
loss += scale_coeff_ridge_alpha_d * jnp.sum(d_params["alpha_d"] ** 2)
loss += scale_coeff_ridge_gamma * jnp.sum(d_params["gamma_d"] ** 2)

loss += scale_coeff_ridge_beta * jnp.sum(params["beta"] ** 2)
Expand Down
69 changes: 56 additions & 13 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
from functools import partial
import warnings
import re

import binarymap as bmap
import numpy as onp
Expand All @@ -19,18 +20,38 @@
from tqdm.auto import tqdm

from multidms import AAS
from multidms.utils import split_subs

import jax

import jax.numpy as jnp
import seaborn as sns
from jax.experimental import sparse
from matplotlib import pyplot as plt
from pandarallel import pandarallel

jax.config.update("jax_enable_x64", True)
# tqdm.pandas()


def split_sub(sub_string):
"""String match the wt, site, and sub aa
in a given string denoting a single substitution
"""
pattern = r"(?P<aawt>[A-Z])(?P<site>[\d\w]+)(?P<aamut>[A-Z\*])"
match = re.search(pattern, sub_string)
assert match is not None, sub_string
return match.group("aawt"), str(match.group("site")), match.group("aamut")


def split_subs(subs_string, parser=split_sub):
"""Wrap the split_sub func to work for a
string contining multiple substitutions
"""
wts, sites, muts = [], [], []
for sub in subs_string.split():
wt, site, mut = parser(sub)
wts.append(wt)
sites.append(site)
muts.append(mut)
return wts, sites, muts


class Data:
Expand Down Expand Up @@ -80,23 +101,32 @@ class Data:
where all sites that are non-identical to the reference are 1's.
For motivation, see the `Model overview` section in :class:`multidms.Model`
class notes.
alphabet : array-like
Allowed characters in mutation strings.
collapse_identical_variants : {'mean', 'median', False}
If identical variants in ``variants_df`` (same 'aa_substitutions'),
exist within individual condition groups,
collapse them by taking mean or median of 'func_score', or
(if `False`) do not collapse at all. Collapsing will make fitting faster,
but *not* a good idea if you are doing bootstrapping.
assert_site_integrity : bool
If True, will assert that all sites in the data frame
have the same wild type amino acid, grouped by condition.
alphabet : array-like
Allowed characters in mutation strings.
condition_colors : array-like or dict
Maps each condition to the color used for plotting. Either a dict keyed
by each condition, or an array of colors that are sequentially assigned
to the conditions.
letter_suffixed_sites: bool
True if sites are sequential and integer, False otherwise.
assert_site_integrity : bool
If True, will assert that all sites in the data frame
have the same wild type amino acid, grouped by condition.
verbose : bool
If True, will print progress bars.
nb_workers : int
Number of workers to use for parallel operations.
If None, will use all available CPUs.
name : str or None
Name of the data object. If None, will be assigned
a unique name based upon the number of data objects
instantiated.
Example
-------
Expand Down Expand Up @@ -134,7 +164,6 @@ class notes.
... alphabet = multidms.AAS_WITHSTOP,
... reference = "a",
... ) # doctest: +ELLIPSIS
INFO: Pandarallel will run on ... workers.
...
Note this may take some time due to the string
Expand Down Expand Up @@ -181,6 +210,8 @@ class notes.
8 b M1E P3R -2.7 G3R M1E
"""

counter = 0

def __init__(
self,
variants_df: pd.DataFrame,
Expand All @@ -192,6 +223,7 @@ def __init__(
assert_site_integrity=False,
verbose=False,
nb_workers=None,
name=None,
):
"""See main class docstring."""
# Check and initialize conditions attribute
Expand All @@ -205,10 +237,10 @@ def __init__(
self._conditions = tuple(variants_df["condition"].astype(str).unique())

if str(reference) not in self._conditions:
if isinstance(reference, str):
if not isinstance(reference, str):
raise ValueError(
"reference must be a string, note that if your "
"condition names are numeric, they are being"
"condition names are numeric, they are being "
"converted to string"
)
raise ValueError("reference must be in condition factor levels")
Expand Down Expand Up @@ -268,7 +300,9 @@ def __init__(

# Use the "aa_substitutions" to infer the
# wild type for each condition
site_map = pd.DataFrame()
# site_map = pd.DataFrame()
site_map = pd.DataFrame(columns=self.conditions)
# print(site_map.info())
for hom, hom_func_df in df.groupby("condition"):
if verbose:
print(f"inferring site map for {hom}")
Expand All @@ -295,7 +329,9 @@ def __init__(
site_map.dropna(inplace=True)

nb_workers = min(os.cpu_count(), 4) if nb_workers is None else nb_workers
pandarallel.initialize(progress_bar=verbose, nb_workers=nb_workers)
pandarallel.initialize(
progress_bar=verbose, verbose=0 if not verbose else 2, nb_workers=nb_workers
)

def flags_invalid_sites(disallowed_sites, sites_list):
"""Check to see if a sites list contains
Expand Down Expand Up @@ -470,6 +506,13 @@ def get_nis_from_site_map(site_map):
).fillna(0)

self._mutations_df = mut_df
self._name = name if isinstance(name, str) else f"Data-{Data.counter}"
Data.counter += 1

@property
def name(self) -> str:
"""The name of the data object."""
return self._name

@property
def conditions(self) -> tuple:
Expand Down
Loading

0 comments on commit 014ac85

Please sign in to comment.