From bf7d169936f82d51e882cd781d74729214e910cb Mon Sep 17 00:00:00 2001 From: Niklas Fiekas Date: Fri, 20 Dec 2024 20:19:56 +0100 Subject: [PATCH] wip cancel engine commands --- chess/engine.py | 27 ++++++++++++--------------- test.py | 21 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/chess/engine.py b/chess/engine.py index b979b278..e2ba1fcc 100644 --- a/chess/engine.py +++ b/chess/engine.py @@ -946,9 +946,7 @@ async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) - assert command.state == CommandState.NEW, command.state if self.next_command is not None: - self.next_command.result.cancel() - self.next_command.finished.cancel() - self.next_command.set_finished() + self.next_command._cancel() self.next_command = command @@ -957,20 +955,14 @@ def previous_command_finished() -> None: if self.command is not None: cmd = self.command - def cancel_if_cancelled(result: asyncio.Future[T]) -> None: - if result.cancelled(): - cmd._cancel() - - cmd.result.add_done_callback(cancel_if_cancelled) - cmd._start() + if cmd.state == CommandState.NEW: + cmd._start() cmd.add_finished_callback(previous_command_finished) - if self.command is None: - previous_command_finished() - elif not self.command.result.done(): - self.command.result.cancel() - elif not self.command.result.cancelled(): + if self.command is not None: self.command._cancel() + else: + previous_command_finished() return await command.result @@ -1233,7 +1225,12 @@ def set_finished(self) -> None: self._dispatch_finished() def _cancel(self) -> None: - if self.state != CommandState.CANCELLING and self.state != CommandState.DONE: + if self.state == CommandState.NEW: + self.state = CommandState.DONE + self.result.cancel() + self.finished.cancel() + self._dispatch_finished() + elif self.state != CommandState.CANCELLING and self.state != CommandState.DONE: assert self.state == CommandState.ACTIVE, self.state self.state = CommandState.CANCELLING self.cancel() diff --git a/test.py b/test.py index 6db84d25..024f80d7 100755 --- a/test.py +++ b/test.py @@ -3142,6 +3142,27 @@ def test_sf_quit(self): with self.assertRaises(chess.engine.EngineTerminatedError), engine: engine.ping() + @catchAndSkip(FileNotFoundError, "need stockfish") + def test_sf_cancel(self): + class TerminateTaskGroup(Exception): + pass + + async def terminate_task_group(): + await asyncio.sleep(0.001) + raise TerminateTaskGroup() + + async def main(): + try: + async with asyncio.TaskGroup() as group: + _, engine = await chess.engine.popen_uci("stockfish") + group.create_task(engine.analyse(chess.Board(), chess.engine.Limit())) + group.create_task(engine.analyse(chess.Board(), chess.engine.Limit())) + group.create_task(terminate_task_group()) + except* TerminateTaskGroup: + pass + + asyncio.run(main()) + @catchAndSkip(FileNotFoundError, "need fairy-stockfish") def test_fairy_sf_initialize(self): with chess.engine.SimpleEngine.popen_uci("fairy-stockfish", setpgrp=True, debug=True):