From eba276ebf2464b4230047156aa9eaf4892347aa1 Mon Sep 17 00:00:00 2001 From: Abigail Emery Date: Wed, 30 Aug 2023 13:52:19 +0100 Subject: [PATCH] Fix propagation of skipped components(#180) This is an option for fixing the propagation of skipped devices. It is not particularly elegant, but it does avoid changing the event routing and allows us to utilise the existing framework flow. --- src/tickit/core/management/schedulers/base.py | 36 +++++++++++++------ src/tickit/core/management/ticker.py | 34 ++++++++++++++---- src/tickit/core/typedefs.py | 20 +++++++++++ tests/core/management/test_ticker.py | 22 +++++++++++- 4 files changed, 94 insertions(+), 18 deletions(-) diff --git a/src/tickit/core/management/schedulers/base.py b/src/tickit/core/management/schedulers/base.py index fda3d57b1..257621b6e 100644 --- a/src/tickit/core/management/schedulers/base.py +++ b/src/tickit/core/management/schedulers/base.py @@ -15,6 +15,7 @@ Interrupt, Output, SimTime, + Skip, StopComponent, ) from tickit.utils.topic_naming import input_topic, output_topic @@ -69,24 +70,37 @@ async def update_component(self, input: Input) -> None: """ await self.state_producer.produce(input_topic(input.target), input) - async def handle_message(self, message: ComponentOutput) -> None: + async def skip_component(self, skip: Skip) -> None: + """Sends a message to itself that a given component update has been skipped. + + Args: + skip (Skip): The Skip information to be included in the message sent to the + scheduler. + """ + await self.state_producer.produce(output_topic(skip.source), skip) + + async def handle_message(self, message: Union[ComponentOutput, Skip]) -> None: """Handle messages received by the state consumer. - An asynchronous callback which handles Interrupt, Output and ComponentException - messages received by the state consumer. For Outputs, changes are propagated - and wakeups scheduled if required. For interrupts handling is deferred. For - exceptions, a StopComponent message is produced to each component in the system - to facilitate shut down. + An asynchronous callback which handles Interrupt, Output, ComponentException + and Skip messages received by the state consumer. For Outputs, changes are + propagated and wakeups scheduled if required. Skips are also propagated. For + interrupts handling is deferred. For exceptions, a StopComponent message is + produced to each component in the system to facilitate shut down. Args: - message (Union[Interrupt, Output, ComponentException]): An Interrupt, + message (Union[ComponentOutput, Skip]): An Interrupt, Output or ComponentException received by the state consumer. """ - LOGGER.debug(f"Scheduler ({type(self).__name__}) got {message}") + if not isinstance(message, Skip): + LOGGER.debug(f"Scheduler ({type(self).__name__}) got {message}") + if isinstance(message, Output): await self.ticker.propagate(message) if message.call_at is not None: self.add_wakeup(message.source, message.call_at) + elif isinstance(message, Skip): + await self.ticker.propagate(message) elif isinstance(message, Interrupt): await self.schedule_interrupt(message.source) elif isinstance(message, ComponentException): @@ -99,14 +113,16 @@ async def setup(self) -> None: subscribed to the output topics of each component in the system, a state producer to produce component inputs. """ - self.ticker = Ticker(self._wiring, self.update_component) + self.ticker = Ticker(self._wiring, self.update_component, self.skip_component) self.state_consumer: StateConsumer[ComponentOutput] = self._state_consumer_cls( self.handle_message ) await self.state_consumer.subscribe( {output_topic(component) for component in self.ticker.components} ) - self.state_producer: StateProducer[ComponentInput] = self._state_producer_cls() + self.state_producer: StateProducer[ + Union[ComponentInput, Skip] + ] = self._state_producer_cls() def add_wakeup(self, component: ComponentID, when: SimTime) -> None: """Adds a wakeup to the mapping. diff --git a/src/tickit/core/management/ticker.py b/src/tickit/core/management/ticker.py index 01c3f0936..f169f3db6 100644 --- a/src/tickit/core/management/ticker.py +++ b/src/tickit/core/management/ticker.py @@ -16,7 +16,15 @@ from immutables import Map from tickit.core.management.event_router import EventRouter, InverseWiring, Wiring -from tickit.core.typedefs import Changes, ComponentID, Input, Output, PortID, SimTime +from tickit.core.typedefs import ( + Changes, + ComponentID, + Input, + Output, + PortID, + SimTime, + Skip, +) LOGGER = logging.getLogger(__name__) @@ -32,6 +40,7 @@ def __init__( self, wiring: Union[Wiring, InverseWiring], update_component: Callable[[Input], Coroutine[Any, Any, None]], + skip_component: Callable[[Skip], Coroutine[Any, Any, None]], ) -> None: """Ticker constructor which creates an event router and performs initial setup. @@ -45,6 +54,7 @@ def __init__( """ self.event_router = EventRouter(wiring) self.update_component = update_component + self.skip_component = skip_component self.to_update: Dict[ComponentID, Optional[asyncio.Task]] = dict() self.finished: asyncio.Event = asyncio.Event() @@ -108,7 +118,7 @@ def required_dependencies(component) -> Set[ComponentID]: ) updating: Dict[ComponentID, asyncio.Task] = dict() - skipping: Set[ComponentID] = set() + for component, task in self.to_update.items(): if task is not None or required_dependencies(component): continue @@ -116,17 +126,27 @@ def required_dependencies(component) -> Set[ComponentID]: updating[component] = asyncio.create_task( self.update_component( Input( - component, self.time, Changes(Map(self.inputs[component])) + component, + self.time, + Changes(Map(self.inputs[component])), ) ) ) else: - skipping.add(component) + LOGGER.debug(f"Skipping {component}") + updating[component] = asyncio.create_task( + self.skip_component( + Skip( + source=component, + time=self.time, + changes=Changes(Map()), + ) + ) + ) + self.to_update.update(updating) - for component in skipping: - del self.to_update[component] - async def propagate(self, output: Output) -> None: + async def propagate(self, output: Union[Output, Skip]) -> None: """Propagates the output of an updated component. An asynchronous message which propagates the output of an updated component by diff --git a/src/tickit/core/typedefs.py b/src/tickit/core/typedefs.py index 99ee6fade..021ec9566 100644 --- a/src/tickit/core/typedefs.py +++ b/src/tickit/core/typedefs.py @@ -78,6 +78,26 @@ class Output: call_at: Optional[SimTime] +@dataclass(frozen=True) +class Skip: + """An immutable data container for skipping Component Updates. + + This mimics a Component output but is produced and consumed by the scheduler for + situations where a components inputs has not changed, therefore does not need + updating but this skipping needs to propgate through the graph. + + Args: + source: The Component whos update will be skipped + time: The simulation time at which the component skipping is to be handled. + changes: The changes to the component outputs, which will always be an empty + map. + """ + + source: ComponentID + time: SimTime + changes: Changes + + @dataclass(frozen=True) class Interrupt: """An immutable data container for scheduling Component interrupts. diff --git a/tests/core/management/test_ticker.py b/tests/core/management/test_ticker.py index 066217c34..4365cf8d0 100644 --- a/tests/core/management/test_ticker.py +++ b/tests/core/management/test_ticker.py @@ -15,6 +15,7 @@ Output, PortID, SimTime, + Skip, ) @@ -39,7 +40,7 @@ def inverse_wiring(inverse_wiring_struct: Inverse_Wiring_Struct) -> InverseWirin @pytest.fixture def ticker(inverse_wiring: InverseWiring) -> Ticker: - return Ticker(inverse_wiring, AsyncMock()) + return Ticker(inverse_wiring, AsyncMock(), AsyncMock()) def test_ticker_components_returns_components(ticker: Ticker): @@ -81,6 +82,25 @@ async def test_ticker_schedule_possible_updates_passes_inputs(ticker: Ticker): ) +@pytest.mark.asyncio +async def test_ticker_schedule_possible_updates_skips_components_with_no_input_changes( + ticker: Ticker, +): + ticker.time = SimTime(10) + ticker.roots = set() + ticker.to_update = {ComponentID("Mid1"): None, ComponentID("In1"): None} + ticker.inputs = defaultdict(dict, {ComponentID("Mid1"): {}}) + ticker.update_component = AsyncMock() + ticker.skip_component = AsyncMock() + + await ticker.schedule_possible_updates() + + ticker.skip_component.assert_called_once_with( + Skip(ComponentID("Mid1"), SimTime(10), Changes(Map())) + ) + ticker.update_component.assert_not_called() + + @pytest.mark.asyncio async def test_ticker_propagate_raises_unexpected_output(ticker: Ticker): ticker.time = SimTime(42)