diff --git a/mira/modeling/amr/regnet.py b/mira/modeling/amr/regnet.py index a57b3ff02..9e87ba039 100644 --- a/mira/modeling/amr/regnet.py +++ b/mira/modeling/amr/regnet.py @@ -2,7 +2,8 @@ at https://github.com/DARPA-ASKEM/Model-Representations/tree/main/petrinet. """ -__all__ = ["AMRRegNetModel", "ModelSpecification"] +__all__ = ["AMRRegNetModel", "ModelSpecification", + "template_model_to_regnet_json"] import json @@ -15,7 +16,7 @@ from mira.metamodel import * -from .. import Model, is_production +from .. import Model, is_production, is_conversion from .utils import add_metadata_annotations logger = logging.getLogger(__name__) @@ -68,7 +69,13 @@ def __init__(self, model: Model): self.states.append(state_data) for idx, transition in enumerate(model.transitions.values()): - if isinstance(transition.template, NaturalDegradation): + # Regnets cannot represent conversions (only + # production/degradation) so we skip these + if is_conversion(transition.template): + continue + # Natural degradation corresponds to an inherent negative + # sign on the state so we have special handling for it + elif isinstance(transition.template, NaturalDegradation): var = vmap[transition.consumed[0].key] if transition.template.rate_law: pnames = transition.template.get_parameter_names() @@ -83,6 +90,8 @@ def __init__(self, model: Model): else: state['sign'] = False continue + # Controlled production corresponds to an inherent positive + # sign on the state so we have special handling for it elif isinstance(transition.template, ControlledProduction): var = vmap[transition.produced[0].key] if transition.template.rate_law: @@ -98,6 +107,9 @@ def __init__(self, model: Model): else: state['sign'] = True continue + # Beyond these, we can assume that the transition is a + # form of production or degradation corresponding to + # a regular transition in the regnet framework tid = f"t{idx + 1}" transition_dict = {'id': tid} @@ -279,6 +291,21 @@ def to_json_file( json.dump(js, fh, **kwargs) +def template_model_to_regnet_json(tm: TemplateModel): + """Convert a template model to a RegNet JSON dict. + + Parameters + ---------- + tm : + The template model to convert. + + Returns + ------- + A JSON dict representing the RegNet model. + """ + return AMRRegNetModel(Model(tm)).to_json() + + class Initial(BaseModel): expression: str expression_mathml: str diff --git a/tests/test_modeling/test_regnet.py b/tests/test_modeling/test_regnet.py new file mode 100644 index 000000000..a215f0851 --- /dev/null +++ b/tests/test_modeling/test_regnet.py @@ -0,0 +1,28 @@ +from mira.sources import amr +from mira.modeling import Model +from mira.metamodel.ops import stratify + +from mira.modeling.amr.regnet import AMRRegNetModel, \ + template_model_to_regnet_json + + +def test_regnet_end_to_end(): + url = 'https://raw.githubusercontent.com/DARPA-ASKEM/' \ + 'Model-Representations/main/regnet/examples/lotka_volterra.json' + + model = amr.regnet.model_from_url(url) + + model_2_city = stratify( + model, + key="city", + strata=[ + "Toronto", + "Montreal", + ], + ) + + # Smoke tests to make sure exports work + ex1 = AMRRegNetModel(Model(model)).to_json() + ex2 = AMRRegNetModel(Model(model_2_city)).to_json() + assert ex1 == template_model_to_regnet_json(model) + assert ex2 == template_model_to_regnet_json(model_2_city) \ No newline at end of file