Skip to content

Commit

Permalink
Implement concurrent effects and autonomous evolution
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdailis committed Jul 5, 2024
1 parent 3715f42 commit 9f3838a
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 32 deletions.
11 changes: 8 additions & 3 deletions pymerlin/_internal/_cell_type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from pymerlin._internal._effect_trait import EffectTrait
from pymerlin.duration import Duration, MICROSECONDS


class CellType:
def __init__(self, gateway):
def __init__(self, gateway, evolution=None):
self.gateway = gateway
self.evolution = evolution

def getEffectType(self):
"""
Expand All @@ -18,8 +20,11 @@ def duplicate(self, state):
def apply(self, state, effect):
state.setValue(effect.apply(state.getValue()))

def step(self, state, duration):
pass
def step(self, state, java_duration):
if self.evolution is None:
return
duration = Duration.of(java_duration.dividedBy(self.gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.types.Duration.MICROSECOND), MICROSECONDS)
state.setValue(self.evolution(state.getValue(), duration))

def getExpiry(self, state):
return self.gateway.jvm.java.util.Optional.empty()
Expand Down
3 changes: 2 additions & 1 deletion pymerlin/_internal/_condition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pymerlin import model_actions
from pymerlin._internal._context import _context
from pymerlin._internal._querier_adapter import QuerierAdapter


Expand All @@ -11,7 +12,7 @@ def __init__(self, gateway, func):
self.func = func

def nextSatisfied(self, querier, horizon):
with model_actions._context(QuerierAdapter(querier)):
with _context(QuerierAdapter(querier)):
if self.func():
return self.gateway.jvm.java.util.Optional.of(self.gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.types.Duration.ZERO)
else:
Expand Down
13 changes: 11 additions & 2 deletions pymerlin/_internal/_effect_trait.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from pymerlin._internal._registrar import FunctionalEffect


class EffectTrait:
def empty(self):
return 0
return FunctionalEffect(lambda x: x)
def sequentially(self, prefix, suffix):
return suffix
def concurrently(self, left, right):
return 0
def try_both(x):
res1 = left.apply(right.apply(x))
res2 = right.apply(left.apply(x))
if res1 != res2:
raise Exception("Concurrent composition of non-commutative effects")
return res1
return FunctionalEffect(try_both)
class Java:
implements = ["gov.nasa.jpl.aerie.merlin.protocol.model.EffectTrait"]
6 changes: 3 additions & 3 deletions pymerlin/_internal/_model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def set_gateway(self, gateway):
gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.driver.Topic()))

def instantiate(self, start_time, config, builder):
cell_type = CellType(self.gateway)

registrar = Registrar()
model = self.model_class(registrar)

for cell_ref, initial_value in registrar.cells:
default_cell_type = CellType(self.gateway)
for cell_ref, initial_value, evolution in registrar.cells:
cell_type = default_cell_type if evolution is None else CellType(self.gateway, evolution=evolution)
topic = self.gateway.jvm.gov.nasa.jpl.aerie.merlin.protocol.driver.Topic()
cell_id = builder.allocate(self.gateway.jvm.org.apache.commons.lang3.mutable.MutableObject(initial_value),
cell_type, self.gateway.jvm.java.util.function.Function.identity(), topic)
Expand Down
27 changes: 21 additions & 6 deletions pymerlin/_internal/_registrar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pymerlin import model_actions
from pymerlin._internal import _globals


class Registrar:
Expand All @@ -7,9 +7,9 @@ def __init__(self):
self.resources = []
self.topics = []

def cell(self, initial_value):
def cell(self, initial_value, evolution=None):
ref = CellRef()
self.cells.append((ref, initial_value))
self.cells.append((ref, initial_value, evolution))
return ref

def resource(self, name, f):
Expand All @@ -31,12 +31,15 @@ class CellRef:
"""
A reference to an allocated piece of simulation state
"""

def __init__(self):
self.id = None
self.topic = None

def emit(self, event):
model_actions._current_context[0].emit(event, self.topic)
if callable(event):
event = FunctionalEffect(event)
_globals._current_context[0].emit(event, self.topic)

def set_value(self, new_value):
self.emit(set_value(new_value))
Expand All @@ -45,7 +48,7 @@ def add(self, addend):
self.emit(add_number(addend))

def get(self):
return model_actions._current_context[0].get(self.id).getValue()
return _globals._current_context[0].get(self.id).getValue()


class set_value:
Expand All @@ -60,5 +63,17 @@ def apply(self, state):
class Java:
implements = ["java.util.function.Function"]


class FunctionalEffect:
def __init__(self, f):
self.f = f

def apply(self, state):
return self.f(state)

class Java:
implements = ["java.util.function.Function"]


def add_number(addend):
pass
pass
25 changes: 12 additions & 13 deletions pymerlin/model_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,44 @@
with the `await` keyword - for example, `await delay("01:00:00")`
"""


