Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

156 generalized lasso model #157

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/build_test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ jobs:

- name: Test
run: |
pytest
pytest --doctest-modules
pytest --doctest-modules multidms tests

- name: Test docs build
run: |
make -C docs clean
make -C docs html
make -C docs html
3 changes: 2 additions & 1 deletion .github/workflows/publish_package_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ jobs:

- name: Test
run: |
pytest
pytest multidms tests
# pytest --doctest-modules multidms tests

- name: Build python package
run: |
Expand Down
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ Attic/
docs/multidms*rst
.vscode

notebooks/output/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
8 changes: 4 additions & 4 deletions multidms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ class works to compose, compile, and optimize the model parameters
__version__ = "0.4.0"
__url__ = "https://github.com/matsengrp/multidms"

from polyclonal.alphabets import AAS # noqa: F401
from polyclonal.alphabets import AAS_WITHGAP # noqa: F401
from polyclonal.alphabets import AAS_WITHSTOP # noqa: F401
from polyclonal.alphabets import AAS_WITHSTOP_WITHGAP # noqa: F401
from binarymap.binarymap import AAS_NOSTOP as AAS # noqa: F401
from binarymap.binarymap import AAS_WITHGAP # noqa: F401
from binarymap.binarymap import AAS_WITHSTOP # noqa: F401
from binarymap.binarymap import AAS_WITHSTOP_WITHGAP # noqa: F401

from multidms.data import Data # noqa: F401
from multidms.model import Model # noqa: F401
Expand Down
180 changes: 117 additions & 63 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@

import jax.numpy as jnp
from jaxopt.loss import huber_loss
from jaxopt.prox import prox_lasso
import jax

# jax.config.update("jax_enable_x64", True)
from multidms.utils import transform
import pyproximal
import jax

jax.config.update("jax_enable_x64", True)

r"""
+++++++++++++++++++++++++++++
Expand Down Expand Up @@ -76,11 +77,7 @@ def additive_model(d_params: dict, X_d: jnp.array):
jnp.array
Predicted latent phenotypes for each row in ``X_d``
"""
return (
d_params["beta_naught"]
+ d_params["alpha_d"]
+ (X_d @ (d_params["beta_m"] + d_params["s_md"]))
)
return d_params["beta0"] + X_d @ d_params["beta"]


r"""
Expand Down Expand Up @@ -279,9 +276,12 @@ def softplus_activation(d_params, act, lower_bound=-3.5, hinge_scale=0.1, **kwar
"""
return (
hinge_scale
* (jnp.logaddexp(0, (act - (lower_bound + d_params["gamma_d"])) / hinge_scale))
# GAMMA
# * (jnp.logaddexp(0, (act - (lower_bound + d_params["gamma_d"])) / hinge_scale))
* (jnp.logaddexp(0, (act - lower_bound) / hinge_scale))
+ lower_bound
+ d_params["gamma_d"]
# GAMMA
# + d_params["gamma_d"]
)


Expand Down Expand Up @@ -320,56 +320,102 @@ def _abstract_epistasis(
return t(d_params, g(d_params["theta"], additive_model(d_params, X_h)), **kwargs)


def _lasso_lock_prox(
params,
hyperparams_prox=dict(
lasso_params=None, lock_params=None, upper_bound_theta_ge_scale=None
),
scaling=1.0,
):
def proximal_box_constraints(params, hyperparameters, *args, **kwargs):
"""
Apply lasso and lock constraints to parameters
Proximal operator for box constraints for single condition models.

Parameters
----------
params : dict
Dictionary of parameters to constrain
hyperparams_prox : dict
Dictionary of hyperparameters for proximal operators
scaling : float
Scaling factor for lasso penalty
Note that *args, and **kwargs are placeholders for additional arguments
that may be passed to this function by the optimizer.
"""
# enforce monotonic epistasis and constrain ge_scale upper limit
(
ge_scale_upper_bound,
lock_params,
bundle_idxs,
) = hyperparameters

params = transform(params, bundle_idxs)

# clamp theta scale to monotonic, and with optional upper bound
if "ge_scale" in params["theta"]:
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, hyperparams_prox["upper_bound_theta_ge_scale"]
0, ge_scale_upper_bound
)
# Any params to constrain during fit
# clamp beta0 for reference condition in non-scaled parameterization
# (where it's a box constraint)
if lock_params is not None:
for (param, subparam), value in lock_params.items():
params[param][subparam] = value

params = transform(params, bundle_idxs)
return params


if "p_weights_1" in params["theta"]:
params["theta"]["p_weights_1"] = params["theta"]["p_weights_1"].clip(0)
params["theta"]["p_weights_2"] = params["theta"]["p_weights_2"].clip(0)
def proximal_objective(Dop, params, hyperparameters, scaling=1.0):
"""ADMM generalized lasso optimization."""
(
scale_coeff_lasso_shift,
admm_niter,
admm_tau,
admm_mu,
ge_scale_upper_bound,
lock_params,
bundle_idxs,
) = hyperparameters
# apply prox
beta_ravel = jnp.vstack(params["beta"].values()).ravel(order="F")

