Skip to content

Commit

Permalink
Spin off units and utils
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Jul 9, 2023
1 parent 3f8513d commit c548cae
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
51 changes: 51 additions & 0 deletions mira/metamodel/units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
__all__ = [
'Unit',
'person_units',
'day_units',
'per_day_units',
'dimensionless_units',
'per_day_per_person_units',
'UNIT_SYMBOLS'
]

import os
import sympy
from pydantic import BaseModel, Field
from .utils import SympyExprStr


class Unit(BaseModel):
"""A unit of measurement."""
class Config:
arbitrary_types_allowed = True
json_encoders = {
SympyExprStr: lambda e: str(e),
}
json_decoders = {
SympyExprStr: lambda e: sympy.parse_expr(e)
}

expression: SympyExprStr = Field(
description="The expression for the unit."
)


person_units = Unit(expression=sympy.Symbol('person'))
day_units = Unit(expression=sympy.Symbol('day'))
per_day_units = Unit(expression=1/sympy.Symbol('day'))
dimensionless_units = Unit(expression=sympy.Integer('1'))
per_day_per_person_units = Unit(expression=1/(sympy.Symbol('day')*sympy.Symbol('person')))


def load_units():
path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
os.pardir, 'dkg', 'resources', 'unit_names.tsv')
with open(path, 'r') as fh:
units = {}
for line in fh.readlines():
symbol = line.strip()
units[symbol] = sympy.Symbol(symbol)
return units


UNIT_SYMBOLS = load_units()
25 changes: 25 additions & 0 deletions mira/metamodel/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
__all__ = ["SympyExprStr"]

import sympy


class SympyExprStr(sympy.Expr):
@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def validate(cls, v):
if isinstance(v, cls):
return v
return cls(v)

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string", example="2*x")

def __str__(self):
return super().__str__()[len(self.__class__.__name__)+1:-1]

def __repr__(self):
return str(self)

0 comments on commit c548cae

Please sign in to comment.