Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

120 model collection #122

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/build_test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:

- name: Ruff Linting
uses: chartboost/ruff-action@v1
with:
version: 0.0.289

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand All @@ -47,7 +49,7 @@ jobs:
options: "--check"
src: "."
jupyter: false
version: "~= 23.3" # this is the version that ships with the vs code extension, currently
version: "~= 23.3.0" # this is the version that ships with the vs code extension, currently

- name: Test
run: |
Expand Down
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Attic/

*.swp
.DS_Store
/_build
Expand Down Expand Up @@ -167,4 +169,4 @@ cython_debug/
#.idea/


notebooks/test_dump.pkl
notebooks/test_dump.pkl
11 changes: 11 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ All notable changes to this project will be documented in this file.

The format is based on `Keep a Changelog <https://keepachangelog.com>`_.

HEAD
----
Major Changes:
- Adds initial [multidms.model_collection](https://github.com/matsengrp/multidms/blob/120_model_collection/multidms/model_collection.py) module with `multidms.fit_models` for the ability to fit multiple models across a range of parameter spaces in parallel using `multiprocessing`. This is inspired by the `polyclonal.fit_models` function.
- Adds the `ModelCollection` class for split-apply-combine interface to the mutational dataframes for a collection of models
- Adds two altair plotting methods to `ModelCollection`. (1) `mut_param_heatmap` for visualizing aggregated parameter sets across fits, and (2) `mut_param_traceplot` making trace plots across fits with variable lasso coeff strengths
- removes `utils` module.
- Cleans up #114
- optionally removes "wts", "sites", "muts" from the mutations dataframe returned by `Model.get_mutations_df`. Those were unnecessary IMO
- Changes the naming of columns produced by `Model.get_mutations_df()`, in particular, it moves the condition name for predicted func score to be a suffix (as with shift, and time_seen) rather than a prefix. e.g. "delta_predicted_func_score" -> "predicted_func_score_delta".


0.2.2
---------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ and how much the effects differ between experiments.
installation
biophysical_model
fit_delta_BA1_example
model_collection
multidms
acknowledgments
contributing
Expand Down
3 changes: 3 additions & 0 deletions docs/model_collection.nblink
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"path": "../notebooks/model_collection.ipynb"
}
16 changes: 4 additions & 12 deletions multidms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,8 @@
- :mod:`~multidms.data.Data`

- :mod:`~multidms.model.Model`

for some helpful utilities for working with multiple model fits
and wrangle their output, see:

- :mod:`~multidms.utils`


- :mod:`~multidms.model_collection.ModelCollection`

For a brief description about how the :class:`~multidms.model.Model`
class works to compose, compile, and optimize the model parameters
Expand All @@ -37,8 +33,6 @@ class works to compose, compile, and optimize the model parameters
:mod:`~multidms.plot` mostly contains code for interactive plotting
at the moment.



It also imports the following alphabets:

- :const:`~multidms.alphabets.AAS`
Expand All @@ -63,9 +57,7 @@ class works to compose, compile, and optimize the model parameters

from multidms.data import Data # noqa: F401
from multidms.model import Model # noqa: F401

import multidms.biophysical # noqa: F401
import multidms.utils # noqa: F401
from multidms.model_collection import ModelCollection, fit_models # noqa: F401

