Skip to content

Commit

Permalink
latest before moving spike analysis (#141)
Browse files Browse the repository at this point in the history
* conditional loss
  • Loading branch information
jgallowa07 authored Mar 11, 2024
1 parent e376173 commit 1325dcf
Show file tree
Hide file tree
Showing 5 changed files with 2,210 additions and 274 deletions.
8 changes: 0 additions & 8 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,6 @@ def targets(self) -> dict:
"""The functional scores for each variant in the training data."""
return self._training_data["y"]

# TODO, rename mutparser
@property
def mutparser(self) -> MutationParser:
"""
Expand All @@ -608,7 +607,6 @@ def mutparser(self) -> MutationParser:
"""
return self._mutparser

# TODO, rename
@property
def parse_mut(self) -> MutationParser:
"""
Expand All @@ -618,7 +616,6 @@ def parse_mut(self) -> MutationParser:
"""
return self.mutparser.parse_mut

# TODO, document rename issue
@property
def parse_muts(self) -> partial:
"""
Expand All @@ -628,11 +625,6 @@ def parse_muts(self) -> partial:
"""
return self._parse_muts

# TODO should this be cached? how does caching interact with the way in
# which we applying this function in parallel?
# although, unless the variants are un-collapsed, this cache will be
# pretty useless.
# although it could be useful for the Model.add_phenotypes_to_df method.
def convert_subs_wrt_ref_seq(self, condition, aa_subs):
"""
Covert amino acid substitutions to be with respect to the reference sequence.
Expand Down
64 changes: 56 additions & 8 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
epistatic_model=multidms.biophysical.sigmoidal_global_epistasis,
output_activation=multidms.biophysical.identity_activation,
conditional_shifts=True,
alpha_d=False, # TODO raise issue to be squashed in this PR
alpha_d=False,
gamma_corrected=False,
PRNGKey=0,
init_beta_naught=0.0,
Expand Down Expand Up @@ -375,6 +375,25 @@ def loss(self) -> float:
data = (self.data.training_data["X"], self.data.training_data["y"])
return jax.jit(self.model_components["objective"])(self.params, data, **kwargs)

@property
def conditional_loss(self) -> float:
"""Compute loss individually for each condition."""
kwargs = {
"scale_coeff_ridge_beta": 0.0,
"scale_coeff_ridge_shift": 0.0,
"scale_coeff_ridge_gamma": 0.0,
"scale_ridge_alpha_d": 0.0,
}

X, y = self.data.training_data["X"], self.data.training_data["y"]
loss_fxn = jax.jit(self.model_components["objective"])
ret = {}
for condition in self.data.conditions:
condition_data = ({condition: X[condition]}, {condition: y[condition]})
ret[condition] = float(loss_fxn(self.params, condition_data, **kwargs))
ret["total"] = sum(ret.values())
return ret

@property
def variants_df(self):
"""
Expand Down Expand Up @@ -546,7 +565,7 @@ def get_mutations_df(

return mutations_df[col_order]

def get_df_loss(self, df, error_if_unknown=False, verbose=False):
def get_df_loss(self, df, error_if_unknown=False, verbose=False, conditional=False):
"""
Get the loss of the model on a given data frame.
Expand All @@ -563,10 +582,13 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):
in the loss calculation. If `True`, raise an error.
verbose : bool
If True, print the number of valid and invalid variants.
conditional : bool
If True, return the loss for each condition as a dictionary.
If False, return the total loss.
Returns
-------
float
float or dict
The loss of the model on the given data frame.
"""
substitutions_col = "aa_substitutions"
Expand All @@ -579,8 +601,11 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):
if condition_col not in df.columns:
raise ValueError("`df` lacks `condition_col` " f"{condition_col}")

X, y = {}, {}
loss_fxn = jax.jit(self.model_components["objective"])

ret = {}
for condition, condition_df in df.groupby(condition_col):
X, y = {}, {}
variant_subs = condition_df[substitutions_col]
if condition not in self.data.reference_sequence_conditions:
variant_subs = condition_df.apply(
Expand All @@ -592,14 +617,23 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):

# build binary variants as csr matrix, make prediction, and append
valid, invalid = 0, 0 # row indices of elements that are one
binary_variants = []
# binary_variants = []
variant_targets = []
row_ind = [] # row indices of elements that are one
col_ind = [] # column indices of elements that are one

for subs, target in zip(variant_subs, condition_df[func_score_col]):
try:
binary_variants.append(ref_bmap.sub_str_to_binary(subs))
# binary_variants.append(ref_bmap.sub_str_to_binary(subs))
# variant_targets.append(target)
# valid += 1

for isub in ref_bmap.sub_str_to_indices(subs):
row_ind.append(valid)
col_ind.append(isub)
variant_targets.append(target)
valid += 1

except ValueError:
if error_if_unknown:
raise ValueError(
Expand All @@ -615,12 +649,26 @@ def get_df_loss(self, df, error_if_unknown=False, verbose=False):
f"{valid}, n invalid variants: {invalid}"
)

# X[condition] = sparse.BCOO.from_scipy_sparse(
# scipy.sparse.csr_matrix(onp.vstack(binary_variants))
# )
X[condition] = sparse.BCOO.from_scipy_sparse(
scipy.sparse.csr_matrix(onp.vstack(binary_variants))
scipy.sparse.csr_matrix(
(onp.ones(len(row_ind), dtype="int8"), (row_ind, col_ind)),
shape=(valid, ref_bmap.binarylength),
dtype="int8",
)
)

y[condition] = jnp.array(variant_targets)

return self.model_components["objective"](self.params, (X, y))
ret[condition] = float(loss_fxn(self.params, (X, y)))

ret["total"] = sum(ret.values())

if not conditional:
return ret["total"]
return ret

def add_phenotypes_to_df(
self,
Expand Down
124 changes: 102 additions & 22 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,19 @@ def __init__(self, fit_models):
)
all_mutations = set.union(all_mutations, set(fit.data.mutations))

# add the final training loss to the fit_models dataframe
fit_models["training_loss"] = fit_models.step_loss.apply(lambda x: x[-1])
# initialize empty columns for conditional loss
fit_models.assign(
**{
f"{condition}_loss_training": onp.nan
for condition in first_dataset.conditions
},
total_loss=onp.nan,
)
# assign coditional loss columns
for idx, fit in fit_models.iterrows():
conditional_loss = fit.model.conditional_loss
for condition, loss in conditional_loss.items():
fit_models.loc[idx, f"{condition}_loss_training"] = loss

self._site_map_union = site_map_union
self._conditions = first_dataset.conditions
Expand Down Expand Up @@ -432,7 +443,6 @@ def all_mutations(self) -> tuple:
"""The mutations shared by each fitting dataset."""
return self._all_mutations

# TODO remove verbose everywhere
@lru_cache(maxsize=10)
def split_apply_combine_muts(
self,
Expand Down Expand Up @@ -482,32 +492,52 @@ def split_apply_combine_muts(
A dataframe containing the aggregated mutational parameter values
"""
print("cache miss - this could take a moment")
queried_fits = (
self.fit_models.query(query) if query is not None else self.fit_models
)
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")

if groupby is None:
groupby = tuple(
set(self.fit_models.columns)
- set(["model", "data", "step_loss", "verbose"])
# groupby = tuple(
# set(queried_fits.columns)
# - set(
# ["model", "dataset_name", "verbose"]
# + [col for col in queried_fits.columns if "loss" in col]
# )
# )
ret = (
pd.concat(
[
fit["model"].get_mutations_df(return_split=False, **kwargs)
for _, fit in queried_fits.iterrows()
],
join="inner", # the columns will always match based on class req.
)
.query(
f"mutation.isin({list(self.shared_mutations)})"
if inner_merge_dataset_muts
else "mutation.notna()"
)
.groupby("mutation")
.aggregate(aggregate_func)
)
return ret

elif isinstance(groupby, str):
groupby = tuple([groupby])

elif isinstance(groupby, tuple):
if not all(feature in self.fit_models.columns for feature in groupby):
if not all(feature in queried_fits.columns for feature in groupby):
raise ValueError(
f"invalid groupby, values must be in {self.fit_models.columns}"
)
else:
raise ValueError(
"invalid groupby, must be tuple with values "
f"in {self.fit_models.columns}"
f"in {queried_fits.columns}"
)

queried_fits = (
self.fit_models.query(query) if query is not None else self.fit_models
)
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")

ret = pd.concat(
[
pd.concat(
Expand Down Expand Up @@ -566,20 +596,69 @@ def add_validation_loss(self, test_data, overwrite=False):
# check there's a testing dataframe for each unique dataset_name
assert set(test_data.keys()) == set(self.fit_models["dataset_name"].unique())

if "validation_loss" in self.fit_models.columns and not overwrite:
validation_cols_exist = onp.any(
[
f"{condition}_loss_validation" in self.fit_models.columns
for condition in self.conditions
]
)
if validation_cols_exist and not overwrite:
raise ValueError(
"validation_loss already exists in self.fit_models, set overwrite=True "
"to overwrite"
)

self.fit_models["validation_loss"] = onp.nan
self.fit_models = self.fit_models.assign(
**{
f"{condition}_loss_validation": onp.nan for condition in self.conditions
},
total_loss_validation=onp.nan,
)

for idx, fit in self.fit_models.iterrows():
self.fit_models.loc[idx, "validation_loss"] = fit["model"].get_df_loss(
test_data[fit["dataset_name"]]
condional_df_loss = fit.model.get_df_loss(
test_data[fit["dataset_name"]], conditional=True
)
for condition, loss in condional_df_loss.items():
self.fit_models.loc[idx, f"{condition}_loss_validation"] = loss

return None

def get_conditional_loss_df(self, query=None):
"""
return a long form dataframe with columns
"dataset_name", "scale_coeff_lasso_shift",
"split" ("training" or "validation"),
"loss" (actual value), and "condition".
Parameters
----------
query : str, optional
The query to apply to the fit_models dataframe
before formatting the loss dataframe. The default is None.
"""
if query is not None:
queried_fits = self.fit_models.query(query)
else:
queried_fits = self.fit_models
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")

id_vars = ["dataset_name", "scale_coeff_lasso_shift"]
value_vars = [
c for c in queried_fits.columns if "loss" in c and c != "step_loss"
]
loss_df = queried_fits.melt(
id_vars=id_vars,
value_vars=value_vars,
var_name="condition",
value_name="loss",
).assign(
split=lambda x: x.condition.str.split("_").str.get(-1),
condition=lambda x: x.condition.str.split("_").str[:-2].str.join("_"),
)
return loss_df

def mut_param_heatmap(
self,
query=None,
Expand Down Expand Up @@ -652,7 +731,11 @@ def mut_param_heatmap(
if len(queried_fits) == 0:
raise ValueError("invalid query, no fits returned")
shouldbe_uniform = list(
set(queried_fits.columns) - set(["model", "dataset_name", "step_loss"])
set(queried_fits.columns)
- set(
["model", "dataset_name"]
+ [col for col in queried_fits.columns if "loss" in col]
)
)
if len(queried_fits.groupby(list(shouldbe_uniform)).groups) > 1:
raise ValueError(
Expand Down Expand Up @@ -921,9 +1004,6 @@ def mut_type(mut):
return "stop" if mut.endswith("*") else "nonsynonymous"

# apply, drop, and melt
# TODO This throws deprecation warning
# because of the include_groups argument ...
# set to False, and lose the drop call after ...
sparsity_df = (
df.drop(columns=to_throw)
.assign(mut_type=lambda x: x.mutation.apply(mut_type))
Expand Down
Loading

0 comments on commit 1325dcf

Please sign in to comment.