import asyncio

from pymerlin._internal._task_status import Delayed, Calling, Awaiting
from pymerlin.duration import Duration
from pymerlin._internal._globals import _current_context, _yield_callback
import pymerlin._internal._task_status
import pymerlin.duration
import pymerlin._internal._globals


async def delay(duration):
if type(duration) is str:
duration = Duration.from_string(duration)
elif type(duration) is not Duration:
duration = pymerlin.duration.Duration.from_string(duration)
elif type(duration) is not pymerlin.duration.Duration:
raise Exception("Argument to delay must be a Duration or a string representing a duration")
return await _yield_with(Delayed(duration))
return await _yield_with(pymerlin._internal._task_status.Delayed(duration))


def spawn(child):
"""
:param coro:
:return:
"""
_current_context[1](child)

pymerlin._internal._globals._current_context[1](child)


async def call(child):
return await _yield_with(Calling(child))
return await _yield_with(pymerlin._internal._task_status.Calling(child))


async def wait_until(condition):
"""
:param condition: A function returning True or False
"""
return await _yield_with(Awaiting(condition))
return await _yield_with(pymerlin._internal._task_status.Awaiting(condition))


async def _yield_with(status):
loop = asyncio.get_running_loop()
loop.set_debug(True)
future = loop.create_future()
_yield_callback[0].__call__(status, future)
pymerlin._internal._globals._yield_callback[0].__call__(status, future)

return await future
return await future
41 changes: 37 additions & 4 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ class TestMissionModel:
def __init__(self, registrar: Registrar):
self.list = registrar.cell([])
self.counter = registrar.cell(0)
# self.linear = registrar.cell(line(0, 1))
self.linear = registrar.cell((0, 1), evolution=linear_evolution)

def linear_evolution(x, d):
initial = x[0]
rate = x[1]
delta = rate * d.micros / 1_000_000
return initial + delta, rate


class LinearCell:
Expand Down Expand Up @@ -114,8 +120,6 @@ async def activity(mission: TestMissionModel):
simulate(TestMissionModel, Schedule.build(("00:00:00", Directive("activity"))), "24:00:00")
# TODO make sure ample information is extracted from the java exception

print()


def test_spawn():
"""
Expand Down Expand Up @@ -189,5 +193,34 @@ async def other_activity(mission: TestMissionModel):
Span("other_activity", Duration.of(15, SECONDS), Duration.ZERO)]


def test_concurrent_effects():
"""
Make sure the counter gets incremented when expected by observing it from another activity
"""

@TestMissionModel.ActivityType
async def activity(mission: TestMissionModel):
mission.counter.emit(lambda x: x + 1)
assert mission.counter.get() == 1
await delay("00:00:00")
assert mission.counter.get() == 2

simulate(TestMissionModel,
Schedule.build(("00:00:00", Directive("activity")), ("00:00:00", Directive("activity"))),
"24:00:00")


def test_autonomous_condition():
pass
"""
Check that a task is resumed correctly when a condition becomes true
"""

@TestMissionModel.ActivityType
async def activity(mission: TestMissionModel):
assert mission.linear.get()[0] == 0
await delay("00:00:01")
assert mission.linear.get()[0] == 1
await delay("00:00:02")
assert mission.linear.get()[0] == 3

simulate(TestMissionModel, Schedule.build(("00:00:00", Directive("activity"))), "24:00:00")

0 comments on commit 9f3838a

Please sign in to comment.