Skip to content

Commit

Permalink
Added remove_X where X is an observable or parameter method and its a…
Browse files Browse the repository at this point in the history
…ccompanying unit test
  • Loading branch information
nanglo123 committed Sep 6, 2023
1 parent be2ee6d commit a5f6172
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 9 deletions.
11 changes: 11 additions & 0 deletions mira/modeling/askenet/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ def replace_observable_id(tm, old_id, new_id, display_name):
return tm


@amr_to_mira
def remove_observable_or_parameter(tm, replaced_id, replacement_value=None):
if replacement_value:
tm.substitute_parameter(replaced_id, replacement_value)
else:
for obs, observable in copy.deepcopy(tm.observables).items():
if obs == replaced_id:
tm.observables.pop(obs)
return tm


@amr_to_mira
def replace_parameter_id(tm, old_id, new_id):
"""Replace the ID of a parameter."""
Expand Down
60 changes: 51 additions & 9 deletions tests/test_modeling/test_askenet_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import requests
import pandas as pd
from copy import deepcopy as _d
from mira.modeling.askenet.ops import *
# from sympy.parsing.sympy_parser import parse_expr
Expand Down Expand Up @@ -94,7 +95,7 @@ def test_replace_state_id(self):
# since output only has 2 entries (parameter entries of beta and gamma) as opposed to 5 from
# old amr, this test will pass as this loop only makes 2 iterations
for old_params, new_params in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
# test to see if old_id/new_id in name/id field and not for equality because these fields
# test to see if old_id/new_id in name/id field and not for id/name equality because these fields
# may contain subscripts or timestamps appended to the old_id/new_id
if old_id in old_params['id'] and old_id in old_params['name']:
self.assertIn(new_id, new_params['id'])
Expand Down Expand Up @@ -152,6 +153,44 @@ def test_replace_observable_id(self):
self.assertEqual(new_observable['id'], new_id)
self.assertEqual(new_observable['name'], new_display_name)

def test_remove_observable_or_parameter(self):

old_amr_obs = _d(self.sir_amr)
old_amr_param = _d(self.sir_amr)

replaced_observable_id = 'noninf'
new_amr_obs = remove_observable_or_parameter(old_amr_obs, replaced_observable_id)
for new_observable in new_amr_obs['semantics']['ode']['observables']:
self.assertNotEqual(new_observable['id'], replaced_observable_id)

replaced_param_id = 'beta'
replacement_value = 5
new_amr_param = remove_observable_or_parameter(old_amr_param, replaced_param_id, replacement_value)
for new_param in new_amr_param['semantics']['ode']['parameters']:
self.assertNotEqual(new_param['id'], replaced_param_id)
for old_rate, new_rate in zip(old_amr_param['semantics']['ode']['rates'],
new_amr_param['semantics']['ode']['rates']):
if replaced_param_id in old_rate['expression'] and replaced_param_id in old_rate['expression_mathml']:
self.assertNotIn(replaced_param_id, new_rate['expression'])
self.assertIn(str(replacement_value), new_rate['expression'])

self.assertNotIn(replaced_param_id, new_rate['expression_mathml'])
self.assertIn(str(replacement_value), new_rate['expression_mathml'])

self.assertEqual(old_rate['target'], new_rate['target'])

# currently don't support expressions for initials
for old_obs, new_obs in zip(old_amr_param['semantics']['ode']['observables'],
new_amr_param['semantics']['ode']['observables']):
if replaced_param_id in old_obs['expression'] and replaced_param_id in old_obs['expression_mathml']:
self.assertNotIn(replaced_param_id, new_obs['expression'])
self.assertIn(str(replacement_value), new_obs['expression'])

self.assertNotIn(replaced_param_id, new_obs['expression_mathml'])
self.assertIn(str(replacement_value), new_obs['expression_mathml'])

self.assertEqual(old_obs['id'], new_obs['id'])

# current bug is that it doesn't return the changed parameter in new_amr['semantics']['ode']['parameters']
# expected 2 returned parameters in list of parameters, only got 1 (the 1 that wasn't changed)
def test_replace_parameter_id(self):
Expand Down Expand Up @@ -197,6 +236,7 @@ def test_replace_parameter_id(self):
self.assertNotIn(old_id, new_observable['expression_mathml'])

# zip method iterates over length of the smaller iterable (new_semantics_ode_parameters)
# non-state parameters are listed first in input amr
for old_parameter, new_parameter in zip(old_semantics_ode_parameters, new_semantics_ode_parameters):
if old_parameter['id'] == old_id:
self.assertEqual(new_parameter['id'], new_id)
Expand All @@ -206,6 +246,13 @@ def test_replace_parameter_id(self):
self.assertEqual(mathml_to_expression(old_parameter['units']['expression_mathml']),
mathml_to_expression(new_parameter['units']['expression_mathml']))

def test_add_parameter(self):
amr = _d(self.sir_amr)

new_amr = add_parameter(amr, parameter_id='sigma',
expression_xml="<apply><times/><ci>E</ci><ci>delta</ci></apply>",
value=.5, distribution_type='Uniform1', min_value=.05, max_value=.8)

# def test_replace_initial_id(self):
# old_id = 'S'
# new_id = 'TEST'
Expand Down Expand Up @@ -251,8 +298,10 @@ def test_remove_state(self):
self.assertNotEquals(removed_state_id, new_initial['target'])

# parameters that are associated in an expression with a removed state are not present in output amr
# (e.g.) if there exists an expression: 'S*I*beta' and we remove S, then beta is no longer present in output
# list of parameters
for new_parameter in new_semantics_ode_parameters:
self.assertTrue(removed_state_id not in new_parameter['id'])
self.assertNotIn(removed_state_id, new_parameter['id'])

# output observable expressions that originally contained targeted state still exist with targeted state removed
# (e.g. 'S+R' -> 'R') if 'S' is the removed state
Expand Down Expand Up @@ -299,13 +348,6 @@ def test_replace_rate_law_mathml(self):
self.assertEqual(new_rate['expression_mathml'], xml_str)
self.assertEqual(new_rate['expression'], sstr(sympy_expression))

def test_add_parameter(self):
amr = _d(self.sir_amr)

new_amr = add_parameter(amr, parameter_id='sigma',
expression_xml="<apply><times/><ci>E</ci><ci>delta</ci></apply>",
value=.5, distribution_type='Uniform1', min_value=.05, max_value=.8)

def test_stratify(self):
amr = _d(self.sir_amr)
new_amr = stratify(amr, key='city', strata=['boston', 'nyc'])
Expand Down

0 comments on commit a5f6172

Please sign in to comment.