Skip to content

Commit

Permalink
added plotting unit tests and ge scale ridge penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 committed Jun 17, 2024
1 parent bee4544 commit c6b6845
Show file tree
Hide file tree
Showing 5 changed files with 1,099 additions and 925 deletions.
19 changes: 17 additions & 2 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def softplus_activation(d_params, act, lower_bound=-3.5, hinge_scale=0.1, **kwar
hinge_scale
# GAMMA
# * (jnp.logaddexp(0, (act - (lower_bound + d_params["gamma_d"])) / hinge_scale))
* (jnp.logaddexp(0, act - lower_bound / hinge_scale))
* (jnp.logaddexp(0, (act - lower_bound) / hinge_scale))
+ lower_bound
# GAMMA
# + d_params["gamma_d"]
Expand Down Expand Up @@ -413,6 +413,8 @@ def smooth_objective(
params,
data,
scale_coeff_ridge_beta=0.0,
scale_coeff_ridge_ge_scale=0.0,
scale_coeff_ridge_ge_bias=0.0,
huber_scale=1,
**kwargs,
):
Expand All @@ -432,6 +434,10 @@ def smooth_objective(
Scale parameter for Huber loss function
scale_coeff_ridge_beta : float
Ridge penalty coefficient for shift parameters
scale_coeff_ridge_ge_scale : float
Ridge penalty coefficient for global epistasis scale parameter
scale_coeff_ridge_ge_bias : float
Ridge penalty coefficient for global epistasis bias parameter
kwargs : dict
Additional keyword arguments to pass to the biophysical model function
Expand Down Expand Up @@ -474,4 +480,13 @@ def smooth_objective(

huber_cost /= len(X)

return huber_cost + beta_ridge_penalty
ge_scale_ridge_penalty = (
scale_coeff_ridge_ge_scale * (params["theta"]["ge_scale"] ** 2).sum()
)
ge_bias_ridge_penalty = (
scale_coeff_ridge_ge_bias * (params["theta"]["ge_bias"] ** 2).sum()
)

return (
huber_cost + beta_ridge_penalty + ge_scale_ridge_penalty + ge_bias_ridge_penalty
)
172 changes: 67 additions & 105 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import math
import warnings
from functools import lru_cache, partial, reduce, cached_property
from functools import lru_cache, partial, cached_property
from frozendict import frozendict

from multidms import Data
Expand Down Expand Up @@ -267,7 +267,10 @@ def __init__(
)

elif epistatic_model == multidms.biophysical.identity_activation:
self._scaled_data_params["theta"] = dict(ghost_param=jnp.zeros(shape=(1,)))
self._scaled_data_params["theta"] = dict(
ge_scale=jnp.zeros(shape=(1,)),
ge_bias=jnp.zeros(shape=(1,)),
)

elif epistatic_model == multidms.biophysical.nn_global_epistasis:
if n_hidden_units is None:
Expand Down Expand Up @@ -484,7 +487,7 @@ def get_variants_df(self, phenotype_as_effect=True):
based on the current state of the model.
"""
# this is what well update and return
variants_df = self._data.variants_df.copy()
variants_df = self.data.variants_df.copy()

# initialize new columns
for pheno in ["latent", "func_score"]:
Expand Down Expand Up @@ -871,20 +874,28 @@ def add_phenotypes_to_df(

return ret

def mutation_site_summary_df(self, agg_func=onp.mean, times_seen_threshold=0):
def mutation_site_summary_df(self, agg_func="mean", **kwargs):
"""
Get all single mutational attributes from self._data
updated with all model specific attributes, then aggregate
all numerical columns by "sites" using
``agg`` function. The mean values are given by default.
all numerical columns by "sites"
Parameters
----------
agg_func : str
Aggregation function to use on the numerical columns.
Defaults to "mean".
**kwargs
Additional keyword arguments to pass to get_mutations_df.
Returns
-------
pandas.DataFrame
A summary of the mutation attributes aggregated by site.
"""
numerics = ["int16", "int32", "int64", "float16", "float32", "float64"]
mut_df = self.mutations_df.select_dtypes(include=numerics)
times_seen_cols = [c for c in mut_df.columns if "times" in c]
for c in times_seen_cols:
mut_df = mut_df[mut_df[c] >= times_seen_threshold]

return mut_df.groupby("sites").aggregate(agg_func)
mut_df = self.get_mutations_df(**kwargs).select_dtypes(include=numerics)
return mut_df.groupby("sites").agg(agg_func)

def get_condition_params(self, condition=None):
"""Get the relent parameters for a model prediction"""
Expand Down Expand Up @@ -1193,7 +1204,7 @@ def plot_pred_accuracy(
between model predicted functional score of all
variants in the training with ground truth measurements.
"""
df = self.variants_df
df = self.get_variants_df(phenotype_as_effect=False)

df = df.assign(
is_wt=df["aa_substitutions"].apply(
Expand All @@ -1204,7 +1215,9 @@ def plot_pred_accuracy(
if ax is None:
fig, ax = plt.subplots(figsize=[3, 3])

func_score = "corrected_func_score" if self.gamma_corrected else "func_score"
# GAMMA
# func_score = "corrected_func_score" if self.gamma_corrected else "func_score"
func_score = "func_score"
sns.scatterplot(
data=df.sample(frac=1),
x="predicted_func_score",
Expand Down Expand Up @@ -1240,13 +1253,14 @@ def plot_pred_accuracy(
c=self._data.condition_colors[c],
)
start_y += -0.05
ax.set_ylabel("functional score")
# ax.set_ylabel("functional score")
ax.set_xlabel("predicted functional score")

ax.axhline(0, color="k", ls="--", lw=2)
ax.axvline(0, color="k", ls="--", lw=2)

ax.set_ylabel("functional score + gamma$_{d}$")
# ax.set_ylabel("functional score + gamma$_{d}$")
ax.set_ylabel("measured functional score")
plt.tight_layout()
if saveas:
fig.savefig(saveas)
Expand All @@ -1262,7 +1276,7 @@ def plot_epistasis(
gamma corrected ground truth measurements
of all samples in the training set.
"""
df = self.variants_df
df = self.get_variants_df(phenotype_as_effect=False)

df = df.assign(
is_wt=df["aa_substitutions"].apply(
Expand All @@ -1273,7 +1287,9 @@ def plot_epistasis(
if ax is None:
fig, ax = plt.subplots(figsize=[3, 3])

func_score = "corrected_func_score" if self.gamma_corrected else "func_score"
# GAMMA
# func_score = "corrected_func_score" if self.gamma_corrected else "func_score"
func_score = "func_score"
sns.scatterplot(
data=df.sample(frac=sample),
x="predicted_latent",
Expand Down Expand Up @@ -1305,7 +1321,7 @@ def plot_epistasis(
ax.axhline(0, color="k", ls="--", lw=2)
ax.set_xlim([xlb, xub])
ax.set_ylim([ylb, yub])
ax.set_ylabel("functional score")
ax.set_ylabel("measured functional score")
ax.set_xlabel("predicted latent phenotype")
plt.tight_layout()

Expand All @@ -1316,10 +1332,10 @@ def plot_epistasis(
return ax

def plot_param_hist(
self, param, show=True, saveas=False, times_seen_threshold=3, ax=None, **kwargs
self, param, show=True, saveas=False, times_seen_threshold=0, ax=None, **kwargs
):
"""Plot the histogram of a parameter."""
mut_effects_df = self.mutations_df
mut_effects_df = self.get_mutations_df()

if ax is None:
fig, ax = plt.subplots(figsize=[3, 3])
Expand Down Expand Up @@ -1372,18 +1388,18 @@ def plot_param_hist(
return ax

def plot_param_heatmap(
self, param, show=True, saveas=False, times_seen_threshold=3, ax=None, **kwargs
self, param, show=True, saveas=False, times_seen_threshold=0, ax=None, **kwargs
):
"""
Plot the heatmap of a parameters
associated with specific sites and substitutions.
"""
if not param.startswith("beta") and not param.startswith("S"):
if not param.startswith("beta") and not param.startswith("shift"):
raise ValueError(
"Parameter to visualize must be an existing beta, or shift parameter"
)

mut_effects_df = self.mutations_df
mut_effects_df = self.get_mutations_df()

if ax is None:
fig, ax = plt.subplots(figsize=[12, 3])
Expand Down Expand Up @@ -1419,8 +1435,8 @@ def plot_shifts_by_site(
condition,
show=True,
saveas=False,
times_seen_threshold=3,
agg_func=onp.mean,
times_seen_threshold=0,
agg_func="mean",
ax=None,
**kwargs,
):
Expand Down Expand Up @@ -1470,65 +1486,6 @@ def plot_shifts_by_site(
plt.show()
return ax

def plot_fit_param_comp_scatter(
self,
other,
self_param="beta",
other_param="beta",
figsize=[5, 4],
saveas=None,
show=True,
site_agg_func=None,
):
"""Plot a scatter plot of the parameter values of two models"""
if not site_agg_func:
dfs = [self.mutations_df, other.mutations_df]
else:
dfs = [
self.mutation_site_summary_df(agg=site_agg_func).reset_index(),
other.mutation_site_summary_df(agg=site_agg_func).reset_index(),
]

combine_on = "mutation" if site_agg_func is None else "sites"
comb_mut_effects = reduce(
lambda l, r: pd.merge(l, r, how="inner", on=combine_on), # noqa: E741
dfs,
)
comb_mut_effects["is_stop"] = [
True if "*" in s else False for s in comb_mut_effects[combine_on]
]

same = self_param == other_param
x = f"{self_param}_x" if same else self_param
y = f"{other_param}_y" if same else other_param

fig, ax = plt.subplots(figsize=figsize)
r = pearsonr(comb_mut_effects[x], comb_mut_effects[y])[0]
sns.scatterplot(
data=comb_mut_effects,
x=x,
y=y,
hue="is_stop",
alpha=0.6,
palette="deep",
ax=ax,
)

xlb, xub = [-1, 1] + onp.quantile(comb_mut_effects[x], [0.00, 1.0])
ylb, yub = [-1, 1] + onp.quantile(comb_mut_effects[y], [0.00, 1.0])
min1 = min(xlb, ylb)
max1 = max(xub, yub)
ax.plot([min1, max1], [min1, max1], ls="--", c="k")
ax.annotate(f"$r = {r:.2f}$", (0.7, 0.1), xycoords="axes fraction", fontsize=12)
plt.tight_layout()

if saveas:
fig.saveas(saveas)
if show:
plt.show()

return fig, ax

def mut_param_heatmap(
self,
mut_param="shift",
Expand All @@ -1548,7 +1505,13 @@ def mut_param_heatmap(
muts_df = self.get_mutations_df(
times_seen_threshold=times_seen_threshold,
phenotype_as_effect=phenotype_as_effect,
return_split=False,
return_split=True,
).rename(
columns={
"wts": "wildtype",
"muts": "mutant",
"sites": "site",
}
)

# drop columns which are not the mutational parameter of interest
Expand All @@ -1558,9 +1521,9 @@ def mut_param_heatmap(
muts_df.drop(drop_cols, axis=1, inplace=True)

# add in the mutation annotations
muts_df["wildtype"], muts_df["site"], muts_df["mutant"] = zip(
*muts_df.reset_index()["mutation"].map(self.data.parse_mut)
)
# muts_df["wildtype"], muts_df["site"], muts_df["mutant"] = zip(
# *muts_df.reset_index()["mutation"].map(self.data.parse_mut)
# )

# no longer need mutation annotation
muts_df.reset_index(drop=True, inplace=True)
Expand Down Expand Up @@ -1597,22 +1560,21 @@ def mut_param_heatmap(
# melt conditions and stats cols, beta is already "tall"
# note that we must rename conditions with "." in the
# name to "_" to avoid altair errors
if mut_param == "beta":
muts_df_tall = muts_df.assign(condition=reference.replace(".", "_"))
else:
muts_df_tall = muts_df.melt(
id_vars=["wildtype", "site", "mutant"] + addtl_tooltip_stats,
value_vars=[c for c in muts_df.columns if c.startswith(mut_param)],
var_name="condition",
value_name=mut_param,
)
muts_df_tall.condition.replace(
{
f"{mut_param}_{condition}": condition.replace(".", "_")
for condition in conditions
},
inplace=True,
)
# if mut_param == f"beta_{reference}":
# muts_df_tall = muts_df.assign(condition=reference.replace(".", "_"))
# else:
muts_df_tall = muts_df.melt(
id_vars=["wildtype", "site", "mutant"] + addtl_tooltip_stats,
value_vars=[c for c in muts_df.columns if c.startswith(mut_param)],
var_name="condition",
value_name=mut_param,
)
muts_df_tall["condition"] = muts_df_tall.condition.replace(
{
f"{mut_param}_{condition}": condition.replace(".", "_")
for condition in conditions
},
)

# add in condition colors, rename for altair
condition_colors = {
Expand Down
Loading

0 comments on commit c6b6845

Please sign in to comment.