Skip to content

Commit

Permalink
Convergence infrastructure and add protoype nb (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 authored Apr 9, 2024
1 parent 79dc61a commit 8e30b43
Show file tree
Hide file tree
Showing 7 changed files with 3,804 additions and 228 deletions.
9 changes: 7 additions & 2 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@

from multidms import AAS

import jax
import jax.numpy as jnp
import seaborn as sns
from jax.experimental import sparse
from matplotlib import pyplot as plt
from pandarallel import pandarallel

jax.config.update("jax_enable_x64", True)


def split_sub(sub_string):
"""String match the wt, site, and sub aa
Expand Down Expand Up @@ -463,7 +466,7 @@ def get_nis_from_site_map(site_map):

# Make BinaryMap representations for each condition
allowed_subs = {s for subs in df.var_wrt_ref for s in subs.split()}
binmaps, X, y = {}, {}, {}
binmaps, X, y, w = {}, {}, {}, {}
for condition, condition_func_score_df in df.groupby("condition"):
ref_bmap = bmap.BinaryMap(
condition_func_score_df,
Expand All @@ -475,10 +478,12 @@ def get_nis_from_site_map(site_map):
binmaps[condition] = ref_bmap
X[condition] = sparse.BCOO.from_scipy_sparse(ref_bmap.binary_variants)
y[condition] = jnp.array(condition_func_score_df["func_score"].values)
if "weight" in condition_func_score_df.columns:
w[condition] = jnp.array(condition_func_score_df["weight"].values)

df.drop(["wts", "sites", "muts"], axis=1, inplace=True)
self._variants_df = df
self._training_data = {"X": X, "y": y}
self._training_data = {"X": X, "y": y, "w": w}
self._binarymaps = binmaps
self._mutations = tuple(ref_bmap.all_subs)

Expand Down
186 changes: 138 additions & 48 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import multidms.biophysical
from multidms.plot import _lineplot_and_heatmap

jax.config.update("jax_enable_x64", True)


class Model:
r"""
Expand Down Expand Up @@ -139,11 +141,11 @@ class Model:
>>> model.get_mutations_df() # doctest: +NORMALIZE_WHITESPACE
beta shift_b predicted_func_score_a predicted_func_score_b \
mutation
M1E 1.816086 0.0 1.800479 1.379661
M1W -0.754885 0.0 -0.901211 -1.322029
G3P 0.339889 0.0 0.420818 0.000000
G3R -0.534835 0.0 -0.653051 -1.073869
mutation
M1E 0.080868 0.0 0.101030 0.565154
M1W -0.386247 0.0 -0.476895 -0.012770
G3P -0.375656 0.0 -0.464124 0.000000
G3R 1.668974 0.0 1.707195 2.171319
<BLANKLINE>
times_seen_a times_seen_b wts sites muts
mutation
Expand All @@ -160,26 +162,27 @@ class Model:
>>> model.get_variants_df() # doctest: +NORMALIZE_WHITESPACE
condition aa_substitutions func_score var_wrt_ref predicted_latent \
0 a M1E 2.0 M1E 1.816086
1 a G3R -7.0 G3R -0.534835
2 a G3P -0.5 G3P 0.339889
3 a M1W 2.3 M1W -0.754885
4 b M1E 1.0 G3P M1E 1.816086
5 b P3R -5.0 G3R -0.874724
6 b P3G 0.4 -0.339889
7 b M1E P3G 2.7 M1E 1.476197
8 b M1E P3R -2.7 G3R M1E 0.941362
0 a M1E 2.0 M1E 0.080868
1 a G3R -7.0 G3R 1.668974
2 a G3P -0.5 G3P -0.375656
3 a M1W 2.3 M1W -0.386247
4 b M1E 1.0 G3P M1E 0.080868
5 b P3R -5.0 G3R 2.044630
6 b P3G 0.4 0.375656
7 b M1E P3G 2.7 M1E 0.456523
8 b M1E P3R -2.7 G3R M1E 2.125498
<BLANKLINE>
predicted_func_score
0 1.800479
1 -0.653051
2 0.420818
3 -0.901211
4 1.560311
5 -1.073869
6 -0.420818
7 1.379661
8 0.992495
0 0.101030
1 1.707195
2 -0.464124
3 -0.476895
4 0.098285
5 2.171319
6 0.464124
7 0.565154
8 2.223789
We now have access to the predicted (and gamma corrected) functional scores
as predicted by the models current parameters.
Expand All @@ -189,13 +192,13 @@ class Model:
given our initialized parameters
>>> model.loss
Array(4.7124467, dtype=float32)
Array(7.19312981, dtype=float64)
Next, we fit the model with some chosen hyperparameters.
>>> model.fit(maxiter=1000, lasso_shift=1e-5, warn_unconverged=False)
>>> model.loss
Array(6.0517805e-06, dtype=float32)
Array(1.18200934e-05, dtype=float64)
The model tunes its parameters in place, and the subsequent call to retrieve
the loss reflects our models loss given its updated parameters.
Expand Down Expand Up @@ -330,7 +333,10 @@ def __init__(
)

self._name = name if isinstance(name, str) else f"Model-{Model.counter}"

# None of the following are set until the fit() is called.
self._state = None
self._convergence_trajectory = None
self._converged = False
Model.counter += 1

Expand Down Expand Up @@ -412,6 +418,14 @@ def conditional_loss(self) -> float:
ret["total"] = sum(ret.values())
return ret

@property
def convergence_trajectory_df(self):
"""
The state.error through each training iteration.
Currentlty, this is reset each time the fit() method is called
"""
return self._convergence_trajectory

@property
def variants_df(self):
"""
Expand Down Expand Up @@ -968,26 +982,30 @@ def fit_reference_beta(self, **kwargs):

def fit(
self,
lasso_shift=1e-5,
tol=1e-4,
maxiter=1000,
acceleration=True,
maxls=15,
scale_coeff_lasso_shift=1e-5,
lock_params={},
warn_unconverged=True,
upper_bound_theta_ge_scale="infer",
convergence_trajectory_resolution=100,
**kwargs,
):
r"""
Use jaxopt.ProximalGradiant to optimize the model's free parameters.
Parameters
----------
lasso_shift : float
scale_coeff_lasso_shift : float
L1 penalty on the shift parameters. Defaults to 1e-5.
tol : float
Tolerance for the optimization. Defaults to 1e-6.
maxiter : int
Maximum number of iterations for the optimization. Defaults to 1000.
maxls : int
Maximum number of iterations to perform during line search.
acceleration : bool
If True, use FISTA acceleration. Defaults to True.
lock_params : dict
Expand All @@ -1006,19 +1024,14 @@ def fit(
Passing the string literal 'infer' results in the
scale being set to double the range of the training data.
Defaults to 'infer'.
convergence_trajectory_resolution : int
The resolution of the loss and error trajectory recorded
during optimization. Defaults to 100.
**kwargs : dict
Additional keyword arguments passed to the objective function.
These include hyperparameters like a ridge penalty on beta, shift, and gamma
as well as huber loss scaling.
"""
solver = ProximalGradient(
jax.jit(self._model_components["objective"]),
jax.jit(self._model_components["proximal"]),
tol=tol,
maxiter=maxiter,
acceleration=acceleration,
)

lock_params[f"shift_{self._data.reference}"] = jnp.zeros(
len(self._params["beta"])
)
Expand All @@ -1042,7 +1055,7 @@ def fit(
for non_ref_condition in self._data.conditions:
if non_ref_condition == self._data.reference:
continue
lasso_params[f"shift_{non_ref_condition}"] = lasso_shift
lasso_params[f"shift_{non_ref_condition}"] = scale_coeff_lasso_shift

if not isinstance(upper_bound_theta_ge_scale, (float, int, type(None), str)):
raise ValueError(
Expand All @@ -1054,33 +1067,109 @@ def fit(
# infer the range of the training data, and double it
# to set the upper bound of the theta scale parameter.
# see https://github.com/matsengrp/multidms/issues/143 for details
# TODO could make this a property of the Data object

if upper_bound_theta_ge_scale == "infer":
y = jnp.concatenate(list(self.data.training_data["y"].values()))
y_range = y.max() - y.min()
upper_bound_theta_ge_scale = 2 * y_range

self._params, self._state = solver.run(
compiled_objective = jax.jit(self._model_components["objective"])
compiled_proximal = jax.jit(self._model_components["proximal"])

solver = ProximalGradient(
compiled_objective,
compiled_proximal,
tol=tol,
maxiter=maxiter,
acceleration=acceleration,
maxls=maxls,
)

training_data = (self._data.training_data["X"], self._data.training_data["y"])

self._state = solver.init_state(
self._params,
hyperparams_prox=dict(
lasso_params=lasso_params,
lock_params=lock_params,
upper_bound_theta_ge_scale=upper_bound_theta_ge_scale,
),
data=(self._data.training_data["X"], self._data.training_data["y"]),
data=training_data,
**kwargs,
)

converged = self._state.error < tol
if not converged and warn_unconverged:
warnings.warn(
"Model training error did not reach the tolerance threshold. "
f"Final error: {self._state.error}, tolerance: {tol}",
RuntimeWarning,
convergence_trajectory = pd.DataFrame(
index=range(maxiter + 1, convergence_trajectory_resolution)
).assign(loss=onp.nan, error=onp.nan)

# TODO should step be the index?
convergence_trajectory.index.name = "step"

# record initial loss and error
convergence_trajectory.loc[0, "loss"] = float(
compiled_objective(self._params, training_data)
)

convergence_trajectory.loc[0, "error"] = float(self._state.error)

# prev_no_pen_obj_loss = jnp.inf
for i in range(maxiter):
# perform single optimization step
self._params, self._state = solver.update(
self._params,
self._state,
hyperparams_prox=dict(
lasso_params=lasso_params,
lock_params=lock_params,
upper_bound_theta_ge_scale=upper_bound_theta_ge_scale,
),
data=training_data,
**kwargs,
)
self._converged = converged
# record loss and error trajectories at regular intervals
if (i + 1) % convergence_trajectory_resolution == 0:
no_pen_obj_loss = float(compiled_objective(self._params, training_data))
convergence_trajectory.loc[i + 1, "loss"] = no_pen_obj_loss
convergence_trajectory.loc[i + 1, "error"] = float(self._state.error)

# TODO, if you wanted to
# delta_no_pen_obj_loss = -1 * (no_pen_obj_loss - prev_no_pen_obj_loss)
# if delta_no_pen_obj_loss < tol:
# self._converged = True
# break

# prev_no_pen_obj_loss = no_pen_obj_loss

# early stopping criteria
# TODO what if we had an auxilary attribute that we used for
# early stopping, that's probably really speed things up if we
# wanted to do the above.
if self._state.error < tol:
self._converged = True
break

if not self.converged:
if warn_unconverged:
warnings.warn(
"Model training error did not reach the tolerance threshold. "
f"Final error: {self._state.error}, tolerance: {tol}",
RuntimeWarning,
)

self._convergence_trajectory = convergence_trajectory

# return None

def plot_pred_accuracy(
self, hue=True, show=True, saveas=None, annotate_corr=True, ax=None, **kwargs
self,
hue=True,
show=True,
saveas=None,
annotate_corr=True,
ax=None,
r=2,
**kwargs,
):
"""
Create a figure which visualizes the correlation
Expand Down Expand Up @@ -1124,9 +1213,10 @@ def plot_pred_accuracy(
if annotate_corr:
start_y = 0.95
for c, cdf in df.groupby("condition"):
r = pearsonr(cdf[func_score], cdf["predicted_func_score"])[0]
corr = pearsonr(cdf[func_score], cdf["predicted_func_score"])[0] ** r
metric = "pearson" if r == 1 else "R^2"
ax.annotate(
f"$r = {r:.2f}$",
f"{metric} = {corr:.2f}",
(0.01, start_y),
xycoords="axes fraction",
fontsize=12,
Expand Down
Loading

0 comments on commit 8e30b43

Please sign in to comment.