# This lets Sphinx know you want to document foo.foo.Foo as foo.Foo.
__all__ = ["Data", "Model"]
__all__ = ["Data", "Model", "ModelCollection", "fit_models"]
6 changes: 3 additions & 3 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _gamma_corrected_cost_smooth(
scale_coeff_ridge_shift=0,
scale_coeff_ridge_beta=0,
scale_coeff_ridge_gamma=0,
scale_coeff_ridge_cd=0,
scale_coeff_ridge_alpha_d=0,
**kwargs,
):
"""
Expand All @@ -391,7 +391,7 @@ def _gamma_corrected_cost_smooth(
Ridge penalty coefficient for beta parameters
scale_coeff_ridge_gamma : float
Ridge penalty coefficient for gamma parameters
scale_coeff_ridge_cd : float
scale_coeff_ridge_alpha_d : float
Ridge penalty coefficient for alpha parameters
kwargs : dict
Additional keyword arguments to pass to the biophysical model function
Expand Down Expand Up @@ -429,7 +429,7 @@ def _gamma_corrected_cost_smooth(
# compute a regularization term that penalizes non-zero
# parameters and add it to the loss function
loss += scale_coeff_ridge_shift * jnp.sum(d_params["s_md"] ** 2)
loss += scale_coeff_ridge_cd * jnp.sum(d_params["alpha_d"] ** 2)
loss += scale_coeff_ridge_alpha_d * jnp.sum(d_params["alpha_d"] ** 2)
loss += scale_coeff_ridge_gamma * jnp.sum(d_params["gamma_d"] ** 2)

loss += scale_coeff_ridge_beta * jnp.sum(params["beta"] ** 2)
Expand Down
69 changes: 56 additions & 13 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
from functools import partial
import warnings
import re

import binarymap as bmap
import numpy as onp
Expand All @@ -19,18 +20,38 @@
from tqdm.auto import tqdm

from multidms import AAS
from multidms.utils import split_subs

import jax

import jax.numpy as jnp
import seaborn as sns
from jax.experimental import sparse
from matplotlib import pyplot as plt
from pandarallel import pandarallel

jax.config.update("jax_enable_x64", True)
# tqdm.pandas()


def split_sub(sub_string):
"""String match the wt, site, and sub aa
in a given string denoting a single substitution
"""
pattern = r"(?P<aawt>[A-Z])(?P<site>[\d\w]+)(?P<aamut>[A-Z\*])"
match = re.search(pattern, sub_string)
assert match is not None, sub_string
return match.group("aawt"), str(match.group("site")), match.group("aamut")


def split_subs(subs_string, parser=split_sub):
"""Wrap the split_sub func to work for a
string contining multiple substitutions
"""
wts, sites, muts = [], [], []
for sub in subs_string.split():
wt, site, mut = parser(sub)
wts.append(wt)
sites.append(site)
muts.append(mut)
return wts, sites, muts


class Data:
Expand Down Expand Up @@ -80,23 +101,32 @@ class Data:
where all sites that are non-identical to the reference are 1's.
For motivation, see the `Model overview` section in :class:`multidms.Model`
class notes.
alphabet : array-like
Allowed characters in mutation strings.
collapse_identical_variants : {'mean', 'median', False}
If identical variants in ``variants_df`` (same 'aa_substitutions'),
exist within individual condition groups,
collapse them by taking mean or median of 'func_score', or
(if `False`) do not collapse at all. Collapsing will make fitting faster,
but *not* a good idea if you are doing bootstrapping.
assert_site_integrity : bool
If True, will assert that all sites in the data frame
have the same wild type amino acid, grouped by condition.
alphabet : array-like
Allowed characters in mutation strings.
condition_colors : array-like or dict
Maps each condition to the color used for plotting. Either a dict keyed
by each condition, or an array of colors that are sequentially assigned
to the conditions.
letter_suffixed_sites: bool
True if sites are sequential and integer, False otherwise.
assert_site_integrity : bool
If True, will assert that all sites in the data frame
have the same wild type amino acid, grouped by condition.
verbose : bool
If True, will print progress bars.
nb_workers : int
Number of workers to use for parallel operations.
If None, will use all available CPUs.
name : str or None
Name of the data object. If None, will be assigned
a unique name based upon the number of data objects
instantiated.
Example
-------
Expand Down Expand Up @@ -134,7 +164,6 @@ class notes.
... alphabet = multidms.AAS_WITHSTOP,
... reference = "a",
... ) # doctest: +ELLIPSIS
INFO: Pandarallel will run on ... workers.
...
Note this may take some time due to the string
Expand Down Expand Up @@ -181,6 +210,8 @@ class notes.
8 b M1E P3R -2.7 G3R M1E
"""

counter = 0

def __init__(
self,
variants_df: pd.DataFrame,
Expand All @@ -192,6 +223,7 @@ def __init__(
assert_site_integrity=False,
verbose=False,
nb_workers=None,
name=None,
):
"""See main class docstring."""
# Check and initialize conditions attribute
Expand All @@ -205,10 +237,10 @@ def __init__(
self._conditions = tuple(variants_df["condition"].astype(str).unique())

if str(reference) not in self._conditions:
if isinstance(reference, str):
if not isinstance(reference, str):
raise ValueError(
"reference must be a string, note that if your "
"condition names are numeric, they are being"
"condition names are numeric, they are being "
"converted to string"
)
raise ValueError("reference must be in condition factor levels")
Expand Down Expand Up @@ -268,7 +300,9 @@ def __init__(

# Use the "aa_substitutions" to infer the
# wild type for each condition
site_map = pd.DataFrame()
# site_map = pd.DataFrame()
site_map = pd.DataFrame(columns=self.conditions)
# print(site_map.info())
for hom, hom_func_df in df.groupby("condition"):
if verbose:
print(f"inferring site map for {hom}")
Expand All @@ -295,7 +329,9 @@ def __init__(
site_map.dropna(inplace=True)

nb_workers = min(os.cpu_count(), 4) if nb_workers is None else nb_workers
pandarallel.initialize(progress_bar=verbose, nb_workers=nb_workers)
pandarallel.initialize(
progress_bar=verbose, verbose=0 if not verbose else 2, nb_workers=nb_workers
)

def flags_invalid_sites(disallowed_sites, sites_list):
"""Check to see if a sites list contains
Expand Down Expand Up @@ -470,6 +506,13 @@ def get_nis_from_site_map(site_map):
).fillna(0)

self._mutations_df = mut_df
self._name = name if isinstance(name, str) else f"Data-{Data.counter}"
Data.counter += 1

@property
def name(self) -> str:
"""The name of the data object."""
return self._name

@property
def conditions(self) -> tuple:
Expand Down
Loading