Skip to content

Commit

Permalink
Added add_observable and replace_expression and their respective unit…
Browse files Browse the repository at this point in the history
… tests
  • Loading branch information
nanglo123 committed Sep 7, 2023
1 parent a1e6381 commit 979b3cb
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 27 deletions.
83 changes: 63 additions & 20 deletions mira/modeling/askenet/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mira.sources.askenet.petrinet import template_model_from_askenet_json
from .petrinet import template_model_to_petrinet_json
from mira.metamodel.io import mathml_to_expression
from mira.metamodel.template_model import Parameter, Distribution
from mira.metamodel.template_model import Parameter, Distribution, Observable
from mira.metamodel.templates import Concept


Expand Down Expand Up @@ -78,6 +78,17 @@ def remove_observable_or_parameter(tm, replaced_id, replacement_value=None):
return tm


@amr_to_mira
def add_observable(tm, new_id, new_display_name, new_rate_law):
if new_id in tm.observables:
print('This observable id is already present')
return tm
rate_law_sympy = mathml_to_expression(new_rate_law)
new_observable = Observable(name=new_id, display_name=new_display_name, expression=rate_law_sympy)
tm.observables[new_id] = new_observable
return tm


@amr_to_mira
def replace_parameter_id(tm, old_id, new_id):
"""Replace the ID of a parameter."""
Expand All @@ -101,6 +112,30 @@ def replace_parameter_id(tm, old_id, new_id):
return tm


# Resolve issue where only parameters are added only when they are present in rate laws.
@amr_to_mira
def add_parameter(tm, parameter_id: str, expression_xml: str, value: float, distribution_type: str,
min_value: float, max_value: float):
distribution = Distribution(type=distribution_type,
parameters={
'maximum': max_value,
'minimum': min_value
})
sympy_expression = mathml_to_expression(expression_xml)
data = {
'name': parameter_id,
'value': value,
'distribution': distribution,
'units': {'expression': sympy_expression,
'expression_mathml': expression_xml}
}

new_param = Parameter(**data)
tm.parameters[parameter_id] = new_param

return tm


@amr_to_mira
def replace_initial_id(tm, old_id, new_id):
"""Replace the ID of an initial."""
Expand Down Expand Up @@ -136,6 +171,20 @@ def remove_transition(tm, transition_id):
return tm


# @amr_to_mira
# def add_transition(tm, rate_law, src_id=None, tgt_id=None):
# if not src_id and not tgt_id:
# print("You must pass in at least one of source and target id")
# return tm
# sympy_expression = mathml_to_expression(rate_law)
# if src_id is None and tgt_id is not None:
# pass
# if src_id is not None and tgt_id is None:
# pass
# else:
# pass


@amr_to_mira
# rate law is of type Sympy Expression
def replace_rate_law_sympy(tm, transition_id, new_rate_law):
Expand All @@ -150,28 +199,22 @@ def replace_rate_law_mathml(tm, transition_id, new_rate_law):
return replace_rate_law_sympy(tm, transition_id, new_rate_law_sympy)


# Resolve issue where only parameters are added only when they are present in rate laws.
# currently initials don't support expressions so only implement the following 2 methods for observables
# if we are seeking to replace an expression in an initial, return current template model
@amr_to_mira
def add_parameter(tm, parameter_id: str, expression_xml: str, value: float, distribution_type: str,
min_value: float, max_value: float):
distribution = Distribution(type=distribution_type,
parameters={
'maximum': max_value,
'minimum': min_value
})
sympy_expression = mathml_to_expression(expression_xml)
data = {
'name': parameter_id,
'value': value,
'distribution': distribution,
'units': {'expression': sympy_expression,
'expression_mathml': expression_xml}
}
def replace_expression_sympy(tm, object_id, new_expression_sympy, initial_flag):
if initial_flag:
return tm
else:
for obs, observable in tm.observables.items():
if obs == object_id:
observable.expression = SympyExprStr(new_expression_sympy)
return tm

new_param = Parameter(**data)
tm.parameters[parameter_id] = new_param

return tm
def replace_expression_mathml(tm, object_id, new_expression_mathml, initial_flag):
new_expression_sympy = mathml_to_expression(new_expression_mathml)
return replace_expression_sympy(tm, object_id, new_expression_sympy, initial_flag)


@amr_to_mira
Expand Down
63 changes: 56 additions & 7 deletions tests/test_modeling/test_askenet_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,24 @@ def test_remove_observable_or_parameter(self):

