Skip to content

Commit

Permalink
test: additional multi model parameter test
Browse files Browse the repository at this point in the history
  • Loading branch information
hellkite500 committed Aug 28, 2024
1 parent 3c3428b commit 933b465
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/ngen_cal/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,12 @@ def multi_model_shared_params() -> Mapping[str, list[Parameter]]:
params = {'A':[p1], 'B':[p2], 'C':[p3]}

return params

@pytest.fixture
def multi_model_shared_params2() -> Mapping[str, list[Parameter]]:
p1 = Parameter(name='a', min=0, max=1, init=0)
p2 = Parameter(name='a', min=0, max=1, init=0)
p3 = Parameter(name='c', min=0, max=1, init=0)
params = {'A':[p1, p3], 'B':[p2]}

return params
13 changes: 13 additions & 0 deletions python/ngen_cal/tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,16 @@ def test_multi_params(multi_model_shared_params: Mapping[str, list[Parameter]]):
assert pa.drop('model', axis=1).equals( pb.drop('model', axis=1) )
# ensure unique params/alias are not modifed by selection
assert pa.loc['a', '1'] != pc.loc['c', '1']

def test_multi_params2(multi_model_shared_params2: Mapping[str, list[Parameter]]):
# This is essentially the path the params go through from
# creation in model.py to update in search.py
params = _params_as_df(multi_model_shared_params2)
params = pd.DataFrame(params).rename(columns={'init':'0'})
# create new iteration from old
params['1'] = params['0']
#update the parameters by index
params.loc['a', '1'] = 0.5
pa = params[ params['model'] == 'A' ].drop('model', axis=1).loc['a']
pb = params[ params['model'] == 'B' ].drop('model', axis=1).loc['a']
assert pa.equals( pb )

0 comments on commit 933b465

Please sign in to comment.