Skip to content

Commit

Permalink
Merge pull request #276 from gyorilab/regnet_fix
Browse files Browse the repository at this point in the history
Improvements to regnet export
  • Loading branch information
bgyori authored Jan 31, 2024
2 parents 7d41dcf + 7fd450b commit 0033d04
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
33 changes: 30 additions & 3 deletions mira/modeling/amr/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
at https://github.com/DARPA-ASKEM/Model-Representations/tree/main/petrinet.
"""

__all__ = ["AMRRegNetModel", "ModelSpecification"]
__all__ = ["AMRRegNetModel", "ModelSpecification",
"template_model_to_regnet_json"]


import json
Expand All @@ -15,7 +16,7 @@

from mira.metamodel import *

from .. import Model, is_production
from .. import Model, is_production, is_conversion
from .utils import add_metadata_annotations

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -68,7 +69,13 @@ def __init__(self, model: Model):
self.states.append(state_data)

for idx, transition in enumerate(model.transitions.values()):
if isinstance(transition.template, NaturalDegradation):
# Regnets cannot represent conversions (only
# production/degradation) so we skip these
if is_conversion(transition.template):
continue
# Natural degradation corresponds to an inherent negative
# sign on the state so we have special handling for it
elif isinstance(transition.template, NaturalDegradation):
var = vmap[transition.consumed[0].key]
if transition.template.rate_law:
pnames = transition.template.get_parameter_names()
Expand All @@ -83,6 +90,8 @@ def __init__(self, model: Model):
else:
state['sign'] = False
continue
# Controlled production corresponds to an inherent positive
# sign on the state so we have special handling for it
elif isinstance(transition.template, ControlledProduction):
var = vmap[transition.produced[0].key]
if transition.template.rate_law:
Expand All @@ -98,6 +107,9 @@ def __init__(self, model: Model):
else:
state['sign'] = True
continue
# Beyond these, we can assume that the transition is a
# form of production or degradation corresponding to
# a regular transition in the regnet framework
tid = f"t{idx + 1}"
transition_dict = {'id': tid}

Expand Down Expand Up @@ -279,6 +291,21 @@ def to_json_file(
json.dump(js, fh, **kwargs)


def template_model_to_regnet_json(tm: TemplateModel):
"""Convert a template model to a RegNet JSON dict.
Parameters
----------
tm :
The template model to convert.
Returns
-------
A JSON dict representing the RegNet model.
"""
return AMRRegNetModel(Model(tm)).to_json()


class Initial(BaseModel):
expression: str
expression_mathml: str
Expand Down
28 changes: 28 additions & 0 deletions tests/test_modeling/test_regnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mira.sources import amr
from mira.modeling import Model
from mira.metamodel.ops import stratify

from mira.modeling.amr.regnet import AMRRegNetModel, \
template_model_to_regnet_json


def test_regnet_end_to_end():
url = 'https://raw.githubusercontent.com/DARPA-ASKEM/' \
'Model-Representations/main/regnet/examples/lotka_volterra.json'

model = amr.regnet.model_from_url(url)

model_2_city = stratify(
model,
key="city",
strata=[
"Toronto",
"Montreal",
],
)

# Smoke tests to make sure exports work
ex1 = AMRRegNetModel(Model(model)).to_json()
ex2 = AMRRegNetModel(Model(model_2_city)).to_json()
assert ex1 == template_model_to_regnet_json(model)
assert ex2 == template_model_to_regnet_json(model_2_city)

0 comments on commit 0033d04

Please sign in to comment.