# see https://pyproximal.readthedocs.io/en/stable/index.html
beta_ravel, shift_ravel = pyproximal.optimization.primal.LinearizedADMM(
pyproximal.L2(b=beta_ravel),
pyproximal.L1(sigma=scaling * scale_coeff_lasso_shift),
Dop,
niter=admm_niter,
tau=admm_tau,
mu=admm_mu,
x0=beta_ravel,
show=False,
)

beta = beta_ravel.reshape(-1, len(beta_ravel) // len(params["beta"]), order="F")
shift = shift_ravel.reshape(-1, len(shift_ravel) // len(params["beta"]), order="F")

# update beta dict
for i, d in enumerate(params["beta"]):
params["beta"][d] = beta[i]

# update shifts
for i, d in enumerate(params["shift"]):
params["shift"][d] = shift[i]

if hyperparams_prox["lasso_params"] is not None:
for key, value in hyperparams_prox["lasso_params"].items():
params[key] = prox_lasso(params[key], value, scaling)
# clamp beta0 for reference condition in non-scaled parameterization
# (where it's a box constraint)
params = transform(params, bundle_idxs)

# clamp theta scale to monotonic, and with optional upper bound
if "ge_scale" in params["theta"]:
params["theta"]["ge_scale"] = params["theta"]["ge_scale"].clip(
0, ge_scale_upper_bound
)
# Any params to constrain during fit
if hyperparams_prox["lock_params"] is not None:
for key, value in hyperparams_prox["lock_params"].items():
params[key] = value
if lock_params is not None:
for (param, subparam), value in lock_params.items():
params[param][subparam] = value

# params["beta0"][params["beta0"].keys()] = 0.0
params = transform(params, bundle_idxs)

return params


def _gamma_corrected_cost_smooth(
def smooth_objective(
f,
params,
data,
scale_coeff_ridge_beta=0.0,
scale_coeff_ridge_ge_scale=0.0,
scale_coeff_ridge_ge_bias=0.0,
huber_scale=1,
scale_coeff_ridge_shift=0,
scale_coeff_ridge_beta=0,
scale_coeff_ridge_gamma=0,
scale_coeff_ridge_alpha_d=0,
**kwargs,
):
"""
Expand All @@ -386,14 +432,12 @@ def _gamma_corrected_cost_smooth(
return the respective binarymap and the row associated target functional scores
huber_scale : float
Scale parameter for Huber loss function
scale_coeff_ridge_shift : float
Ridge penalty coefficient for shift parameters
scale_coeff_ridge_beta : float
Ridge penalty coefficient for beta parameters
scale_coeff_ridge_gamma : float
Ridge penalty coefficient for gamma parameters
scale_coeff_ridge_alpha_d : float
Ridge penalty coefficient for alpha parameters
Ridge penalty coefficient for shift parameters
scale_coeff_ridge_ge_scale : float
Ridge penalty coefficient for global epistasis scale parameter
scale_coeff_ridge_ge_bias : float
Ridge penalty coefficient for global epistasis bias parameter
kwargs : dict
Additional keyword arguments to pass to the biophysical model function

Expand All @@ -403,36 +447,46 @@ def _gamma_corrected_cost_smooth(
Summed loss across all conditions.
"""
X, y = data
loss = 0
huber_cost = 0
beta_ridge_penalty = 0

# Sum the huber loss across all conditions
# shift_ridge_penalty = 0
for condition, X_d in X.items():
# Subset the params for condition-specific prediction
d_params = {
"beta0": params["beta0"][condition],
"beta": params["beta"][condition],
# GAMMA
# "gamma": params["gamma"][condition],
"theta": params["theta"],
"beta_m": params["beta"],
"beta_naught": params["beta_naught"],
"s_md": params[f"shift_{condition}"],
"alpha_d": params[f"alpha_{condition}"],
"gamma_d": params[f"gamma_{condition}"],
}

# compute predictions
y_d_predicted = f(d_params, X_d, **kwargs)

# compute the Huber loss between observed and predicted
# functional scores
loss += huber_loss(
y[condition] + d_params["gamma_d"], y_d_predicted, huber_scale
huber_cost += huber_loss(
# GAMMA
# y[condition] + d_params["gamma"], y_d_predicted, huber_scale
y[condition],
y_d_predicted,
huber_scale,
).mean()

# compute a regularization term that penalizes non-zero
# parameters and add it to the loss function
loss += scale_coeff_ridge_shift * jnp.sum(d_params["s_md"] ** 2)
loss += scale_coeff_ridge_alpha_d * jnp.sum(d_params["alpha_d"] ** 2)
loss += scale_coeff_ridge_gamma * jnp.sum(d_params["gamma_d"] ** 2)
beta_ridge_penalty += scale_coeff_ridge_beta * (d_params["beta"] ** 2).sum()

loss += scale_coeff_ridge_beta * jnp.sum(params["beta"] ** 2)
huber_cost /= len(X)

return loss
ge_scale_ridge_penalty = (
scale_coeff_ridge_ge_scale * (params["theta"]["ge_scale"] ** 2).sum()
)
ge_bias_ridge_penalty = (
scale_coeff_ridge_ge_bias * (params["theta"]["ge_bias"] ** 2).sum()
)

return (
huber_cost + beta_ridge_penalty + ge_scale_ridge_penalty + ge_bias_ridge_penalty
)
Loading
Loading