diff --git a/pymerlin/_internal/_cell_type.py b/pymerlin/_internal/_cell_type.py index 6ecb535..ef5faac 100644 --- a/pymerlin/_internal/_cell_type.py +++ b/pymerlin/_internal/_cell_type.py @@ -1,9 +1,11 @@ from pymerlin._internal._effect_trait import EffectTrait +from pymerlin.duration import Duration, MICROSECONDS class CellType: - def __init__(self, gateway): + def __init__(self, gateway, evolution=None): self.gateway = gateway + self.evolution = evolution def getEffectType(self): """ @@ -18,8 +20,11 @@ def duplicate(self, state): def apply(self, state, effect): state.setValue(effect.apply(state.getValue())) - def step(self, state, duration): - pass + def step(self, state, java_duration): + if self.evolution is None: + return + duration = Duration.of(java_duration.dividedBy(self.gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.types.Duration.MICROSECOND), MICROSECONDS) + state.setValue(self.evolution(state.getValue(), duration)) def getExpiry(self, state): return self.gateway.jvm.java.util.Optional.empty() diff --git a/pymerlin/_internal/_condition.py b/pymerlin/_internal/_condition.py index 50ed450..34a1ddb 100644 --- a/pymerlin/_internal/_condition.py +++ b/pymerlin/_internal/_condition.py @@ -1,4 +1,5 @@ from pymerlin import model_actions +from pymerlin._internal._context import _context from pymerlin._internal._querier_adapter import QuerierAdapter @@ -11,7 +12,7 @@ def __init__(self, gateway, func): self.func = func def nextSatisfied(self, querier, horizon): - with model_actions._context(QuerierAdapter(querier)): + with _context(QuerierAdapter(querier)): if self.func(): return self.gateway.jvm.java.util.Optional.of(self.gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.types.Duration.ZERO) else: diff --git a/pymerlin/_internal/_effect_trait.py b/pymerlin/_internal/_effect_trait.py index 1eba73d..8afc770 100644 --- a/pymerlin/_internal/_effect_trait.py +++ b/pymerlin/_internal/_effect_trait.py @@ -1,9 +1,18 @@ +from pymerlin._internal._registrar import FunctionalEffect + + class EffectTrait: def empty(self): - return 0 + return FunctionalEffect(lambda x: x) def sequentially(self, prefix, suffix): return suffix def concurrently(self, left, right): - return 0 + def try_both(x): + res1 = left.apply(right.apply(x)) + res2 = right.apply(left.apply(x)) + if res1 != res2: + raise Exception("Concurrent composition of non-commutative effects") + return res1 + return FunctionalEffect(try_both) class Java: implements = ["gov.nasa.jpl.aerie.merlin.protocol.model.EffectTrait"] diff --git a/pymerlin/_internal/_model_type.py b/pymerlin/_internal/_model_type.py index fedb46c..325b9ab 100644 --- a/pymerlin/_internal/_model_type.py +++ b/pymerlin/_internal/_model_type.py @@ -23,12 +23,12 @@ def set_gateway(self, gateway): gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.driver.Topic())) def instantiate(self, start_time, config, builder): - cell_type = CellType(self.gateway) - registrar = Registrar() model = self.model_class(registrar) - for cell_ref, initial_value in registrar.cells: + default_cell_type = CellType(self.gateway) + for cell_ref, initial_value, evolution in registrar.cells: + cell_type = default_cell_type if evolution is None else CellType(self.gateway, evolution=evolution) topic = self.gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.driver.Topic() cell_id = builder.allocate(self.gateway.jvm.org.apache.commons.lang3.mutable.MutableObject(initial_value), cell_type, self.gateway.jvm.java.util.function.Function.identity(), topic) diff --git a/pymerlin/_internal/_registrar.py b/pymerlin/_internal/_registrar.py index c300205..6187aa6 100644 --- a/pymerlin/_internal/_registrar.py +++ b/pymerlin/_internal/_registrar.py @@ -1,4 +1,4 @@ -from pymerlin import model_actions +from pymerlin._internal import _globals class Registrar: @@ -7,9 +7,9 @@ def __init__(self): self.resources = [] self.topics = [] - def cell(self, initial_value): + def cell(self, initial_value, evolution=None): ref = CellRef() - self.cells.append((ref, initial_value)) + self.cells.append((ref, initial_value, evolution)) return ref def resource(self, name, f): @@ -31,12 +31,15 @@ class CellRef: """ A reference to an allocated piece of simulation state """ + def __init__(self): self.id = None self.topic = None def emit(self, event): - model_actions._current_context[0].emit(event, self.topic) + if callable(event): + event = FunctionalEffect(event) + _globals._current_context[0].emit(event, self.topic) def set_value(self, new_value): self.emit(set_value(new_value)) @@ -45,7 +48,7 @@ def add(self, addend): self.emit(add_number(addend)) def get(self): - return model_actions._current_context[0].get(self.id).getValue() + return _globals._current_context[0].get(self.id).getValue() class set_value: @@ -60,5 +63,17 @@ def apply(self, state): class Java: implements = ["java.util.function.Function"] + +class FunctionalEffect: + def __init__(self, f): + self.f = f + + def apply(self, state): + return self.f(state) + + class Java: + implements = ["java.util.function.Function"] + + def add_number(addend): - pass \ No newline at end of file + pass diff --git a/pymerlin/model_actions.py b/pymerlin/model_actions.py index 0b0a12b..3af9fa1 100644 --- a/pymerlin/model_actions.py +++ b/pymerlin/model_actions.py @@ -3,19 +3,19 @@ with the `await` keyword - for example, `await delay("01:00:00")` """ - import asyncio -from pymerlin._internal._task_status import Delayed, Calling, Awaiting -from pymerlin.duration import Duration -from pymerlin._internal._globals import _current_context, _yield_callback +import pymerlin._internal._task_status +import pymerlin.duration +import pymerlin._internal._globals + async def delay(duration): if type(duration) is str: - duration = Duration.from_string(duration) - elif type(duration) is not Duration: + duration = pymerlin.duration.Duration.from_string(duration) + elif type(duration) is not pymerlin.duration.Duration: raise Exception("Argument to delay must be a Duration or a string representing a duration") - return await _yield_with(Delayed(duration)) + return await _yield_with(pymerlin._internal._task_status.Delayed(duration)) def spawn(child): @@ -23,25 +23,24 @@ def spawn(child): :param coro: :return: """ - _current_context[1](child) - + pymerlin._internal._globals._current_context[1](child) async def call(child): - return await _yield_with(Calling(child)) + return await _yield_with(pymerlin._internal._task_status.Calling(child)) async def wait_until(condition): """ :param condition: A function returning True or False """ - return await _yield_with(Awaiting(condition)) + return await _yield_with(pymerlin._internal._task_status.Awaiting(condition)) async def _yield_with(status): loop = asyncio.get_running_loop() loop.set_debug(True) future = loop.create_future() - _yield_callback[0].__call__(status, future) + pymerlin._internal._globals._yield_callback[0].__call__(status, future) - return await future \ No newline at end of file + return await future diff --git a/tests/test_simulation.py b/tests/test_simulation.py index e0d8f1e..23ad19a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -15,7 +15,13 @@ class TestMissionModel: def __init__(self, registrar: Registrar): self.list = registrar.cell([]) self.counter = registrar.cell(0) - # self.linear = registrar.cell(line(0, 1)) + self.linear = registrar.cell((0, 1), evolution=linear_evolution) + +def linear_evolution(x, d): + initial = x[0] + rate = x[1] + delta = rate * d.micros / 1_000_000 + return initial + delta, rate class LinearCell: @@ -114,8 +120,6 @@ async def activity(mission: TestMissionModel): simulate(TestMissionModel, Schedule.build(("00:00:00", Directive("activity"))), "24:00:00") # TODO make sure ample information is extracted from the java exception - print() - def test_spawn(): """ @@ -189,5 +193,34 @@ async def other_activity(mission: TestMissionModel): Span("other_activity", Duration.of(15, SECONDS), Duration.ZERO)] +def test_concurrent_effects(): + """ + Make sure the counter gets incremented when expected by observing it from another activity + """ + + @TestMissionModel.ActivityType + async def activity(mission: TestMissionModel): + mission.counter.emit(lambda x: x + 1) + assert mission.counter.get() == 1 + await delay("00:00:00") + assert mission.counter.get() == 2 + + simulate(TestMissionModel, + Schedule.build(("00:00:00", Directive("activity")), ("00:00:00", Directive("activity"))), + "24:00:00") + + def test_autonomous_condition(): - pass + """ + Check that a task is resumed correctly when a condition becomes true + """ + + @TestMissionModel.ActivityType + async def activity(mission: TestMissionModel): + assert mission.linear.get()[0] == 0 + await delay("00:00:01") + assert mission.linear.get()[0] == 1 + await delay("00:00:02") + assert mission.linear.get()[0] == 3 + + simulate(TestMissionModel, Schedule.build(("00:00:00", Directive("activity"))), "24:00:00")