Skip to content

Commit

Permalink
Further changed unit tests to be more informative and make point of f…
Browse files Browse the repository at this point in the history
…ailure more apparent and fixed bug where changed parameter ids not showing in parameter list of new_amr
  • Loading branch information
nanglo123 committed Sep 1, 2023
1 parent bbafbd8 commit e9cb3dc
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 84 deletions.
17 changes: 16 additions & 1 deletion mira/modeling/askenet/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ def replace_parameter_id(tm, old_id, new_id):
observable.expression = SympyExprStr(
observable.expression.args[0].subs(sympy.Symbol(old_id),
sympy.Symbol(new_id)))

for key, param in copy.deepcopy(tm.parameters).items():
if param.name == old_id:
try:
popped_param = tm.parameters.pop(param.name)
popped_param.name = new_id
tm.parameters[new_id] = popped_param
except KeyError:
print('Old id: {} is not present in parameter dictionary of the template model'.format(old_id))
return tm


Expand All @@ -85,7 +94,7 @@ def replace_initial_id(tm, old_id, new_id):
tm.initials = {
(new_id if k == old_id else k): v for k, v in tm.initials.items()
}
return tm
return tm


# Remove state
Expand Down Expand Up @@ -128,6 +137,12 @@ def replace_rate_law_mathml(tm, transition_id, new_rate_law):
return replace_rate_law_sympy(tm, transition_id, new_rate_law_sympy)


@amr_to_mira
def add_parameter(tm, parameter_id: str, description: str, expression_xml: str, value: float, distribution_type: str,
parameters: dict[str, float]):
pass


@amr_to_mira
def stratify(*args, **kwargs):
return tmops.stratify(*args, **kwargs)
Expand Down
170 changes: 87 additions & 83 deletions tests/test_modeling/test_askenet_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ def setUpClass(cls):
'https://raw.githubusercontent.com/DARPA-ASKEM/'
'Model-Representations/main/petrinet/examples/sir.json').json()

def test_replace_state_id(self):
'''These unit tests are conducted by zipping through lists of each key in a amr file
(e.g. parameters, observables, etc). Zipping in this manner assumes that the order (assuming insertion order)
is preserved before and after mira operation for an amr key
'''

def test_replace_state_id(self):
old_id = 'S'
new_id = 'X'
amr = _d(self.sir_amr)
Expand All @@ -34,24 +38,24 @@ def test_replace_state_id(self):

self.assertEqual(len(old_model_states), len(new_model_states))
for old_state, new_state in zip(old_model_states, new_model_states):

# output states missing description field
if old_state['id'] == old_id:
self.assertEqual(new_state['id'], new_id)
self.assertEqual(old_state['name'], new_state['name'])
self.assertEqual(old_state['grounding']['identifiers'], new_state['grounding']['identifiers'])
self.assertEqual(old_state['units'], new_state['units'])

self.assertEqual(len(old_model_transitions), len(new_model_transitions))
for old_transition, new_transition in zip(old_model_transitions, new_model_transitions):
if old_id in old_transition['input']:
new_value_in_transition_input = new_id in new_transition['input']
old_value_out_transition_input = old_id not in new_transition['input']
equal_length_input = (len(old_transition['input']) == len(new_transition['input']))

self.assertTrue(new_value_in_transition_input and old_value_out_transition_input and equal_length_input)
if old_id in old_transition['output']:
new_value_in_transition_output = new_id in new_transition['output']
old_value_out_transition_output = old_id not in new_transition['output']
equal_length_output = (len(old_transition['output']) == len(new_transition['output']))

self.assertTrue(
new_value_in_transition_output and old_value_out_transition_output and equal_length_output)
# output transitions are missing a description and ['properties']['name'] field is abbreviated in output amr
for old_transition, new_transition in zip(old_model_transitions, new_model_transitions):
if old_id in old_transition['input'] or old_id in old_transition['output']:
self.assertIn(new_id, new_transition['input'])
self.assertNotIn(old_id, new_transition['output'])
self.assertEqual(len(old_transition['input']), len(new_transition['input']))
self.assertEqual(len(old_transition['output']), len(new_transition['output']))
self.assertEqual(old_transition['id'], new_transition['id'])

