Skip to content

Commit

Permalink
Implement model normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Jul 9, 2023
1 parent f8aa50c commit 25cac16
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 29 deletions.
1 change: 1 addition & 0 deletions mira/metamodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .search import *
from .ops import *
from .units import *
from .utils import *
69 changes: 67 additions & 2 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

from .template_model import TemplateModel, Initial, Parameter
from .templates import *
from .units import Unit
from .units import Unit, dimensionless_units


__all__ = [
"stratify",
"simplify_rate_laws",
"aggregate_parameters"
"aggregate_parameters",
"get_term_roles",
"counts_to_dimensionless"
]


Expand Down Expand Up @@ -395,3 +397,66 @@ def get_term_roles(term, template, parameters):
else:
term_roles['other'].append(symbol.name)
return dict(term_roles)


def counts_to_dimensionless(tm: TemplateModel,
counts_unit: str,
norm_factor: float):
"""Convert all entity concentrations to dimensionless units.
Parameters
----------
tm :
A template model.
counts_unit :
The unit of the counts.
norm_factor :
The normalization factor to convert counts to concentration.
Returns
-------
:
A template model with all entity concentrations converted to
dimensionless units.
"""
# Make a deepcopy up front so we don't change the original template model
tm = deepcopy(tm)
# Make a symbol of the counts unit for calculations
counts_unit_symbol = sympy.Symbol(counts_unit)

initials_normalized = set()
# First we normalize concepts and their initials
for template in tm.templates:
# Since concepts can be distributed across templates, we have to go
# template by template
for concept in template.get_concepts():
if concept.units:
# We figure out what the exponent of the coutns unit is
# if it appears in the units of the concept
(coeff, exponent) = \
concept.units.expression.args[0].as_coeff_exponent(counts_unit_symbol)
# If the exponent is other than zero then normalization is needed
if exponent:
concept.units.expression = \
concept.units.expression.args[0] / (counts_unit_symbol ** exponent)
# We not try to see if there is a corresponding initial condition
# for the concept and if so, we normalize it as well
if concept.name in tm.initials and concept.name not in initials_normalized:
init = tm.initials[concept.name]
if init.value is not None:
init.value /= (norm_factor ** exponent)
if init.concept.units:
init.concept.units.expression = \
init.concept.units.expression.args[0] / (counts_unit_symbol ** exponent)
initials_normalized.add(concept.name)
# Now we do the same for parameters
for p_name, p in tm.parameters.items():
if p.units:
(coeff, exponent) = \
p.units.expression = \
p.units.expression.args[0].as_coeff_exponent(counts_unit_symbol)
if exponent:
p.units /= (counts_unit_symbol ** exponent)
p.value /= (norm_factor ** exponent)

return tm
2 changes: 2 additions & 0 deletions mira/metamodel/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pydantic import BaseModel, Field

from .templates import *
from .units import Unit
from .utils import SympyExprStr


class Initial(BaseModel):
Expand Down
28 changes: 2 additions & 26 deletions mira/metamodel/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
"GroupedControlledProduction",
"GroupedControlledDegradation",
"SpecifiedTemplate",
"SympyExprStr",
"templates_equal",
"context_refinement",
]

import logging
import os
import sys
from collections import ChainMap
from itertools import product
Expand All @@ -50,8 +48,8 @@
except ImportError:
from typing_extensions import Annotated

from .units import Unit, UNIT_SYMBOLS

from .units import Unit
from .utils import SympyExprStr


IS_EQUAL = "is_equal"
Expand Down Expand Up @@ -81,28 +79,6 @@ class Config(BaseModel):
)


class SympyExprStr(sympy.Expr):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, v):
if isinstance(v, cls):
return v
return cls(v)

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string", example="2*x")

def __str__(self):
return super().__str__()[len(self.__class__.__name__)+1:-1]

def __repr__(self):
return str(self)


class Concept(BaseModel):
"""A concept is specified by its identifier(s), name, and - optionally -
its context.
Expand Down
29 changes: 28 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Parameter,
TemplateModel,
)
from mira.metamodel.ops import stratify, simplify_rate_law
from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless
from mira.examples.sir import cities, sir, sir_2_city, sir_parameterized
from mira.examples.concepts import infected, susceptible
from mira.examples.chime import sviivr
Expand Down Expand Up @@ -319,3 +319,30 @@ def _make_template(rate_law):
(1 - _s('alpha')) * _s('S') * _s('A'))
assert templates[1].rate_law.args[0].equals(
(1 - _s('alpha')) * _s('beta') * _s('S') * _s('B'))


def test_counts_to_dimensionless():
"""Test that counts are converted to dimensionless."""
from mira.metamodel import Unit
tm = _d(sir_parameterized)

for template in tm.templates:
for concept in template.get_concepts():
concept.units = Unit(expression=sympy.Symbol('person'))
tm.initials['susceptible_population'].value = 1e5-1
tm.initials['infected_population'].value = 1
tm.initials['immune_population'].value = 0
for initial in tm.initials.values():
initial.concept.units = Unit(expression=sympy.Symbol('person'))

tm = counts_to_dimensionless(tm, 'person', 1e5)
for template in tm.templates:
for concept in template.get_concepts():
assert concept.units.expression.equals(1), concept.units

assert tm.initials['susceptible_population'].value == (1e5-1)/1e5
assert tm.initials['infected_population'].value == 1/1e5
assert tm.initials['immune_population'].value == 0

for initial in tm.initials.values():
assert initial.concept.units.expression.equals(1)

0 comments on commit 25cac16

Please sign in to comment.