self.assertEqual(old_obs['id'], new_obs['id'])

# current bug is that it doesn't return the changed parameter in new_amr['semantics']['ode']['parameters']
# expected 2 returned parameters in list of parameters, only got 1 (the 1 that wasn't changed)
def test_add_observable(self):
amr = _d(self.sir_amr)
new_id = 'testinf'
new_display_name = 'DISPLAY_TEST'
xml_expression = "<apply><times/><ci>E</ci><ci>delta</ci></apply>"
new_amr = add_observable(amr, new_id, new_display_name, xml_expression)

# Create a dict out of a list of observable dict entries to easier test for addition of new observables
new_observable_dict = {}
for observable in new_amr['semantics']['ode']['observables']:
name = observable.pop('id')
new_observable_dict[name] = observable

self.assertIn(new_id, new_observable_dict)
self.assertEqual(new_display_name, new_observable_dict[new_id]['name'])
self.assertEqual(xml_expression, new_observable_dict[new_id]['expression_mathml'])
self.assertEqual(sstr(mathml_to_expression(xml_expression)), new_observable_dict[new_id]['expression'])

def test_replace_parameter_id(self):
old_id = 'beta'
new_id = 'TEST'
Expand Down Expand Up @@ -319,6 +335,14 @@ def test_remove_transition(self):
for new_transition in new_model_transition:
self.assertNotEquals(removed_transition, new_transition['id'])

# def test_add_transition(self):
# new_transition_src_id = 'test'
# new_transition_tgt_id = 'MORE'
# expression_xml = '<apply><plus/><ci>X</ci><cn>8</cn></apply>'
# amr = _d(self.sir_amr)
#
# new_amr = add_transition(amr, expression_xml, src_id=new_transition_src_id)

def test_replace_rate_law_sympy(self):
transition_id = 'inf'
target_expression_xml_str = '<apply><plus/><ci>X</ci><cn>8</cn></apply>'
Expand All @@ -336,17 +360,42 @@ def test_replace_rate_law_sympy(self):
def test_replace_rate_law_mathml(self):
amr = _d(self.sir_amr)
transition_id = 'inf'
xml_str = "<apply><times/><ci>E</ci><ci>delta</ci></apply>"
sympy_expression = mathml_to_expression(xml_str)
target_expression_xml_str = "<apply><times/><ci>E</ci><ci>delta</ci></apply>"
target_expression_sympy = mathml_to_expression(target_expression_xml_str)

new_amr = replace_rate_law_mathml(amr, transition_id, xml_str)
new_amr = replace_rate_law_mathml(amr, transition_id, target_expression_xml_str)

new_semantics_ode_rates = new_amr['semantics']['ode']['rates']

for new_rate in new_semantics_ode_rates:
if new_rate['target'] == transition_id:
self.assertEqual(new_rate['expression_mathml'], xml_str)
self.assertEqual(new_rate['expression'], sstr(sympy_expression))
self.assertEqual(sstr(target_expression_sympy), new_rate['expression'])
self.assertEqual(target_expression_xml_str, new_rate['expression_mathml'])

# Following 2 unit tests only test for replacing expressions in observables, not initials
def test_replace_expression_sympy(self):
object_id = 'noninf'
amr = _d(self.sir_amr)
target_expression_xml_str = "<apply><times/><ci>E</ci><ci>beta</ci></apply>"
target_expression_sympy = mathml_to_expression(target_expression_xml_str)
new_amr = replace_expression_sympy(amr, object_id, target_expression_sympy, False)

for new_obs in new_amr['semantics']['ode']['observables']:
if new_obs['id'] == object_id:
self.assertEqual(sstr(target_expression_sympy), new_obs['expression'])
self.assertEqual(target_expression_xml_str, new_obs['expression_mathml'])

def test_replace_expression_mathml(self):
object_id = 'noninf'
amr = _d(self.sir_amr)
target_expression_xml_str = "<apply><times/><ci>E</ci><ci>beta</ci></apply>"
target_expression_sympy = mathml_to_expression(target_expression_xml_str)
new_amr = replace_expression_mathml(amr, object_id, target_expression_xml_str, False)

for new_obs in new_amr['semantics']['ode']['observables']:
if new_obs['id'] == object_id:
self.assertEqual(sstr(target_expression_sympy), new_obs['expression'])
self.assertEqual(target_expression_xml_str, new_obs['expression_mathml'])

def test_stratify(self):
amr = _d(self.sir_amr)
Expand Down

0 comments on commit 979b3cb

Please sign in to comment.