From 73ef1ea15e9ea539d5cdff7afc08ac9edbb24227 Mon Sep 17 00:00:00 2001 From: jgallowa07 Date: Tue, 21 May 2024 16:05:17 -0700 Subject: [PATCH] passing simple single condition fit --- multidms/model.py | 9 +++++++-- tests/test_data.py | 26 +++++++++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/multidms/model.py b/multidms/model.py index 1f7f5ac..e242268 100644 --- a/multidms/model.py +++ b/multidms/model.py @@ -1030,7 +1030,9 @@ def fit( reference-condition latent offset (TODO math? beta0[ref]) are locked. admm_niter : int Number of iterations to perform during the ADMM optimization. - Defaults to 50. + Defaults to 50. Note that in the case of single-condition models, + This is set to zero as the generalized + lasso ADMM optimization is not used. admm_tau : float ADMM step size. Defaults to 1.0. warn_unconverged : bool @@ -1106,7 +1108,10 @@ def fit( admm_mu = 0.99 * admm_tau / eig - assert 0 < admm_mu < admm_tau / eig + if len(self.data.conditions) > 1: + assert 0 < admm_mu < admm_tau / eig + + admm_niter = 0 if len(self.data.conditions) == 1 else admm_niter lock_params[("beta0", self.data.reference)] = 0.0 diff --git a/tests/test_data.py b/tests/test_data.py index e09cbb1..8ec6dfb 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -255,20 +255,20 @@ def test_difference_matrix(): """ -# def test_single_condition_fit(): -# """ -# Simple test to see that the single-condition model -# fits without error. -# """ -# data = multidms.Data( -# TEST_FUNC_SCORES.query("condition == 'a'"), -# alphabet=multidms.AAS_WITHSTOP, -# reference="a", -# assert_site_integrity=False, -# ) -# model = multidms.Model(data, PRNGKey=23) +def test_single_condition_fit(): + """ + Simple test to see that the single-condition model + fits without error. + """ + data = multidms.Data( + TEST_FUNC_SCORES.query("condition == 'a'"), + alphabet=multidms.AAS_WITHSTOP, + reference="a", + assert_site_integrity=False, + ) + model = multidms.Model(data, PRNGKey=23) -# model.fit(maxiter=2, warn_unconverged=False) + model.fit(maxiter=2, warn_unconverged=False) def test_fit_simple():