old_semantics_ode = amr['semantics']['ode']
new_semantics_ode = new_amr['semantics']['ode']
Expand All @@ -62,19 +66,16 @@ def test_replace_state_id(self):
# this test doesn't account for if the expression semantic is preserved (e.g. same type of operations)
# would pass test if we call replace_state_id(I,J) and old expression is "I*X" and new expression is "J+X"
for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates):
if (old_id in old_rate['expression']) or (old_id in old_rate['expression_mathml']):
new_value_in_rate_expression = (new_id in new_rate['expression'])
old_value_out_rate_expression = (old_id not in new_rate['expression'])
expression_flag = (new_value_in_rate_expression and old_value_out_rate_expression)
if old_id in old_rate['expression'] or old_id in old_rate['expression_mathml']:
self.assertIn(new_id, new_rate['expression'])
self.assertNotIn(old_id, new_rate['expression'])

new_value_in_rate_mathml = (new_id in new_rate['expression_mathml'])
old_value_out_rate_mathml = (old_id not in new_rate['expression_mathml'])
mathml_flag = (new_value_in_rate_mathml and old_value_out_rate_mathml)
self.assertIn(new_id, new_rate['expression_mathml'])
self.assertNotIn(old_id, new_rate['expression_mathml'])

self.assertTrue(expression_flag and mathml_flag)
self.assertEqual(old_rate['target'], new_rate['target'])

# for initials, the dict representing states in the initials list have changed expression values in new_amr
# from variables to float values
# initials have float values substituted in for state ids in their expression and expression_mathml field
old_semantics_ode_initials = old_semantics_ode['initials']
new_semantics_ode_initials = new_semantics_ode['initials']

Expand All @@ -88,37 +89,34 @@ def test_replace_state_id(self):
# This is due to initial expressions vs values
assert len(old_semantics_ode_parameters) == 5
assert len(new_semantics_ode_parameters) == 2
# state id and name for each parameter dict in list of parameters has a '0' appended to it (not 'S' but 'S0')
# so test for equality with 0 and subscript ₀ appended to old state id (assuming that 0 and subscript are
# timestamps that will be constantly changing) and then taste if new state id and name field for each
# parameter is equal to just state id

# Currently this test passes no matter what as zip function only iterates over the shorter of two lists
# since output only has 2 entries (parameter entries of beta and gamma) as opposed to 5 from
# old amr, this test will pass
# old amr, this test will pass as this loop only makes 2 iterations
for old_params, new_params in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
if ((old_id + '0') in old_params['id']) or ((old_id + '₀') in old_params['name']):
self.assertEqual(new_id, new_params['id'])
self.assertEqual(new_id, new_params['name'])
# test to see if old_id/new_id in name/id field and not for equality because these fields
# may contain subscripts or timestamps appended to the old_id/new_id
if old_id in old_params['id'] and old_id in old_params['name']:
self.assertIn(new_id, new_params['id'])
self.assertIn(new_id, new_params['name'])

old_semantics_ode_observables = old_semantics_ode['observables']
new_semantics_ode_observables = new_semantics_ode['observables']
self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables))

# each observable dict in list of observables in new amr does not have the states key which is a list of states
# cannot test states field
# expression for each observable has an extra space between states in new_amr
for old_observable, new_observable in zip(old_semantics_ode_observables, new_semantics_ode_observables):
if (old_id in old_observable['states']) or (old_id in old_observable['expression']) or \
(old_id in old_observable['expression_mathml']):
new_value_in_observable_expression = (new_id in new_observable['expression'])
old_value_out_observable_expression = (old_id not in new_observable['expression'])
expression_flag = (new_value_in_observable_expression and old_value_out_observable_expression)
if old_id in old_observable['states'] and old_id in old_observable['expression'] and \
old_id in old_observable['expression_mathml']:
self.assertIn(new_id, new_observable['expression'])
self.assertNotIn(old_id, new_observable['expression'])

new_value_in_observable_mathml = new_id in new_observable['expression_mathml']
old_value_out_observable_mathml = old_id not in new_observable['expression_mathml']
mathml_flag = (new_value_in_observable_mathml and old_value_out_observable_mathml)
self.assertIn(new_id, new_observable['expression_mathml'])
self.assertNotIn(old_id, new_observable['expression_mathml'])

