Skip to content

Commit

Permalink
Improve ergonomics of call and spawn
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdailis committed Jul 13, 2024
1 parent 5eaa349 commit cd80edb
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 77 deletions.
8 changes: 5 additions & 3 deletions pymerlin/_internal/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@


@contextmanager
def _context(scheduler, spawner=None, reaction_context=None):
_set_context(scheduler, spawner, reaction_context)
def _context(scheduler, spawner=None, reaction_context=None, model_type=None):
_set_context(scheduler, spawner, reaction_context, model_type=None)
yield
_clear_context()


def _set_context(context, spawner, reaction_context):
def _set_context(context, spawner, reaction_context, model_type):
_current_context.clear()
_current_context.append(context)
_current_context.append(spawner)
_current_context.append(model_type)
_globals.reaction_context = reaction_context


def _clear_context():
_current_context.clear()
_current_context.append(None)
_current_context.append(None)
_current_context.append(None)
_globals.reaction_context = None
41 changes: 33 additions & 8 deletions pymerlin/_internal/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
import warnings
from dataclasses import dataclass

from pymerlin._internal._serialized_value import from_map_str_serialized_value
from pymerlin._internal._spawn_helpers import activity_wrapper, get_topics
from pymerlin._internal._task_specification import TaskInstance


def MissionModel(cls):
"""
Expand All @@ -19,7 +23,12 @@ def MissionModel(cls):
cls.activity_types = {}

def ActivityType(func):
activity_definition = wrap(func)
if type(func) == TaskDefinition:
activity_definition = func
elif callable(func):
activity_definition = TaskDefinition(func.__name__, lambda *args, **kwargs: activity_wrapper(TaskDefinition("inner", func), args, kwargs, *get_topics(activity_definition)))
else:
raise ValueError("Cannot decorate " + repr(func) + " with @ActivityType")
if activity_definition.name in cls.activity_types:
warnings.warn("Re-defining activity type: " + activity_definition.name)
cls.activity_types[activity_definition.name] = activity_definition
Expand All @@ -29,7 +38,7 @@ def ActivityType(func):


def Task(func):
return TaskDefinition(func)
return TaskDefinition(func.__name__, func)


def Validation(validator, message=None):
Expand All @@ -52,16 +61,32 @@ class ValidationResult:


class TaskDefinition:
def __init__(self, inner):
self.inner = inner
self.name = inner.__name__
"""
TaskDefinition can produce a TaskInstance given all of the arguments for that task
"""
def __init__(self, name, func):
self.name = name
self.inner = func
self.validations = []

def add_validation(self, validation):
self.validations.insert(0, validation)

def run_task_definition(self, *args, **kwargs):
return self.inner.__call__(*args, **kwargs)
def __call__(self, *args, **kwargs):
return self.make_instance(*args, **kwargs)

def make_instance(self, *args, **kwargs) -> TaskInstance:
# inspect.getfullargspec(self.inner)
# return self.inner.__call__(*args, **kwargs)
return TaskInstance(lambda: self.inner.__call__(*args, **kwargs))
# , f"{self.name}({', '.join(f'{k}={v}' for k, v in kwargs.items())})"

def get_task_factory(self, model, args, gateway, model_type):
from pymerlin._internal._task_factory import TaskFactory
from pymerlin._internal._threaded_task import ThreadedTaskHost

# It is expected that the first argument to an activity be the mission model
return TaskFactory(lambda: ThreadedTaskHost(gateway, model_type, self.make_instance(model, **from_map_str_serialized_value(gateway, args))))


def wrap(x):
Expand All @@ -71,5 +96,5 @@ def wrap(x):
if type(x) == TaskDefinition:
return x
if callable(x):
return TaskDefinition(x)
return TaskDefinition(x.__name__, x)
raise Exception("Unhandled variant: " + str(type(x)))
11 changes: 3 additions & 8 deletions pymerlin/_internal/_directive_type.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from pymerlin._internal._decorators import TaskDefinition
from pymerlin._internal._globals import models_by_id
from pymerlin._internal._input_type import InputType
from pymerlin._internal._serialized_value import from_map_str_serialized_value
from pymerlin._internal._spawn_helpers import activity_wrapper
from pymerlin._internal._task_factory import TaskFactory
from pymerlin._internal._threaded_task import ThreadedTaskHost


class DirectiveType:
def __init__(self, gateway, activity, input_topic, output_topic):
def __init__(self, gateway, activity, input_topic, output_topic, model_type):
if type(activity) is not TaskDefinition:
raise ValueError("Activity must be of type TaskDefinition, but was: " + repr(activity))
self.gateway = gateway
self.activity = activity
self.input_topic = input_topic
self.output_topic = output_topic
self.model_type = model_type

def getInputType(self):
return InputType()
Expand All @@ -23,8 +19,7 @@ def getOutputType(self):
return None

def getTaskFactory(self, model_id, args):
task_provider = TaskDefinition(lambda: activity_wrapper(self.activity, from_map_str_serialized_value(self.gateway, args), models_by_id[model_id][0], self.input_topic, self.output_topic))
return TaskFactory(lambda: ThreadedTaskHost(self.gateway, models_by_id[model_id][1], task_provider))
return self.activity.get_task_factory(models_by_id[model_id][0], args, self.gateway, self.model_type)

class Java:
implements = ["gov.nasa.jpl.aerie.merlin.protocol.model.DirectiveType"]
2 changes: 1 addition & 1 deletion pymerlin/_internal/_globals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
models_by_id = {}

_current_context = [None, None]
_current_context = [None, None, None]

next_cell_id = 0

Expand Down
5 changes: 3 additions & 2 deletions pymerlin/_internal/_model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def spawn(coro):
new_task = ThreadedTaskHost(self.gateway, self, coro)
builder.daemon(TaskFactory(lambda: new_task))

with _context(None, spawner=spawn):
with _context(None, spawner=spawn, model_type=self):
model = self.model_class(registrar)

model._model_type = self
Expand Down Expand Up @@ -66,7 +66,8 @@ def getDirectiveTypes(self):
self.gateway,
activity_type[0], # TaskDefinition
activity_type[1], # input_topic
activity_type[2]) # output_topic
activity_type[2], # output_topic
self) # model type
for activity_type in self.activity_types
},
self.gateway._gateway_client)
Expand Down
16 changes: 8 additions & 8 deletions pymerlin/_internal/_spawn_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pymerlin._internal import _globals
from pymerlin._internal._decorators import TaskDefinition


# async def activity_wrapper(task, args, model, input_topic, output_topic):
Expand All @@ -12,18 +11,19 @@
# if output_topic is not None:
# _globals._current_context[0].emit({}, output_topic)

def activity_wrapper(task, args, model, input_topic, output_topic):
def activity_wrapper(task, args, kwargs, input_topic, output_topic):
from pymerlin._internal._decorators import TaskDefinition
if type(task) is not TaskDefinition:
raise ValueError("Hmm, why? " + repr(task))
if input_topic is not None:
_globals._current_context[0].emit(args, input_topic)
task.run_task_definition(model, **args)
if output_topic is not None:
_globals._current_context[0].emit({}, output_topic)
_globals._current_context[0].emit({}, input_topic)
task.make_instance(*args, **kwargs).run()
_globals._current_context[0].emit({}, output_topic)

def get_topics(model_type, func):
def get_topics(func):
from pymerlin._internal._decorators import TaskDefinition
if type(func) is not TaskDefinition:
raise Exception("Whoa there buddy")
model_type = _globals._current_context[2]
for activity_func, input_topic, output_topic in model_type.activity_types:
if activity_func is func:
return input_topic, output_topic
Expand Down
34 changes: 13 additions & 21 deletions pymerlin/_internal/_task_specification.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,18 @@
class TaskInstance:
def __init__(self, func, kwargs, model, validations, definition):
"""
A TaskInstance is just a lambda with extra steps
"""
def __init__(self, func):
self.func = func
self.args = kwargs
self.model = model
self.validations = validations #
self.definition = definition
# self.kwargs = kwargs

def instantiate(self):
if self.model is None:
return self.func(**self.args) #, **self.kwargs)
else:
return self.func(self.model, **self.args) # , **self.kwargs)
# def validate(self):
# return [
# validation(self.args)
# for validation in self.validations
# ]

def validate(self):
return [
validation(self.args)
for validation in self.validations
]
# def __repr__(self):
# return self.repr

def __repr__(self):
return f"{self.definition.name}({', '.join(f'{k}={v}' for k, v in self.args.items())})"

def __call__(self, *args, **kwargs):
return self.instantiate()
def run(self):
return self.func()
13 changes: 8 additions & 5 deletions pymerlin/_internal/_threaded_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from pymerlin._internal import _globals
from pymerlin._internal._condition import Condition
from pymerlin._internal._context import _set_context, _clear_context
from pymerlin._internal._decorators import TaskDefinition
from pymerlin._internal._task_factory import TaskFactory
from pymerlin._internal._task_specification import TaskInstance
from pymerlin._internal._task_status import Completed, Delayed, Awaiting, Calling

# Host-to-task message types
Expand All @@ -22,8 +22,11 @@

class ThreadedTaskHost:
def __init__(self, gateway, model_type, task_provider):
if type(task_provider) is not TaskDefinition:
if type(task_provider) is not TaskInstance:
raise ValueError(repr(task_provider))
from pymerlin._internal._model_type import ModelType
if type(model_type) is not ModelType:
raise ValueError(repr(model_type))
self.gateway = gateway
self.task_thread = _ThreadedTask(task_provider, model_type, gateway)
self.task_thread_started = False
Expand Down Expand Up @@ -81,8 +84,8 @@ def _spawn(self, task_provider):

def _run(self, scheduler):
try:
_set_context(scheduler, self._spawn, self)
result = self.task.run_task_definition()
_set_context(scheduler, self._spawn, self, self.model_type)
result = self.task.run()
self.outbox.put(Finished(result))
except TaskAbort:
self.outbox.put(Aborted())
Expand All @@ -97,7 +100,7 @@ def yield_with(self, status):
self.outbox.put(Yield(status))
request = self.inbox.get()
if type(request) == Resume:
_set_context(request.scheduler, self._spawn, self)
_set_context(request.scheduler, self._spawn, self, self.model_type)
return
elif type(request) == Abort:
self.aborting = True
Expand Down
30 changes: 12 additions & 18 deletions pymerlin/model_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pymerlin._internal._task_status
import pymerlin.duration
import pymerlin._internal._globals
from pymerlin._internal._decorators import TaskDefinition
from pymerlin._internal._spawn_helpers import activity_wrapper, get_topics
from pymerlin._internal._task_specification import TaskInstance


def delay(duration):
Expand All @@ -18,37 +18,31 @@ def delay(duration):
return _yield_with(pymerlin._internal._task_status.Delayed(duration))


def spawn_activity(model, child, args):
def spawn_activity(child):
"""
:param coro:
:return:
"""
topics = get_topics(model._model_type, child)
task_provider = TaskDefinition(lambda: activity_wrapper(
child,
args,
model,
*topics))
pymerlin._internal._globals._current_context[1](task_provider)
pymerlin._internal._globals._current_context[1](child)


def spawn_task(child, args):
"""
:param coro:
:return:
"""
pymerlin._internal._globals._current_context[1](TaskDefinition(lambda: child.run_task_definition(**args)))
pymerlin._internal._globals._current_context[1](child.make_instance(**args))


def call(model, child, args):
if type(child) is not TaskDefinition:
def call(child):
if type(child) is not TaskInstance:
raise ValueError("Should be TaskDefinition, was: " + repr(child))
task_provider = TaskDefinition(lambda: activity_wrapper(
child,
args,
model,
*get_topics(model._model_type, child)))
return _yield_with(pymerlin._internal._task_status.Calling(task_provider))
# task_provider = TaskInstance(lambda: activity_wrapper(
# child,
# args,
# model,
# *get_topics(model._model_type, child)))
return _yield_with(pymerlin._internal._task_status.Calling(child))


def wait_until(condition):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_spawn_activity():
@TestMissionModel.ActivityType
def activity(mission: TestMissionModel):
mission.counter.set(123)
spawn_activity(mission, other_activity, {})
spawn_activity(other_activity(mission))
mission.counter.set(345)
assert mission.counter.get() == 345

Expand Down Expand Up @@ -168,7 +168,7 @@ def test_call():
@TestMissionModel.ActivityType
def activity(mission: TestMissionModel):
mission.counter.set(123)
call(mission, other_activity, {})
call(other_activity(mission))
assert mission.counter.get() == 345
delay("00:00:01")

Expand All @@ -193,7 +193,7 @@ def test_call_task():
@TestMissionModel.ActivityType
def activity(mission: TestMissionModel):
mission.counter.set(123)
call(mission, subtask, {})
call(subtask(mission))
assert mission.counter.get() == 345
delay("00:00:01")

Expand Down

0 comments on commit cd80edb

Please sign in to comment.