Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jaxmodels #164

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft

jaxmodels #164

wants to merge 17 commits into from

Conversation

wsdewitt
Copy link
Contributor

@wsdewitt wsdewitt commented Jul 25, 2024

NOTE: WIP, more sections coming

Summary

This PR includes a number of interrelated changes that are briefly summarized in the following bullets, and elaborated in the subsequent subsections.

  • New module multidms/jaxmodels implements core modeling and optimization in a dedicated module, separate from any downstream analysis/plotting functionality. My recommendation is to call this API from the current modules that handle analysis/plotting.
    • Revised global epistasis model that eliminates parameter redundancy using a standard fitness picture.
    • A simpler approach to dealing with ill-conditioning in mutation effect inference due to the set of mutations that define non-reference homologs.
    • Count-based loss function.
  • New notebook notebooks/jaxmodels
    • Demonstrates recommended data aggregation, filtering, and stop codon transformation.
    • Demonstrates usage of jaxmodels module for loading and fitting multidms data.
  • Removed pandarallel

Revised global epistasis model

Define a fitness landscape $f:\mathbb{R}\to[0, 1]$ via a Hill function $f(\phi) = \frac{1}{1+e^{-\phi}}$ that maps from latent molecular phenotype $\phi\in\mathbb{R}$ to fitness $f\in[0, 1]$. The ratio of relative Malthusian growth in variants with fitnesses $f_1$ and $f_2$ is $\exp(\alpha(f_1 - f_2))$, where $\alpha$ is the effective number of generations of selection. The functional score in a DMS is the logarithm of relative enrichment for a variant $v$ wrt WT, so are modeled as $\hat y = \alpha(f_v - f_{\text{WT}})$.

Latent molecular phenotypes are modeled via a genotype-phenotype map $\phi:\{0, 1\}^M\to\mathbb{R}$ that is assumed to be linear: $\phi(x) = \beta_0 + \beta^\intercal x$, where $x\in\{0, 1\}^M$ is a binary sequence encoding. In multidms, we relax this to a locally linear model, so that each experiment, indexed $d$, has its own mutation effects, and $\phi_d(x) = \beta_{0,d} + \beta_d^\intercal x$ is the latent phenotype of sequence $x$ in experiment $d$. Functional score predictors (in base $e$) for each experiment are then

$$ \hat y_d(x) = \alpha_d\left(f\left(\beta_{0,d}+\beta_d^\intercal x\right) - f\left(\beta_{0,d}+\beta_d^\intercal w_d\right)\right), $$

where $w_d\in\{0,1\}^M$ is the encoding of the WT sequence in experiment $d$.

Count-based loss

Given the counts in the pre-selection pool and a functional score predictor, we can predict the counts in the post-selection pool. This involves filtering the data on pre-count not too low, and using that along with the functional score prediction to predict post-count. We use a negative binomial with overdispersion parameter $\theta_d$ to model varying noise levels/bottnecking in each experiment.

Preconditioned gradient

We have a badly conditioned problem because the loss is much more sensitive to the mutation effects that are present in the WT sequences of the non-reference homologs. The previous way to handle this relied on a seemingly clever reparameterization, which had the downside that it required a much more expensive proximal step (using ADMM for generalized lasso). In this PR, we take a simpler and more intuitive approach to preconditioning.

The usual forward-backward step in FISTA looks like

$$ \beta_{(t+1)} = \mathrm{prox}_{\gamma g}(\beta_t - \gamma \nabla f(\beta_t)). $$

We know which components of gradient are sensitive, so, somewhat analogous to a second-order method that uses the Hessian, we're adjust our descent direction with a matrix $P$:

$$ \beta_{(t+1)} = \mathrm{prox}_{\gamma P g}(\beta_t - \gamma P^{-1} \nabla f(\beta_t)), $$

where $P = I + \mathrm{diag}(1^\intercal X)$ and

$$ \mathrm{prox}_{\gamma P g} = \arg\min_\beta \frac{1}{2}\|\beta\|_P^2 + g(\beta). $$

@jgallowa07
Copy link
Member

jgallowa07 commented Jul 31, 2024

This is great, thanks @wsdewitt! I'm digesting the math, code and reproducing the analysis now. The next step will be a comparative analysis with the current HEAD. This will likely spur thoughts on the integration of jaxmodels to the rest of the multidms namespace. More to come!

@jgallowa07
Copy link
Member

We can discuss these in today's meeting but I'll be posting some questions/discussion points here:

  1. Use of $\alpha$ parameter: I'm not sure I fully grasp this. You state: "where $\alpha$ is the effective number of generations of selection". But is the number of generations not always 1 i.e the pre-selection generation, and post selection generation? AFAIU, log$\alpha$ is a free parameter we fit to scale the sigmoid - what does the sigmoid scaling have to do with this "generations" description (my popgen brain is firing but not quite clicking).
  2. I'm unclear, still, what the motivation/advantage of using the counts-based model is. Is it so we can use negative binomial as opposed to a huber regression style loss? This may be reason enough as you explained about the dispersion param, but I could use a little more explanation on why we think this is better than the huber loss regression model.
  3. sklearn dependency should be added to the pyproject.toml If we're going to keep this. It seems this is only used for the warm up of the individual linear models.
  4. To be clear, here, it seems like you must put the wildtype as the first variant of each condition, but remove it from the training data? I get that all functional scores are relative to the wildtype, but then how is the model being informed about the correct $\beta_0$ value to infer?
  5. It seems the fitting procedure is not deterministic. I assume this has something to do with scikit-learn.linear_model.Ridge.fit needing a seed to be set? That said, the results are very similar between runs which is good
  6. The distribution of beta's seems a little odd ... why are there some extremely large, positive values? Could this be the effect of the negative binomial dispersion parameter? This is also causing there to be some very large
  7. The code seems to be very memory intensive. Without formal benchmarking, it seems that fitting just one of these models takes nearly 10G ram ... running 6 model fits in parallel crashed my computer's applications. This could be because I cannot share the training binarymaps between models - which was a nice feature of the current multidms.

More to come

@jgallowa07
Copy link
Member

jgallowa07 commented Aug 1, 2024

To Try (after discussions with @wsdewitt):

Investigate large beta's:

  • Try no warm start with scikit.
  • Filtering counts more strictly.
  • hyper param L2 regularization.
  • Try lowering the tolerance
  • Stops vs. non-stops
  • Make sure that wildtype are first, as that is an assumption of the new code.

Comparative analysis to old results

TODO

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants