From 25a150965809c7085e43de51bf2f59c7de8ec61f Mon Sep 17 00:00:00 2001 From: Allison Karlitskaya Date: Wed, 22 Nov 2023 12:10:56 +0100 Subject: [PATCH] bridge: unify ChannelError and CockpitProblem These two exception types are extremely similar and would benefit from being unified. In order to do that, we need to make `Channel.close()` operate on a dictionary (as is found in CockpitProblem.attrs) rather than a set of kwargs (as was previously stored on ChannelError). Add a couple of type signatures around to help us make sure we get this right. --- src/cockpit/channel.py | 63 +++++++++++++++--------------- src/cockpit/channels/dbus.py | 2 +- src/cockpit/channels/filesystem.py | 2 +- src/cockpit/protocol.py | 3 +- test/pytest/test_peer.py | 2 +- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/cockpit/channel.py b/src/cockpit/channel.py index 5d7f88be1372..4b7891e35e36 100644 --- a/src/cockpit/channel.py +++ b/src/cockpit/channel.py @@ -20,6 +20,7 @@ from typing import BinaryIO, ClassVar, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type from .jsonutil import JsonDocument, JsonError, JsonObject, get_bool, get_str, print_object +from .protocol import CockpitProblem from .router import Endpoint, Router, RoutingRule logger = logging.getLogger(__name__) @@ -76,10 +77,8 @@ def shutdown(self): pass # we don't hold any state -class ChannelError(Exception): - def __init__(self, problem, **kwargs): - super().__init__(f'ChannelError {problem}') - self.kwargs = dict(kwargs, problem=problem) +class ChannelError(CockpitProblem): + pass class Channel(Endpoint): @@ -130,19 +129,19 @@ def do_control(self, command, message): elif command == 'options': self.do_options(message) - def do_channel_control(self, channel, command, message): + def do_channel_control(self, channel: str, command: str, message: JsonObject) -> None: # Already closing? Ignore. if self._close_args is not None: return # Catch errors and turn them into close messages try: - self.do_control(command, message) + try: + self.do_control(command, message) + except JsonError as exc: + raise ChannelError('protocol-error', message=str(exc)) from exc except ChannelError as exc: - self.close(**exc.kwargs) - except JsonError as exc: - logger.warning("%s %s %s: %s", self, channel, command, exc) - self.close(problem='protocol-error', message=str(exc)) + self.close(exc.attrs) def do_kill(self, host: Optional[str], group: Optional[str]) -> None: # Already closing? Ignore. @@ -156,27 +155,27 @@ def do_kill(self, host: Optional[str], group: Optional[str]) -> None: self.do_close() # At least this one really ought to be implemented... - def do_open(self, options): + def do_open(self, options: JsonObject) -> None: raise NotImplementedError # ... but many subclasses may reasonably want to ignore some of these. - def do_ready(self): + def do_ready(self) -> None: pass - def do_done(self): + def do_done(self) -> None: pass - def do_close(self): + def do_close(self) -> None: self.close() - def do_options(self, message): + def do_options(self, message: JsonObject) -> None: raise ChannelError('not-supported', message='This channel does not implement "options"') # 'reasonable' default, overridden in other channels for receive-side flow control - def do_ping(self, message): + def do_ping(self, message: JsonObject) -> None: self.send_pong(message) - def do_channel_data(self, channel, data): + def do_channel_data(self, channel: str, data: bytes) -> None: # Already closing? Ignore. if self._close_args is not None: return @@ -185,26 +184,26 @@ def do_channel_data(self, channel, data): try: self.do_data(data) except ChannelError as exc: - self.close(**exc.kwargs) + self.close(exc.attrs) - def do_data(self, _data): + def do_data(self, _data: bytes) -> None: # By default, channels can't receive data. self.close() # output - def ready(self, **kwargs): + def ready(self, **kwargs: JsonDocument) -> None: self.thaw_endpoint() self.send_control(command='ready', **kwargs) - def done(self): + def done(self) -> None: self.send_control(command='done') # tasks and close management def is_closing(self) -> bool: return self._close_args is not None - def _close_now(self): - self.shutdown_endpoint(**self._close_args) + def _close_now(self) -> None: + self.shutdown_endpoint(self._close_args) def _task_done(self, task): # Strictly speaking, we should read the result and check for exceptions but: @@ -227,7 +226,7 @@ def create_task(self, coroutine, name=None): task.add_done_callback(self._task_done) return task - def close(self, **kwargs): + def close(self, close_args: 'JsonObject | None' = None) -> None: """Requests the channel to be closed. After you call this method, you won't get anymore `.do_*()` calls. @@ -238,7 +237,7 @@ def close(self, **kwargs): if self._close_args is not None: # close already requested return - self._close_args = kwargs + self._close_args = close_args or {} if not self._tasks: self._close_now() @@ -319,24 +318,24 @@ async def create_transport(self, loop: asyncio.AbstractEventLoop, options: JsonO """ raise NotImplementedError - def do_open(self, options): + def do_open(self, options: JsonObject) -> None: loop = asyncio.get_running_loop() self._create_transport_task = asyncio.create_task(self.create_transport(loop, options)) self._create_transport_task.add_done_callback(self.create_transport_done) - def create_transport_done(self, task): + def create_transport_done(self, task: 'asyncio.Task[asyncio.Transport]') -> None: assert task is self._create_transport_task self._create_transport_task = None try: transport = task.result() except ChannelError as exc: - self.close(**exc.kwargs) + self.close(exc.attrs) return self.connection_made(transport) self.ready() - def connection_made(self, transport: asyncio.BaseTransport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: assert isinstance(transport, asyncio.Transport) self._transport = transport @@ -344,7 +343,7 @@ def _get_close_args(self) -> JsonObject: return {} def connection_lost(self, exc: Optional[Exception]) -> None: - self.close(**self._get_close_args()) + self.close(self._get_close_args()) def do_data(self, data: bytes) -> None: assert self._transport is not None @@ -445,7 +444,7 @@ async def run_wrapper(self, options): await self.run(options) self.close() except ChannelError as exc: - self.close(**exc.kwargs) + self.close(exc.attrs) async def read(self): while True: @@ -521,4 +520,4 @@ def do_resume_send(self) -> None: pass except StopIteration as stop: self.done() - self.close(**stop.value or {}) + self.close(stop.value) diff --git a/src/cockpit/channels/dbus.py b/src/cockpit/channels/dbus.py index 1be2ecb585be..430fd94f09c9 100644 --- a/src/cockpit/channels/dbus.py +++ b/src/cockpit/channels/dbus.py @@ -259,7 +259,7 @@ async def get_ready(): if self.owner: self.ready(unique_name=self.owner) else: - self.close(problem="not-found") + self.close({'problem': 'not-found'}) self.create_task(get_ready()) else: self.ready() diff --git a/src/cockpit/channels/filesystem.py b/src/cockpit/channels/filesystem.py index 2d303af0ee58..f7a5346b978a 100644 --- a/src/cockpit/channels/filesystem.py +++ b/src/cockpit/channels/filesystem.py @@ -198,7 +198,7 @@ def do_done(self): self._tempfile = None self.done() - self.close(tag=tag_from_path(self._path)) + self.close({'tag': tag_from_path(self._path)}) def do_close(self): if self._tempfile is not None: diff --git a/src/cockpit/protocol.py b/src/cockpit/protocol.py index 6c5d9a6660c3..2d071b930ae2 100644 --- a/src/cockpit/protocol.py +++ b/src/cockpit/protocol.py @@ -38,10 +38,11 @@ class CockpitProblem(Exception): It is usually thrown in response to some violation of expected protocol when parsing messages, connecting to a peer, or opening a channel. """ + attrs: JsonObject + def __init__(self, problem: str, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None: self.attrs = create_object(_msg, kwargs) self.attrs['problem'] = problem - self.problem = problem super().__init__(get_str(self.attrs, 'message', problem)) diff --git a/test/pytest/test_peer.py b/test/pytest/test_peer.py index cba57540e5c7..a4ae02b24f86 100644 --- a/test/pytest/test_peer.py +++ b/test/pytest/test_peer.py @@ -210,5 +210,5 @@ async def do_connect_transport(self) -> None: peer = BrokenPipePeer(specific_error=True) with pytest.raises(ChannelError) as raises: await peer.start() - assert raises.value.kwargs == {'message': 'kaputt', 'problem': 'not-supported'} + assert raises.value.attrs == {'message': 'kaputt', 'problem': 'not-supported'} peer.close()