diff --git a/sardine_core/handlers/sleep_handler/__init__.py b/sardine_core/handlers/sleep_handler/__init__.py index b911357a..eec1ff14 100644 --- a/sardine_core/handlers/sleep_handler/__init__.py +++ b/sardine_core/handlers/sleep_handler/__init__.py @@ -57,6 +57,8 @@ async def sleep_until(self, deadline: NUMBER) -> None: The deadline is based on the fish bowl clock's time. """ + + # General checks if self.env is None: raise ValueError("SleepHandler must be added to a fish bowl") elif not self.env.is_running(): @@ -120,6 +122,7 @@ def _check_running(self): self._poll_task.cancel() def _create_handle(self, deadline: NUMBER) -> TimeHandle: + #TODO: document this function handle = TimeHandle(deadline) if self.env.clock.time >= deadline: diff --git a/sardine_core/run.py b/sardine_core/run.py index dde4cadb..e833de70 100644 --- a/sardine_core/run.py +++ b/sardine_core/run.py @@ -212,6 +212,8 @@ def swim( """ def decorator(func: Union[Callable, AsyncRunner], /) -> AsyncRunner: + + # This is true when the function is already running on the scheduler if isinstance(func, AsyncRunner): func.update_state(*args, **kwargs) bowl.scheduler.start_runner(func) @@ -220,15 +222,18 @@ def decorator(func: Union[Callable, AsyncRunner], /) -> AsyncRunner: if until is not None: func = for_(until)(func) + # Checks if the runner already exists, otherwise, create a new one runner = bowl.scheduler.get_runner(func.__name__) if runner is None: runner = AsyncRunner(func.__name__) + # Some AsyncRunners need to stay in the background and never be + # interrupted by any user action (c.f. Tidal Vortex Loop) if background_job: runner.background_job = True elif not runner.is_running(): - # Runner has likely stopped swimming, in which case - # we should make sure the old state doesn't pollute - # the new function when it's pushed + # Runner has likely stopped swimming, in which case we should + # make sure the old state doesn't pollute the new function + # when it's pushed runner.reset_states() # Runners normally allow the same functions to appear in the stack, @@ -238,14 +243,16 @@ def decorator(func: Union[Callable, AsyncRunner], /) -> AsyncRunner: bowl.scheduler.start_runner(runner) return runner + # We apply the 'quant' policy (start now, on next beat, bar, etc) deadline = get_deadline_from_quant(bowl.clock, quant) + # Deadline is None when the 'quant' policy is now if deadline is None: runner.push(func, *args, **kwargs) else: runner.push_deferred(deadline, func, *args, **kwargs) - # Intentionally avoid interval correction so - # the user doesn't accidentally nudge the runner + # Intentionally avoid interval correction so the user doesn't + # accidentally nudge the runner runner.swim() runner.reload() @@ -398,12 +405,8 @@ def panic(*runners: AsyncRunner) -> None: def Pat( - pattern: str, - i: int = 0, - div: int = 1, - rate: int = 1, - as_text: bool = False - ) -> Any: + pattern: str, i: int = 0, div: int = 1, rate: int = 1, as_text: bool = False +) -> Any: """ General purpose pattern interface. This function can be used to summon the global parser stored in the fish_bowl. It is generally used to pattern outside of the diff --git a/sardine_core/scheduler/async_runner.py b/sardine_core/scheduler/async_runner.py index e5f86d45..d2a3058d 100644 --- a/sardine_core/scheduler/async_runner.py +++ b/sardine_core/scheduler/async_runner.py @@ -50,10 +50,12 @@ def _extract_new_period( ) -> Union[float, int]: period = kwargs.get("p") + # Assign a default period if necessary if period is None: param = sig.parameters.get("p") period = getattr(param, "default", default_period) + # Resolve any callable period if callable(period): try: period = period() @@ -381,6 +383,7 @@ def push_deferred( index = self._deferred_state_index self._deferred_state_index += 1 + # Create a new function state and push it to the heap queue state = FunctionState(func, args, kwargs) heapq.heappush(self.deferred_states, DeferredState(deadline, index, state)) @@ -407,6 +410,18 @@ def reload(self): self._reload_event.set() def _merge_states(self, old: FunctionState, new: FunctionState) -> None: + """ + Merges the arguments and keyword arguments of two function states. + + The old function state is expected to be the most recent state, + and the new function state is the one being pushed. The new state + will inherit any arguments that were not passed to it from the old + state, while any new keyword arguments will be added to the new state. + + Args: + old (FunctionState): The most recent function state. + new (FunctionState): The function state being pushed. + """ new.args = new.args + old.args[len(new.args) :] new.kwargs = old.kwargs | new.kwargs @@ -427,14 +442,14 @@ def is_running(self) -> bool: """Returns True if the runner is running.""" return self._task is not None and not self._task.done() - def swim(self): + def swim(self) -> None: """Allows the runner to continue the next iteration. This method must be called continuously to keep the runner alive. """ self._swimming = True - def stop(self): + def stop(self) -> None: """Stops the runner's execution after the current iteration. This method takes precedence when `swim()` is also called. @@ -442,7 +457,7 @@ def stop(self): self._stop = True self.reload() - def reset_states(self): + def reset_states(self) -> None: """Clears all function states from the runner. This method can safely be called while the runner is running. @@ -454,7 +469,7 @@ def reset_states(self): # Interval shifting - def allow_interval_correction(self): + def allow_interval_correction(self) -> None: """Allows the interval to be corrected in the next iteration.""" self._can_correct_interval = True @@ -480,13 +495,14 @@ def delay_interval(self, deadline: Union[float, int], period: Union[float, int]) RuntimeError: A function must be pushed before this can be used. """ self.snap = deadline + # TODO: explain this line and what it does exactly self.interval_shift = self.clock.get_beat_time(period, time=deadline) - def _check_snap(self, time: float): + def _check_snap(self, time: float) -> None: if self.snap is not None and time + self._last_interval >= self.snap: self.snap = None - def _correct_interval(self, period: Union[float, int]): + def _correct_interval(self, period: Union[float, int]) -> None: """Checks if the interval should be corrected. Interval correction occurs when `allow_interval_correction()` @@ -507,7 +523,7 @@ def _correct_interval(self, period: Union[float, int]): self._last_interval = interval self._can_correct_interval = False - def _correct_interval_background_job(self, period: Union[float, int]): + def _correct_interval_background_job(self, period: Union[float, int]) -> None: """ Alternative version for fixed-rate background jobs. The interval or period is not indexed on the clock like with the _correct_interval @@ -531,6 +547,25 @@ def _get_next_deadline(self, period: Union[float, int]) -> float: and the current clock time has not passed the snap, it will take priority over whatever period was passed. + If this is called earlier than the expected time, we should use + the current time to avoid calculating the next beat too far ahead, + which would cause an unusually long gap between iterations. + + If this is called after the expected time has already passed, + we should continue from that iteration and ignore the current time. + This allows returning an overdue deadline potentially caused by a + high delta, letting missed iterations fire ASAP. + + Given the above requirements, this would be the ideal solution: + time = min(self.clock.time, self._expected_time) + + However, this is complicated by SleepHandler which does not guarantee + that a successful iteration will never be earlier than the deadline. + As such, we will additionally prevent the time from being sooner than + the last successful iteration. + If we ignored this and allowed deadlines earlier than the last iteration, + the above solution could potentially trigger non-missed iterations too early. + Args: period (Union[float, int]): The number of beats in the interval. @@ -538,24 +573,6 @@ def _get_next_deadline(self, period: Union[float, int]) -> float: Returns: float: The deadline for the next interval. """ - # If this is called earlier than the expected time, we should use - # the current time to avoid calculating the next beat too far ahead, - # which would cause an unusually long gap between iterations. - # - # If this is called after the expected time has already passed, - # we should continue from that iteration and ignore the current time. - # This allows returning an overdue deadline potentially caused by a - # high delta, letting missed iterations fire ASAP. - # - # Given the above requirements, this would be the ideal solution: - # time = min(self.clock.time, self._expected_time) - # - # However, this is complicated by SleepHandler which does not guarantee - # that a successful iteration will never be earlier than the deadline. - # As such, we will additionally prevent the time from being sooner than - # the last successful iteration. - # If we ignored this and allowed deadlines earlier than the last iteration, - # the above solution could potentially trigger non-missed iterations too early. time = max(self._last_expected_time, min(self.clock.time, self._expected_time)) self._check_snap(time) @@ -567,18 +584,21 @@ def _get_next_deadline(self, period: Union[float, int]) -> float: # If the interval was corrected, this should equal to: # `period * beat_duration` expected_duration = self.clock.get_beat_time(period, time=shifted_time) - return time + expected_duration # Runner loop async def _runner(self): - current_beat = ( - self.scheduler.env.clock.beat % self.scheduler.env.clock.beats_per_bar - ) + """ + This function defines the main loop for the runner. TODO: add more. + """ + + # Query time position + current_beat = self.scheduler.env.clock.beat % self.scheduler.env.clock.beats_per_bar current_bar = self.scheduler.env.clock.bar current_phase = self.scheduler.env.clock.phase + # Preparing the runner for an incoming iteration try: self._prepare() except Exception as exc: @@ -590,33 +610,49 @@ async def _runner(self): try: while self._is_ready_for_iteration(): - # self._last_interval = self._get_period(self._last_state) * self.clock.beat_duration + # We try to run the function for real now, catching generic exception! try: await self._run_once() except Exception as exc: print(f"[red][Function exception | ({self.name})]") traceback.print_exception(type(exc), exc, exc.__traceback__) - + # Revert the state to the previous one to try to 'save' the runner self._revert_state() self.swim() finally: print(f"[yellow][Stopped [red]{self.name}[/red]][/yellow]") def _prepare(self): + """ + Prepare the runner for an incoming iteration. This method is called at the + start of the runner's main loop. TODO: add more. + """ self._last_expected_time = -math.inf + + # Grab function, arguments and key arguments stored in state self._last_state = self._get_state() + + # Setting flags self._swimming = True self._stop = False + # Extract period from state period = self._get_period(self._last_state) - self._last_interval = period * self.clock.beat_duration - async def _run_once(self): + self._last_interval = (period * self.clock.beat_duration) + + async def _run_once(self) -> None: + """ + This function is called once per iteration. It is responsible for running the function + and handling any deferred states that have arrived. TODO: complete description. + """ + # TODO: documentation needed for this very complex function + self._swimming = False self._reload_event.clear() + # 1) Get the last state state = self._get_state() - if state is not None: self._maybe_print_new_state(state) self._last_state = state @@ -628,13 +664,14 @@ async def _run_once(self): kwargs = _discard_kwargs(signature, state.kwargs) period = _extract_new_period(signature, state.kwargs, self._default_period) + # TODO: what are we doing here? if not self.background_job: self._correct_interval(period) else: self._correct_interval_background_job(period) deadline = self._get_next_deadline(period) - # Push any deferred states that have or will arrive onto the stack + # 2) Push any deferred states that have arrived or will arrive onto the stack arriving_states: list[DeferredState] = [] while self.deferred_states: entry = self.deferred_states[0] @@ -649,16 +686,15 @@ async def _run_once(self): # (similar to what `push()` does) if state is not None: self._merge_states(state, entry.state) - arriving_states.append(entry) else: break + # 3) Run the function or skip to the next iteration if arriving_states: latest_entry = arriving_states[-1] self.states.extend(e.state for e in arriving_states) - # In case the new state has a faster interval than before, - # delay it so it doesn't run too early + # In case the new state has a faster interval than before, delay it so it doesn't run too early self.delay_interval( latest_entry.deadline, self._get_period(latest_entry.state), @@ -671,11 +707,13 @@ async def _run_once(self): # so it runs exactly on the deadline instead of unnecessarily # sleeping a full period. deadline = self.deferred_states[0].deadline + # interrupted is true if we are past the deadline interrupted = await self._sleep_until(deadline) return self._jump_start_iteration() # NOTE: deadline will always be defined at this point if not self.background_job: + # interrupted is true if we are past the deadline interrupted = await self._sleep_unless_jump_started(deadline) if interrupted: return self._skip_iteration() @@ -690,7 +728,7 @@ async def _run_once(self): self._last_expected_time = self._expected_time self._update_iter() - def _update_iter(self): + def _update_iter(self) -> None: """Updates the iteration number""" self._iter += self._iter_step if self._iter_limit != "inf": @@ -709,38 +747,62 @@ async def _call_func(self, func, args, kwargs): return await maybe_coro(func, *args, **kwargs) def _get_period(self, state: Optional[FunctionState]) -> Union[float, int]: - if state is None: - return 0.0 + """ + TODO: ??? + + Args: + state (Optional[FunctionState]): A function state (typically the + most recent one). + Returns: + Union[float, int]: The period to use for the next iteration. + """ + # If we don't have a state, we can't extract a period given by the user + # Extract the period from the state or assign default period if missing return _extract_new_period( inspect.signature(state.func), state.kwargs, self.period - ) + ) if state is not None else 0.0 def _get_state(self) -> Optional[FunctionState]: + """ + Returns the top-most function state, if any. + + Returns: + Optional[FunctionState]: The top-most function state. + """ return self.states[-1] if self.states else None def _is_ready_for_iteration(self) -> bool: + """ + Conditions for the runner to be ready for the next iteration. + + Returns: + bool: True if the runner is ready for the next iteration. + """ return bool( (self.states or self.deferred_states) and self._swimming # self.swim() and not self._stop # self.stop() ) - def _maybe_print_new_state(self, state: FunctionState): - current_beat = ( - self.scheduler.env.clock.beat % self.scheduler.env.clock.beats_per_bar - ) + def _maybe_print_new_state(self, state: FunctionState) -> None: + """ + This function is called when the runner is about to start a new iteration and + a new state has been pushed. It will print a message to the console indicating + how well the runner is doing (update or saved from crash). + """ + current_beat = self.scheduler.env.clock.beat % self.scheduler.env.clock.beats_per_bar current_bar = self.scheduler.env.clock.bar current_phase = self.scheduler.env.clock.phase if self._last_state is not None and state is not self._last_state: if not self._has_reverted: print( - f"[yellow][Updating [red]{self.name}[/red] at {current_bar}/{current_beat}/{current_phase:.2f}]" + f"[yellow][Updating [red]{self.name}[/red]]" ) else: print( - f"[yellow][Saving [red]{self.name}[/red] from crash ({current_bar}/{current_beat}/{current_phase:.2f})]" + f"[yellow][Saving [red]{self.name}[/red] from crash]" ) self._has_reverted = False @@ -761,11 +823,16 @@ async def _sleep_until(self, deadline: Union[float, int]) -> bool: reload_task = asyncio.create_task(self._reload_event.wait()) try: done, pending = await asyncio.wait( + # here, we are waiting for one of the two tasks to finish: + # - wait task: a task that sleeps until deadline + # - reload task: a task that waits until the runner is reloaded (wait_task, reload_task), return_when=asyncio.FIRST_COMPLETED, ) + # The other task is cancelled, we don't need it anymore for task in pending: task.cancel() + # Retrieving the result, helpful for error propagation, etc. for task in done: task.result() @@ -786,12 +853,16 @@ async def _sleep_unless_jump_started(self, deadline: Union[float, int]) -> bool: return await self._sleep_until(deadline) - def _revert_state(self): + def _revert_state(self) -> None: + """Reverts the runner to the previous state.""" if self.states: self.states.pop() self._has_reverted = True def _skip_iteration(self) -> None: + """ + TODO: document + """ self.swim() def _jump_start_iteration(self) -> None: