-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jaxmodels #164
base: main
Are you sure you want to change the base?
Conversation
This is great, thanks @wsdewitt! I'm digesting the math, code and reproducing the analysis now. The next step will be a comparative analysis with the current HEAD. This will likely spur thoughts on the integration of |
We can discuss these in today's meeting but I'll be posting some questions/discussion points here:
More to come |
To Try (after discussions with @wsdewitt): Investigate large beta's:
Comparative analysis to old resultsTODO |
NOTE: WIP, more sections coming
Summary
This PR includes a number of interrelated changes that are briefly summarized in the following bullets, and elaborated in the subsequent subsections.
multidms/jaxmodels
implements core modeling and optimization in a dedicated module, separate from any downstream analysis/plotting functionality. My recommendation is to call this API from the current modules that handle analysis/plotting.notebooks/jaxmodels
jaxmodels
module for loading and fitting multidms data.Revised global epistasis model
Define a fitness landscape$f:\mathbb{R}\to[0, 1]$ via a Hill function $f(\phi) = \frac{1}{1+e^{-\phi}}$ that maps from latent molecular phenotype $\phi\in\mathbb{R}$ to fitness $f\in[0, 1]$ . The ratio of relative Malthusian growth in variants with fitnesses $f_1$ and $f_2$ is $\exp(\alpha(f_1 - f_2))$ , where $\alpha$ is the effective number of generations of selection. The functional score in a DMS is the logarithm of relative enrichment for a variant $v$ wrt WT, so are modeled as $\hat y = \alpha(f_v - f_{\text{WT}})$ .
Latent molecular phenotypes are modeled via a genotype-phenotype map$\phi:\{0, 1\}^M\to\mathbb{R}$ that is assumed to be linear: $\phi(x) = \beta_0 + \beta^\intercal x$ , where $x\in\{0, 1\}^M$ is a binary sequence encoding. In multidms, we relax this to a locally linear model, so that each experiment, indexed $d$ , has its own mutation effects, and $\phi_d(x) = \beta_{0,d} + \beta_d^\intercal x$ is the latent phenotype of sequence $x$ in experiment $d$ . Functional score predictors (in base $e$ ) for each experiment are then
where$w_d\in\{0,1\}^M$ is the encoding of the WT sequence in experiment $d$ .
Count-based loss
Given the counts in the pre-selection pool and a functional score predictor, we can predict the counts in the post-selection pool. This involves filtering the data on pre-count not too low, and using that along with the functional score prediction to predict post-count. We use a negative binomial with overdispersion parameter$\theta_d$ to model varying noise levels/bottnecking in each experiment.
Preconditioned gradient
We have a badly conditioned problem because the loss is much more sensitive to the mutation effects that are present in the WT sequences of the non-reference homologs. The previous way to handle this relied on a seemingly clever reparameterization, which had the downside that it required a much more expensive proximal step (using ADMM for generalized lasso). In this PR, we take a simpler and more intuitive approach to preconditioning.
The usual forward-backward step in FISTA looks like
We know which components of gradient are sensitive, so, somewhat analogous to a second-order method that uses the Hessian, we're adjust our descent direction with a matrix$P$ :
where$P = I + \mathrm{diag}(1^\intercal X)$ and