Skip to content

Commit

Permalink
passing simple single condition fit
Browse files Browse the repository at this point in the history
  • Loading branch information
jgallowa07 committed May 21, 2024
1 parent 1b03815 commit 73ef1ea
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
9 changes: 7 additions & 2 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,9 @@ def fit(
reference-condition latent offset (TODO math? beta0[ref]) are locked.
admm_niter : int
Number of iterations to perform during the ADMM optimization.
Defaults to 50.
Defaults to 50. Note that in the case of single-condition models,
This is set to zero as the generalized
lasso ADMM optimization is not used.
admm_tau : float
ADMM step size. Defaults to 1.0.
warn_unconverged : bool
Expand Down Expand Up @@ -1106,7 +1108,10 @@ def fit(

admm_mu = 0.99 * admm_tau / eig

assert 0 < admm_mu < admm_tau / eig
if len(self.data.conditions) > 1:
assert 0 < admm_mu < admm_tau / eig

admm_niter = 0 if len(self.data.conditions) == 1 else admm_niter

lock_params[("beta0", self.data.reference)] = 0.0

Expand Down
26 changes: 13 additions & 13 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,20 +255,20 @@ def test_difference_matrix():
"""


# def test_single_condition_fit():
# """
# Simple test to see that the single-condition model
# fits without error.
# """
# data = multidms.Data(
# TEST_FUNC_SCORES.query("condition == 'a'"),
# alphabet=multidms.AAS_WITHSTOP,
# reference="a",
# assert_site_integrity=False,
# )
# model = multidms.Model(data, PRNGKey=23)
def test_single_condition_fit():
"""
Simple test to see that the single-condition model
fits without error.
"""
data = multidms.Data(
TEST_FUNC_SCORES.query("condition == 'a'"),
alphabet=multidms.AAS_WITHSTOP,
reference="a",
assert_site_integrity=False,
)
model = multidms.Model(data, PRNGKey=23)

# model.fit(maxiter=2, warn_unconverged=False)
model.fit(maxiter=2, warn_unconverged=False)


def test_fit_simple():
Expand Down

0 comments on commit 73ef1ea

Please sign in to comment.