diff --git a/mira/metamodel/__init__.py b/mira/metamodel/__init__.py index f880264e5..ef67b5aae 100644 --- a/mira/metamodel/__init__.py +++ b/mira/metamodel/__init__.py @@ -7,3 +7,4 @@ from .search import * from .ops import * from .units import * +from .utils import * diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index 00868667c..46ef46148 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -9,13 +9,15 @@ from .template_model import TemplateModel, Initial, Parameter from .templates import * -from .units import Unit +from .units import Unit, dimensionless_units __all__ = [ "stratify", "simplify_rate_laws", - "aggregate_parameters" + "aggregate_parameters", + "get_term_roles", + "counts_to_dimensionless" ] @@ -395,3 +397,66 @@ def get_term_roles(term, template, parameters): else: term_roles['other'].append(symbol.name) return dict(term_roles) + + +def counts_to_dimensionless(tm: TemplateModel, + counts_unit: str, + norm_factor: float): + """Convert all entity concentrations to dimensionless units. + + Parameters + ---------- + tm : + A template model. + counts_unit : + The unit of the counts. + norm_factor : + The normalization factor to convert counts to concentration. + + Returns + ------- + : + A template model with all entity concentrations converted to + dimensionless units. + """ + # Make a deepcopy up front so we don't change the original template model + tm = deepcopy(tm) + # Make a symbol of the counts unit for calculations + counts_unit_symbol = sympy.Symbol(counts_unit) + + initials_normalized = set() + # First we normalize concepts and their initials + for template in tm.templates: + # Since concepts can be distributed across templates, we have to go + # template by template + for concept in template.get_concepts(): + if concept.units: + # We figure out what the exponent of the coutns unit is + # if it appears in the units of the concept + (coeff, exponent) = \ + concept.units.expression.args[0].as_coeff_exponent(counts_unit_symbol) + # If the exponent is other than zero then normalization is needed + if exponent: + concept.units.expression = \ + concept.units.expression.args[0] / (counts_unit_symbol ** exponent) + # We not try to see if there is a corresponding initial condition + # for the concept and if so, we normalize it as well + if concept.name in tm.initials and concept.name not in initials_normalized: + init = tm.initials[concept.name] + if init.value is not None: + init.value /= (norm_factor ** exponent) + if init.concept.units: + init.concept.units.expression = \ + init.concept.units.expression.args[0] / (counts_unit_symbol ** exponent) + initials_normalized.add(concept.name) + # Now we do the same for parameters + for p_name, p in tm.parameters.items(): + if p.units: + (coeff, exponent) = \ + p.units.expression = \ + p.units.expression.args[0].as_coeff_exponent(counts_unit_symbol) + if exponent: + p.units /= (counts_unit_symbol ** exponent) + p.value /= (norm_factor ** exponent) + + return tm \ No newline at end of file diff --git a/mira/metamodel/template_model.py b/mira/metamodel/template_model.py index 05a00cd2f..450eab884 100644 --- a/mira/metamodel/template_model.py +++ b/mira/metamodel/template_model.py @@ -10,6 +10,8 @@ from pydantic import BaseModel, Field from .templates import * +from .units import Unit +from .utils import SympyExprStr class Initial(BaseModel): diff --git a/mira/metamodel/templates.py b/mira/metamodel/templates.py index 60f45668f..7f72337c9 100644 --- a/mira/metamodel/templates.py +++ b/mira/metamodel/templates.py @@ -17,13 +17,11 @@ "GroupedControlledProduction", "GroupedControlledDegradation", "SpecifiedTemplate", - "SympyExprStr", "templates_equal", "context_refinement", ] import logging -import os import sys from collections import ChainMap from itertools import product @@ -50,8 +48,8 @@ except ImportError: from typing_extensions import Annotated -from .units import Unit, UNIT_SYMBOLS - +from .units import Unit +from .utils import SympyExprStr IS_EQUAL = "is_equal" @@ -81,28 +79,6 @@ class Config(BaseModel): ) -class SympyExprStr(sympy.Expr): - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - if isinstance(v, cls): - return v - return cls(v) - - @classmethod - def __modify_schema__(cls, field_schema): - field_schema.update(type="string", example="2*x") - - def __str__(self): - return super().__str__()[len(self.__class__.__name__)+1:-1] - - def __repr__(self): - return str(self) - - class Concept(BaseModel): """A concept is specified by its identifier(s), name, and - optionally - its context. diff --git a/tests/test_ops.py b/tests/test_ops.py index bb23f784a..6579ccb1d 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -15,7 +15,7 @@ Parameter, TemplateModel, ) -from mira.metamodel.ops import stratify, simplify_rate_law +from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless from mira.examples.sir import cities, sir, sir_2_city, sir_parameterized from mira.examples.concepts import infected, susceptible from mira.examples.chime import sviivr @@ -319,3 +319,30 @@ def _make_template(rate_law): (1 - _s('alpha')) * _s('S') * _s('A')) assert templates[1].rate_law.args[0].equals( (1 - _s('alpha')) * _s('beta') * _s('S') * _s('B')) + + +def test_counts_to_dimensionless(): + """Test that counts are converted to dimensionless.""" + from mira.metamodel import Unit + tm = _d(sir_parameterized) + + for template in tm.templates: + for concept in template.get_concepts(): + concept.units = Unit(expression=sympy.Symbol('person')) + tm.initials['susceptible_population'].value = 1e5-1 + tm.initials['infected_population'].value = 1 + tm.initials['immune_population'].value = 0 + for initial in tm.initials.values(): + initial.concept.units = Unit(expression=sympy.Symbol('person')) + + tm = counts_to_dimensionless(tm, 'person', 1e5) + for template in tm.templates: + for concept in template.get_concepts(): + assert concept.units.expression.equals(1), concept.units + + assert tm.initials['susceptible_population'].value == (1e5-1)/1e5 + assert tm.initials['infected_population'].value == 1/1e5 + assert tm.initials['immune_population'].value == 0 + + for initial in tm.initials.values(): + assert initial.concept.units.expression.equals(1)