self.assertTrue(expression_flag and mathml_flag)
self.assertEqual(old_observable['id'], new_observable['id'])

def test_replace_transition_id(self):
old_id = 'inf'
Expand Down Expand Up @@ -150,13 +148,14 @@ def test_replace_observable_id(self):

for old_observable, new_observable in zip(old_semantics_observables, new_semantics_observables):
if old_observable['id'] == old_id:
self.assertEqual(new_observable['id'], new_id) and self.assertEqual(new_observable['name'], new_id)
self.assertEqual(new_observable['id'], new_id)
self.assertEqual(new_observable['name'], new_id)

# current bug is that it doesn't return the changed parameter in new_amr['semantics']['ode']['parameters']
# expected 2 returned parameters in list of parameters, only got 1 (the 1 that wasn't changed)
def test_replace_parameter_id(self):
old_id = 'beta'
new_id = 'zeta'
new_id = 'TEST'
amr = _d(self.sir_amr)
new_amr = replace_parameter_id(amr, old_id, new_id)

Expand All @@ -166,44 +165,51 @@ def test_replace_parameter_id(self):
old_semantics_ode_observables = amr['semantics']['ode']['observables']
new_semantics_ode_observables = new_amr['semantics']['ode']['observables']

self.assertEqual(len(old_semantics_ode_rates), len(new_semantics_ode_rates))
old_semantics_ode_parameters = amr['semantics']['ode']['parameters']
new_semantics_ode_parameters = new_amr['semantics']['ode']['parameters']

for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates):
if old_id in old_rate['expression'] and old_id in old_rate['expression_mathml']:
new_value_in_rate_expression = new_id in new_rate['expression']
old_value_out_rate_expression = old_id not in new_rate['expression']
new_model_states = new_amr['model']['states']

expression_flag = (new_value_in_rate_expression and old_value_out_rate_expression)
self.assertEqual(len(old_semantics_ode_rates), len(new_semantics_ode_rates))
self.assertEqual(len(old_semantics_ode_observables), len(new_semantics_ode_observables))

new_value_in_rate_mathml = new_id in new_rate['expression_mathml']
old_value_out_rate_mathml = old_id not in new_rate['expression_mathml']
# Since states are removed from list of parameters after replacing parameter id, test to see if length of
# parameters list of input amr - # of states is equal to length of parameters list of output amr
self.assertEqual(len(old_semantics_ode_parameters) - len(new_model_states), len(new_semantics_ode_parameters))

mathml_flag = (new_value_in_rate_mathml and old_value_out_rate_mathml)
for old_rate, new_rate in zip(old_semantics_ode_rates, new_semantics_ode_rates):
if old_id in old_rate['expression'] and old_id in old_rate['expression_mathml']:
self.assertIn(new_id, new_rate['expression'])
self.assertNotIn(old_id, new_rate['expression'])

self.assertTrue(expression_flag and mathml_flag)
self.assertIn(new_id, new_rate['expression_mathml'])
self.assertNotIn(old_id, new_rate['expression_mathml'])

# don't test states field for a parameter as it is assumed that replace_parameter_id will only be used with
# parameters such as gamma or beta (i.e. non-states)
for old_observable, new_observable in zip(old_semantics_ode_observables, new_semantics_ode_observables):

if (old_id in old_observable['states'] and
old_id in old_observable['expression'] and old_id in new_observable['expression_mathml']):
new_value_in_observable_expression = new_id in new_observable['expression']
old_value_out_observable_expression = old_id not in new_observable['expression']

expression_flag = (new_value_in_observable_expression and old_value_out_observable_expression)

new_value_in_observable_mathml = new_id in new_observable['expression_mathml']
old_value_out_observable_mathml = old_id not in new_observable['expression_mathml']

mathml_flag = (new_value_in_observable_mathml and old_value_out_observable_mathml)

self.assertTrue(expression_flag and mathml_flag)
if old_id in old_observable['expression'] and old_id in new_observable['expression_mathml']:
self.assertIn(new_id, new_observable['expression'])
self.assertNotIn(old_id, new_observable['expression'])

