From c548caeb0901fceb7644186ecd6843f6db540781 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Sun, 9 Jul 2023 01:00:47 -0400 Subject: [PATCH] Spin off units and utils --- mira/metamodel/units.py | 51 +++++++++++++++++++++++++++++++++++++++++ mira/metamodel/utils.py | 25 ++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 mira/metamodel/units.py create mode 100644 mira/metamodel/utils.py diff --git a/mira/metamodel/units.py b/mira/metamodel/units.py new file mode 100644 index 000000000..f9386c6c5 --- /dev/null +++ b/mira/metamodel/units.py @@ -0,0 +1,51 @@ +__all__ = [ + 'Unit', + 'person_units', + 'day_units', + 'per_day_units', + 'dimensionless_units', + 'per_day_per_person_units', + 'UNIT_SYMBOLS' +] + +import os +import sympy +from pydantic import BaseModel, Field +from .utils import SympyExprStr + + +class Unit(BaseModel): + """A unit of measurement.""" + class Config: + arbitrary_types_allowed = True + json_encoders = { + SympyExprStr: lambda e: str(e), + } + json_decoders = { + SympyExprStr: lambda e: sympy.parse_expr(e) + } + + expression: SympyExprStr = Field( + description="The expression for the unit." + ) + + +person_units = Unit(expression=sympy.Symbol('person')) +day_units = Unit(expression=sympy.Symbol('day')) +per_day_units = Unit(expression=1/sympy.Symbol('day')) +dimensionless_units = Unit(expression=sympy.Integer('1')) +per_day_per_person_units = Unit(expression=1/(sympy.Symbol('day')*sympy.Symbol('person'))) + + +def load_units(): + path = os.path.join(os.path.dirname(os.path.abspath(__file__)), + os.pardir, 'dkg', 'resources', 'unit_names.tsv') + with open(path, 'r') as fh: + units = {} + for line in fh.readlines(): + symbol = line.strip() + units[symbol] = sympy.Symbol(symbol) + return units + + +UNIT_SYMBOLS = load_units() diff --git a/mira/metamodel/utils.py b/mira/metamodel/utils.py new file mode 100644 index 000000000..3a54d99bb --- /dev/null +++ b/mira/metamodel/utils.py @@ -0,0 +1,25 @@ +__all__ = ["SympyExprStr"] + +import sympy + + +class SympyExprStr(sympy.Expr): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + if isinstance(v, cls): + return v + return cls(v) + + @classmethod + def __modify_schema__(cls, field_schema): + field_schema.update(type="string", example="2*x") + + def __str__(self): + return super().__str__()[len(self.__class__.__name__)+1:-1] + + def __repr__(self): + return str(self)