Skip to content

Commit

Permalink
bridge: unify ChannelError and CockpitProblem
Browse files Browse the repository at this point in the history
These two exception types are extremely similar and would benefit from
being unified.  In order to do that, we need to make `Channel.close()`
operate on a dictionary (as is found in CockpitProblem.attrs) rather
than a set of kwargs (as was previously stored on ChannelError).

Add a couple of type signatures around to help us make sure we get this
right.
  • Loading branch information
allisonkarlitskaya committed Nov 22, 2023
1 parent 88c29d5 commit 25a1509
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 36 deletions.
63 changes: 31 additions & 32 deletions src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import BinaryIO, ClassVar, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type

from .jsonutil import JsonDocument, JsonError, JsonObject, get_bool, get_str, print_object
from .protocol import CockpitProblem
from .router import Endpoint, Router, RoutingRule

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,10 +77,8 @@ def shutdown(self):
pass # we don't hold any state


class ChannelError(Exception):
def __init__(self, problem, **kwargs):
super().__init__(f'ChannelError {problem}')
self.kwargs = dict(kwargs, problem=problem)
class ChannelError(CockpitProblem):
pass


class Channel(Endpoint):
Expand Down Expand Up @@ -130,19 +129,19 @@ def do_control(self, command, message):
elif command == 'options':
self.do_options(message)

def do_channel_control(self, channel, command, message):
def do_channel_control(self, channel: str, command: str, message: JsonObject) -> None:
# Already closing? Ignore.
if self._close_args is not None:
return

# Catch errors and turn them into close messages
try:
self.do_control(command, message)
try:
self.do_control(command, message)
except JsonError as exc:
raise ChannelError('protocol-error', message=str(exc)) from exc
except ChannelError as exc:
self.close(**exc.kwargs)
except JsonError as exc:
logger.warning("%s %s %s: %s", self, channel, command, exc)
self.close(problem='protocol-error', message=str(exc))
self.close(exc.attrs)

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
# Already closing? Ignore.
Expand All @@ -156,27 +155,27 @@ def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
self.do_close()

# At least this one really ought to be implemented...
def do_open(self, options):
def do_open(self, options: JsonObject) -> None:
raise NotImplementedError

# ... but many subclasses may reasonably want to ignore some of these.
def do_ready(self):
def do_ready(self) -> None:
pass

def do_done(self):
def do_done(self) -> None:
pass

def do_close(self):
def do_close(self) -> None:
self.close()

def do_options(self, message):
def do_options(self, message: JsonObject) -> None:
raise ChannelError('not-supported', message='This channel does not implement "options"')

# 'reasonable' default, overridden in other channels for receive-side flow control
def do_ping(self, message):
def do_ping(self, message: JsonObject) -> None:
self.send_pong(message)

def do_channel_data(self, channel, data):
def do_channel_data(self, channel: str, data: bytes) -> None:
# Already closing? Ignore.
if self._close_args is not None:
return
Expand All @@ -185,26 +184,26 @@ def do_channel_data(self, channel, data):
try:
self.do_data(data)
except ChannelError as exc:
self.close(**exc.kwargs)
self.close(exc.attrs)

def do_data(self, _data):
def do_data(self, _data: bytes) -> None:
# By default, channels can't receive data.
self.close()

# output
def ready(self, **kwargs):
def ready(self, **kwargs: JsonDocument) -> None:
self.thaw_endpoint()
self.send_control(command='ready', **kwargs)

def done(self):
def done(self) -> None:
self.send_control(command='done')

# tasks and close management
def is_closing(self) -> bool:
return self._close_args is not None

def _close_now(self):
self.shutdown_endpoint(**self._close_args)
def _close_now(self) -> None:
self.shutdown_endpoint(self._close_args)

def _task_done(self, task):
# Strictly speaking, we should read the result and check for exceptions but:
Expand All @@ -227,7 +226,7 @@ def create_task(self, coroutine, name=None):
task.add_done_callback(self._task_done)
return task

def close(self, **kwargs):
def close(self, close_args: 'JsonObject | None' = None) -> None:
"""Requests the channel to be closed.
After you call this method, you won't get anymore `.do_*()` calls.
Expand All @@ -238,7 +237,7 @@ def close(self, **kwargs):
if self._close_args is not None:
# close already requested
return
self._close_args = kwargs
self._close_args = close_args or {}
if not self._tasks:
self._close_now()

Expand Down Expand Up @@ -319,32 +318,32 @@ async def create_transport(self, loop: asyncio.AbstractEventLoop, options: JsonO
"""
raise NotImplementedError

def do_open(self, options):
def do_open(self, options: JsonObject) -> None:
loop = asyncio.get_running_loop()
self._create_transport_task = asyncio.create_task(self.create_transport(loop, options))
self._create_transport_task.add_done_callback(self.create_transport_done)

def create_transport_done(self, task):
def create_transport_done(self, task: 'asyncio.Task[asyncio.Transport]') -> None:
assert task is self._create_transport_task
self._create_transport_task = None
try:
transport = task.result()
except ChannelError as exc:
self.close(**exc.kwargs)
self.close(exc.attrs)
return

self.connection_made(transport)
self.ready()

def connection_made(self, transport: asyncio.BaseTransport):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
assert isinstance(transport, asyncio.Transport)
self._transport = transport

def _get_close_args(self) -> JsonObject:
return {}

def connection_lost(self, exc: Optional[Exception]) -> None:
self.close(**self._get_close_args())
self.close(self._get_close_args())

def do_data(self, data: bytes) -> None:
assert self._transport is not None
Expand Down Expand Up @@ -445,7 +444,7 @@ async def run_wrapper(self, options):
await self.run(options)
self.close()
except ChannelError as exc:
self.close(**exc.kwargs)
self.close(exc.attrs)

async def read(self):
while True:
Expand Down Expand Up @@ -521,4 +520,4 @@ def do_resume_send(self) -> None:
pass
except StopIteration as stop:
self.done()
self.close(**stop.value or {})
self.close(stop.value)
2 changes: 1 addition & 1 deletion src/cockpit/channels/dbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ async def get_ready():
if self.owner:
self.ready(unique_name=self.owner)
else:
self.close(problem="not-found")
self.close({'problem': 'not-found'})
self.create_task(get_ready())
else:
self.ready()
Expand Down
2 changes: 1 addition & 1 deletion src/cockpit/channels/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def do_done(self):
self._tempfile = None

self.done()
self.close(tag=tag_from_path(self._path))
self.close({'tag': tag_from_path(self._path)})

def do_close(self):
if self._tempfile is not None:
Expand Down
3 changes: 2 additions & 1 deletion src/cockpit/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ class CockpitProblem(Exception):
It is usually thrown in response to some violation of expected protocol
when parsing messages, connecting to a peer, or opening a channel.
"""
attrs: JsonObject

def __init__(self, problem: str, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None:
self.attrs = create_object(_msg, kwargs)
self.attrs['problem'] = problem
self.problem = problem
super().__init__(get_str(self.attrs, 'message', problem))


Expand Down
2 changes: 1 addition & 1 deletion test/pytest/test_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,5 @@ async def do_connect_transport(self) -> None:
peer = BrokenPipePeer(specific_error=True)
with pytest.raises(ChannelError) as raises:
await peer.start()
assert raises.value.kwargs == {'message': 'kaputt', 'problem': 'not-supported'}
assert raises.value.attrs == {'message': 'kaputt', 'problem': 'not-supported'}
peer.close()

0 comments on commit 25a1509

Please sign in to comment.