self.assertIn(new_id, new_observable['expression_mathml'])
self.assertNotIn(old_id, new_observable['expression_mathml'])

# zip method iterates over length of the smaller iterable (new_semantics_ode_parameters)
for old_parameter, new_parameter in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
if old_parameter['id'] == old_id:
self.assertEqual(new_parameter['id'], new_id)
self.assertEqual(old_parameter['value'], new_parameter['value'])
self.assertEqual(old_parameter['distribution'], new_parameter['distribution'])
self.assertEqual(sstr(old_parameter['units']['expression']), new_parameter['units']['expression'])
self.assertEqual(mathml_to_expression(old_parameter['units']['expression_mathml']),
mathml_to_expression(new_parameter['units']['expression_mathml']))

# def test_replace_initial_id(self):
# old_id = 'S'
# new_id = 'TEST'
# amr = _d(self.sir_amr)
# new_amr = replace_initial_id(amr, old_id, new_id)
#
# old_semantics_ode_initials = amr['semantics']['ode']['initials']
# new_semantics_ode_initials = new_amr['semantics']['ode']['initials']
#
Expand All @@ -212,10 +218,10 @@ def test_replace_parameter_id(self):
# self.assertEqual(new_initials['target'], new_id)

def test_remove_state(self):
removed_state = 'S'
removed_state_id = 'S'
amr = _d(self.sir_amr)

new_amr = remove_state(amr, removed_state)
new_amr = remove_state(amr, removed_state_id)

new_model = new_amr['model']
new_model_states = new_model['states']
Expand All @@ -228,44 +234,42 @@ def test_remove_state(self):
new_semantics_ode_observables = new_semantics_ode['observables']

for new_state in new_model_states:
self.assertTrue(removed_state != new_state['id'])
self.assertNotEquals(removed_state_id, new_state['id'])

for new_transition in new_model_transitions:
self.assertNotIn(removed_state, new_transition['input'])
self.assertNotIn(removed_state, new_transition['output'])
self.assertNotIn(removed_state_id, new_transition['input'])
self.assertNotIn(removed_state_id, new_transition['output'])

# output rates that originally contained targeted state are removed
for new_rate in new_semantics_ode_rates:
self.assertNotIn(removed_state, new_rate['expression'])
self.assertNotIn(removed_state, new_rate['expression_mathml'])
self.assertNotIn(removed_state_id, new_rate['expression'])
self.assertNotIn(removed_state_id, new_rate['expression_mathml'])

# initials are bugged, all states removed rather than just targeted removed state in output amr
for new_initial in new_semantics_ode_initials:
self.assertTrue(removed_state != new_initial['target'])
self.assertNotEquals(removed_state_id, new_initial['target'])

# parameters that are associated in an expression with a removed state are not present in output amr
for new_parameter in new_semantics_ode_parameters:
self.assertTrue(removed_state + '0' != new_parameter['id'])
self.assertTrue(removed_state_id not in new_parameter['id'])

# output observables that originally contained targeted state still exist with targeted state removed
# output observable expressions that originally contained targeted state still exist with targeted state removed
# (e.g. 'S+R' -> 'R') if 'S' is the removed state
for new_observable in new_semantics_ode_observables:
self.assertNotIn(removed_state, new_observable['expression'])
self.assertNotIn(removed_state, new_observable['expression_mathml'])
self.assertNotIn(removed_state_id, new_observable['expression'])
self.assertNotIn(removed_state_id, new_observable['expression_mathml'])

def test_remove_transition(self):

removed_transition = 'inf'
amr = _d(self.sir_amr)

new_amr = remove_transition(amr, removed_transition)
new_model_transition = new_amr['model']['transitions']

for new_transition in new_model_transition:
self.assertTrue(removed_transition != new_transition['id'])
self.assertNotEquals(removed_transition, new_transition['id'])

def test_replace_rate_law_sympy(self):

transition_id = 'inf'
target_expression_xml_str = '<apply><plus/><ci>X</ci><cn>8</cn></apply>'
target_expression_sympy = mathml_to_expression(target_expression_xml_str)
Expand Down

0 comments on commit e9cb3dc

Please sign in to comment.