From 8d28bffcb41320fd72f20c0c6d4ab1302077d9e7 Mon Sep 17 00:00:00 2001 From: Allison Karlitskaya Date: Tue, 21 Nov 2023 09:51:39 +0100 Subject: [PATCH 1/4] bridge: clean up handling of control messages Remove the general purpose mechanism for sending JSON-formatted messages from the protocol level, leaving only the mechanism for sending control messages. Non-control JSON-formatted messages (such as those used for D-Bus replies) are now handled at the channel level and treated as normal data, subject to flow control. The function for doing so has been renamed from send_message() to send_json() for clarity, and returns boolean (with the same meaning as send_data()), although nobody currently pays any attention to this. This removes one of the main interaction points between endpoints and the router. send_json() uses a class attribute defined on Channel to encode the JSON data. By default, this is the default json.JSONEncoder, but channel subclasses can provide their own encoder (D-Bus needs the encoder provided by systemd_ctypes). This is all driven by a new function in jsonutil which defines the exact way in which we intend to handle keyword args when building control messages. We make a clear distinction between data which is meant to be handled "verbatim" and data which is meant to be subject to processing rules (our '_' to '-' replacements, etc). Start using the new jsonutil function for all control message handling. That means that all control-message-creating paths now offer the opportunity to pass a verbatim dictionary as well as a set of kwargs. Propagate this change upwards, throughout. For forwarding messages from peers this is substantially cleaner because it means we don't rewrite their messages. The improved typing results in some extra mypy errors which requires some hinting at call sites. One note: ideally, we'd use `/` in the argument list of all of these functions to ensure that the "verbatim object" field that we add everywhere can only be passed positionally. This is unfortunately only in Python 3.8. To work around that issue, we add a `_msg` field everywhere with its name chosen never to clash with an actual kwarg we might use (such as `message`). Unfortunately, `mypy` isn't yet convinced that this is a purely positional argument, and in places that we write `**kwargs` it has no way to know that one of the kwargs won't in fact be `_msg`, so it complains that the types don't match. In the cases where that happens, we can manually specify `None` to avoid that problem. --- src/cockpit/beiboot.py | 6 ++-- src/cockpit/bridge.py | 4 ++- src/cockpit/channel.py | 18 +++++++---- src/cockpit/channels/dbus.py | 51 ++++++++++++++++-------------- src/cockpit/channels/filesystem.py | 8 ++--- src/cockpit/channels/metrics.py | 11 ++----- src/cockpit/channels/packages.py | 6 ++-- src/cockpit/jsonutil.py | 25 +++++++++++++++ src/cockpit/peer.py | 8 ++--- src/cockpit/protocol.py | 40 +++++++++-------------- src/cockpit/remote.py | 6 ++-- src/cockpit/router.py | 17 +++++----- src/cockpit/superuser.py | 2 +- 13 files changed, 111 insertions(+), 91 deletions(-) diff --git a/src/cockpit/beiboot.py b/src/cockpit/beiboot.py index 68c153452bb..c42e249d1e6 100644 --- a/src/cockpit/beiboot.py +++ b/src/cockpit/beiboot.py @@ -281,7 +281,7 @@ def do_init(self, message): if isinstance(message.get('superuser'), dict): self.write_control(command='superuser-init-done') message['superuser'] = False - self.ssh_peer.write_control(**message) + self.ssh_peer.write_control(message) async def run(args) -> None: @@ -306,12 +306,12 @@ async def run(args) -> None: if bridge.packages: message['packages'] = {p: None for p in bridge.packages.packages} - bridge.write_control(**message) + bridge.write_control(message) bridge.ssh_peer.thaw_endpoint() except ferny.InteractionError as exc: sys.exit(str(exc)) except CockpitProblem as exc: - bridge.write_control(command='init', problem=exc.problem, **exc.kwargs) + bridge.write_control(exc.attrs, command='init') return logger.debug('Startup done. Looping until connection closes.') diff --git a/src/cockpit/bridge.py b/src/cockpit/bridge.py index c2f3175cb45..26ca88645be 100644 --- a/src/cockpit/bridge.py +++ b/src/cockpit/bridge.py @@ -138,13 +138,15 @@ def do_init(self, message: JsonObject) -> None: def do_send_init(self) -> None: init_args = { 'capabilities': {'explicit-superuser': True}, + 'command': 'init', 'os-release': self.get_os_release(), + 'version': 1, } if self.packages is not None: init_args['packages'] = {p: None for p in self.packages.packages} - self.write_control(command='init', version=1, **init_args) + self.write_control(init_args) # PackagesListener interface def packages_loaded(self) -> None: diff --git a/src/cockpit/channel.py b/src/cockpit/channel.py index f4aaacbf951..7c27dee6ee8 100644 --- a/src/cockpit/channel.py +++ b/src/cockpit/channel.py @@ -16,10 +16,11 @@ # along with this program. If not, see . import asyncio +import json import logging from typing import BinaryIO, ClassVar, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type -from .jsonutil import JsonError, JsonObject, get_bool, get_str +from .jsonutil import JsonDocument, JsonError, JsonObject, create_object, get_bool, get_str from .router import Endpoint, Router, RoutingRule logger = logging.getLogger(__name__) @@ -276,14 +277,17 @@ def do_resume_send(self) -> None: """Called to indicate that the channel may start sending again.""" # change to `raise NotImplementedError` after everyone implements it - def send_message(self, **kwargs): - self.send_channel_message(self.channel, **kwargs) + json_encoder: ClassVar[json.JSONEncoder] = json.JSONEncoder(indent=2) - def send_control(self, command, **kwargs): - self.send_channel_control(self.channel, command=command, **kwargs) + def send_json(self, **kwargs: JsonDocument) -> bool: + pretty = self.json_encoder.encode(create_object(None, kwargs)) + '\n' + return self.send_data(pretty.encode()) - def send_pong(self, message): - self.send_channel_control(**dict(message, command='pong')) + def send_control(self, command: str, **kwargs: JsonDocument) -> None: + self.send_channel_control(self.channel, command, None, **kwargs) + + def send_pong(self, message: JsonObject) -> None: + self.send_channel_control(self.channel, 'pong', message) class ProtocolChannel(Channel, asyncio.Protocol): diff --git a/src/cockpit/channels/dbus.py b/src/cockpit/channels/dbus.py index 1be2ecb585b..b770ad176a9 100644 --- a/src/cockpit/channels/dbus.py +++ b/src/cockpit/channels/dbus.py @@ -41,6 +41,7 @@ import traceback import xml.etree.ElementTree as ET +from cockpit._vendor import systemd_ctypes from cockpit._vendor.systemd_ctypes import Bus, BusError, introspection from ..channel import Channel, ChannelError @@ -166,6 +167,7 @@ def notify_update(notify, path, interface_name, props): class DBusChannel(Channel): + json_encoder = systemd_ctypes.JSONEncoder(indent=2) payload = 'dbus-json3' matches = None @@ -179,7 +181,7 @@ def send_owner(owner): # notifications. cockpit.js relies on that. if self.owner != owner: self.owner = owner - self.send_message(owner=owner) + self.send_json(owner=owner) def handler(message): name, old, new = message.get_body() @@ -207,7 +209,7 @@ def handler(message): "StartServiceByName", "su", self.name, 0) except BusError as start_error: logger.debug("Failed to start service '%s': %s", self.name, start_error.message) - self.send_message(owner=None) + self.send_json(owner=None) else: logger.debug("Failed to get owner of service '%s': %s", self.name, error.message) else: @@ -325,15 +327,17 @@ async def do_call(self, message): logger.debug('Doing introspection request for %s %s', iface, method) signature = await self.cache.get_signature(iface, method, self.bus, self.name, path) except BusError as error: - self.send_message(error=[error.name, [f'Introspection: {error.message}']], id=cookie) + self.send_json(error=[error.name, [f'Introspection: {error.message}']], id=cookie) return except KeyError: - self.send_message(error=["org.freedesktop.DBus.Error.UnknownMethod", - [f"Introspection data for method {iface} {method} not available"]], - id=cookie) + self.send_json( + error=[ + "org.freedesktop.DBus.Error.UnknownMethod", + [f"Introspection data for method {iface} {method} not available"]], + id=cookie) return except Exception as exc: - self.send_message(error=['python.error', [f'Introspection: {exc!s}']], id=cookie) + self.send_json(error=['python.error', [f'Introspection: {exc!s}']], id=cookie) return try: @@ -343,15 +347,16 @@ async def do_call(self, message): # watch processing, wait for that to be done. async with self.watch_processing_lock: # TODO: stop hard-coding the endian flag here. - self.send_message(reply=[reply.get_body()], id=cookie, - flags="<" if flags is not None else None, - type=reply.get_signature(True)) # noqa: FBT003 + self.send_json( + reply=[reply.get_body()], id=cookie, + flags="<" if flags is not None else None, + type=reply.get_signature(True)) # noqa: FBT003 except BusError as error: # actually, should send the fields from the message body - self.send_message(error=[error.name, [error.message]], id=cookie) + self.send_json(error=[error.name, [error.message]], id=cookie) except Exception: logger.exception("do_call(%s): generic exception", message) - self.send_message(error=['python.error', [traceback.format_exc()]], id=cookie) + self.send_json(error=['python.error', [traceback.format_exc()]], id=cookie) async def do_add_match(self, message): add_match = message['add-match'] @@ -360,7 +365,7 @@ async def do_add_match(self, message): async def match_hit(message): logger.debug('got match') async with self.watch_processing_lock: - self.send_message(signal=[ + self.send_json(signal=[ message.get_path(), message.get_interface(), message.get_member(), @@ -388,14 +393,14 @@ async def handler(message): if mm: meta.update({name: mm}) notify_update(notify, path, name, props) - self.send_message(meta=meta) - self.send_message(notify=notify) + self.send_json(meta=meta) + self.send_json(notify=notify) elif member == "InterfacesRemoved": (path, interfaces) = message.get_body() logger.debug('interfaces removed %s %s', path, interfaces) async with self.watch_processing_lock: notify = {path: {name: None for name in interfaces}} - self.send_message(notify=notify) + self.send_json(notify=notify) self.add_async_signal_handler(handler, path=path, @@ -432,7 +437,7 @@ async def handler(message): props[inv] = reply notify = {} notify_update(notify, path, name, props) - self.send_message(notify=notify) + self.send_json(notify=notify) this_meta = await self.cache.introspect_path(self.bus, self.name, path) if interface_name is not None: @@ -471,8 +476,8 @@ async def do_watch(self, message): if path is None or cookie is None: logger.debug('ignored incomplete watch request %s', message) - self.send_message(error=['x.y.z', ['Not Implemented']], id=cookie) - self.send_message(reply=[], id=cookie) + self.send_json(error=['x.y.z', ['Not Implemented']], id=cookie) + self.send_json(reply=[], id=cookie) return try: @@ -482,12 +487,12 @@ async def do_watch(self, message): await self.setup_path_watch(path, interface_name, recursive, meta, notify) if recursive: await self.setup_objectmanager_watch(path, interface_name, meta, notify) - self.send_message(meta=meta) - self.send_message(notify=notify) - self.send_message(reply=[], id=message['id']) + self.send_json(meta=meta) + self.send_json(notify=notify) + self.send_json(reply=[], id=message['id']) except BusError as error: logger.debug("do_watch(%s) caught D-Bus error: %s", message, error.message) - self.send_message(error=[error.name, [error.message]], id=cookie) + self.send_json(error=[error.name, [error.message]], id=cookie) async def do_meta(self, message): self.cache.inject(message['meta']) diff --git a/src/cockpit/channels/filesystem.py b/src/cockpit/channels/filesystem.py index 2d303af0ee5..1efb879d9ad 100644 --- a/src/cockpit/channels/filesystem.py +++ b/src/cockpit/channels/filesystem.py @@ -61,7 +61,7 @@ def send_entry(self, event, entry): else: mode = 'special' - self.send_message(event=event, path=entry.name, type=mode) + self.send_json(event=event, path=entry.name, type=mode) def do_open(self, options): path = options.get('path') @@ -239,20 +239,20 @@ def do_inotify_event(self, mask, _cookie, name): # file inside watched directory changed path = os.path.join(self._path, name.decode()) tag = tag_from_path(path) - self.send_message(event=event, path=path, tag=tag, type=type_) + self.send_json(event=event, path=path, tag=tag, type=type_) else: # the watched path itself changed; filter out duplicate events tag = tag_from_path(self._path) if tag == self._tag: return self._tag = tag - self.send_message(event=event, path=self._path, tag=self._tag, type=type_) + self.send_json(event=event, path=self._path, tag=self._tag, type=type_) def do_identity_changed(self, fd, err): logger.debug("do_identity_changed(%s): fd %s, err %s", self._path, str(fd), err) self._tag = tag_from_fd(fd) if fd else '-' if self._active: - self.send_message(event='created' if fd else 'deleted', path=self._path, tag=self._tag) + self.send_json(event='created' if fd else 'deleted', path=self._path, tag=self._tag) def do_open(self, options): self._path = options['path'] diff --git a/src/cockpit/channels/metrics.py b/src/cockpit/channels/metrics.py index 82cc5003fa5..ffc402e57b0 100644 --- a/src/cockpit/channels/metrics.py +++ b/src/cockpit/channels/metrics.py @@ -24,6 +24,7 @@ from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union from ..channel import AsyncChannel, ChannelError +from ..jsonutil import JsonList from ..samples import SAMPLERS, SampleDescription, Sampler, Samples logger = logging.getLogger(__name__) @@ -88,7 +89,7 @@ def parse_options(self, options): self.samplers = {cls() for cls in sampler_classes} def send_meta(self, samples: Samples, timestamp: float): - metrics = [] + metrics: JsonList = [] for metricinfo in self.metrics: if metricinfo.desc.instanced: metrics.append({ @@ -105,13 +106,7 @@ def send_meta(self, samples: Samples, timestamp: float): 'semantics': metricinfo.desc.semantics }) - meta = { - 'timestamp': timestamp * 1000, - 'interval': self.interval, - 'source': 'internal', - 'metrics': metrics - } - self.send_message(**meta) + self.send_json(source='internal', interval=self.interval, timestamp=timestamp * 1000, metrics=metrics) self.need_meta = False def sample(self): diff --git a/src/cockpit/channels/packages.py b/src/cockpit/channels/packages.py index 90445cf91b9..912aab4fd73 100644 --- a/src/cockpit/channels/packages.py +++ b/src/cockpit/channels/packages.py @@ -35,7 +35,7 @@ class PackagesChannel(AsyncChannel): def http_error(self, status: int, message: str) -> None: template = read_cockpit_data_file('fail.html') - self.send_message(status=status, reason='ERROR', headers={'Content-Type': 'text/html; charset=utf-8'}) + self.send_json(status=status, reason='ERROR', headers={'Content-Type': 'text/html; charset=utf-8'}) self.send_data(template.replace(b'@@message@@', message.encode('utf-8'))) self.done() self.close() @@ -58,7 +58,7 @@ async def run(self, options: JsonObject) -> None: # Note: we can't cache documents right now. See # https://github.com/cockpit-project/cockpit/issues/19071 # for future plans. - out_headers = { + out_headers: JsonObject = { 'Cache-Control': 'no-cache, no-store', 'Content-Type': document.content_type, } @@ -97,5 +97,5 @@ async def run(self, options: JsonObject) -> None: self.http_error(500, f'Internal error: {exc!s}') else: - self.send_message(status=200, reason='OK', headers=out_headers) + self.send_json(status=200, reason='OK', headers=out_headers) await self.sendfile(document.data) diff --git a/src/cockpit/jsonutil.py b/src/cockpit/jsonutil.py index 481838b6cea..63370d273dc 100644 --- a/src/cockpit/jsonutil.py +++ b/src/cockpit/jsonutil.py @@ -104,3 +104,28 @@ def get_objv(obj: JsonObject, key: str, constructor: Callable[[JsonObject], T]) def as_objv(value: JsonDocument) -> Sequence[T]: return tuple(constructor(typechecked(item, dict)) for item in typechecked(value, list)) return _get(obj, as_objv, key, ()) + + +def create_object(message: 'JsonObject | None', kwargs: JsonObject) -> JsonObject: + """Constructs a JSON object based on message and kwargs. + + If only message is given, it is returned, unmodified. If message is None, + it is equivalent to an empty dictionary. A copy is always made. + + If kwargs are present, then any underscore ('_') present in a key name is + rewritten to a dash ('-'). This is intended to bridge between the required + Python syntax when providing kwargs and idiomatic JSON (which uses '-' for + attributes). These values override values in message. + + The idea is that `message` should be used for passing data along, and + kwargs used for data originating at a given call site, possibly including + modifications to an original message. + """ + result = dict(message or {}) + + for key, value in kwargs.items(): + # rewrite '_' (necessary in Python syntax kwargs list) to '-' (idiomatic JSON) + json_key = key.replace('_', '-') + result[json_key] = value + + return result diff --git a/src/cockpit/peer.py b/src/cockpit/peer.py index 2e3df678477..002388465f7 100644 --- a/src/cockpit/peer.py +++ b/src/cockpit/peer.py @@ -118,7 +118,7 @@ def _connect_task_done(task: asyncio.Task) -> None: if init_host is not None: logger.debug(' sending init message back, host %s', init_host) # Send "init" back - self.write_control(command='init', version=1, host=init_host, **kwargs) + self.write_control(None, command='init', version=1, host=init_host, **kwargs) # Thaw the queued messages self.thaw_endpoint() @@ -182,7 +182,7 @@ def do_closed(self, exc: Optional[Exception]) -> None: else: self.shutdown_endpoint(problem='terminated', message=f'Peer exited with status {exc.exit_code}') elif isinstance(exc, CockpitProblem): - self.shutdown_endpoint(problem=exc.problem, **exc.kwargs) + self.shutdown_endpoint(exc.attrs) else: self.shutdown_endpoint(problem='internal-error', message=f"[{exc.__class__.__name__}] {exc!s}") @@ -209,7 +209,7 @@ def process_exited(self) -> None: def channel_control_received(self, channel: str, command: str, message: JsonObject) -> None: if self.init_future is not None: raise CockpitProtocolError('Received unexpected channel control message before init') - self.send_channel_control(**message) + self.send_channel_control(channel, command, message) def channel_data_received(self, channel: str, data: bytes) -> None: if self.init_future is not None: @@ -219,7 +219,7 @@ def channel_data_received(self, channel: str, data: bytes) -> None: # Forwarding data: from the router to the peer def do_channel_control(self, channel: str, command: str, message: JsonObject) -> None: assert self.init_future is None - self.write_control(**message) + self.write_control(message) def do_channel_data(self, channel: str, data: bytes) -> None: assert self.init_future is None diff --git a/src/cockpit/protocol.py b/src/cockpit/protocol.py index 44584ad0513..d65133823a3 100644 --- a/src/cockpit/protocol.py +++ b/src/cockpit/protocol.py @@ -19,11 +19,9 @@ import json import logging import uuid -from typing import ClassVar, Dict, Optional +from typing import Dict, Optional -from cockpit._vendor import systemd_ctypes - -from .jsonutil import JsonError, JsonObject, get_str, typechecked +from .jsonutil import JsonDocument, JsonError, JsonObject, create_object, get_str, typechecked logger = logging.getLogger(__name__) @@ -40,10 +38,11 @@ class CockpitProblem(Exception): It is usually thrown in response to some violation of expected protocol when parsing messages, connecting to a peer, or opening a channel. """ - def __init__(self, problem: str, **kwargs): - super().__init__(kwargs.get('message') or problem) + def __init__(self, problem: str, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None: + self.attrs = create_object(_msg, kwargs) + self.attrs['problem'] = problem self.problem = problem - self.kwargs = kwargs + super().__init__(get_str(self.attrs, 'message', problem)) class CockpitProtocolError(CockpitProblem): @@ -57,7 +56,6 @@ class CockpitProtocol(asyncio.Protocol): We need to use this because Python's SelectorEventLoop doesn't supported buffered protocols. """ - json_encoder: ClassVar[json.JSONEncoder] = systemd_ctypes.JSONEncoder(indent=2) transport: Optional[asyncio.Transport] = None buffer = b'' _closed: bool = False @@ -191,21 +189,11 @@ def write_channel_data(self, channel, payload): else: logger.debug('cannot write to closed transport') - def write_message(self, _channel, **kwargs): - """Format kwargs as a JSON blob and send as a message - Any kwargs with '_' in their names will be converted to '-' - """ - for name in list(kwargs): - if '_' in name: - kwargs[name.replace('_', '-')] = kwargs[name] - del kwargs[name] - - logger.debug('sending message %s %s', _channel, kwargs) - pretty = CockpitProtocol.json_encoder.encode(kwargs) + '\n' - self.write_channel_data(_channel, pretty.encode('utf-8')) - - def write_control(self, **kwargs): - self.write_message('', **kwargs) + def write_control(self, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None: + """Write a control message. See jsonutil.create_object() for details.""" + logger.debug('sending control message %r %r', _msg, kwargs) + pretty = json.dumps(create_object(_msg, kwargs), indent=2) + '\n' + self.write_channel_data('', pretty.encode()) def data_received(self, data): try: @@ -269,14 +257,16 @@ def do_ready(self): self.do_send_init() # authorize request/response API - async def request_authorization(self, challenge: str, timeout: Optional[int] = None, **kwargs: object) -> str: + async def request_authorization( + self, challenge: str, timeout: 'int | None' = None, **kwargs: JsonDocument + ) -> str: if self.authorizations is None: self.authorizations = {} cookie = str(uuid.uuid4()) future = asyncio.get_running_loop().create_future() try: self.authorizations[cookie] = future - self.write_control(command='authorize', challenge=challenge, cookie=cookie, **kwargs) + self.write_control(None, command='authorize', challenge=challenge, cookie=cookie, **kwargs) return await asyncio.wait_for(future, timeout) finally: self.authorizations.pop(cookie) diff --git a/src/cockpit/remote.py b/src/cockpit/remote.py index c4facc36493..23f947566a0 100644 --- a/src/cockpit/remote.py +++ b/src/cockpit/remote.py @@ -110,7 +110,7 @@ async def do_connect_transport(self) -> None: # containing the key that would need to be accepted. That will # cause the front-end to present a dialog. _reason, host, algorithm, key, fingerprint = responder.hostkeys_seen[0] - error_args = {'host_key': f'{host} {algorithm} {key}', 'host_fingerprint': fingerprint} + error_args: JsonObject = {'host-key': f'{host} {algorithm} {key}', 'host-fingerprint': fingerprint} else: error_args = {} @@ -124,12 +124,12 @@ async def do_connect_transport(self) -> None: logger.debug('SshPeer got a %s %s; private %s, seen hostkeys %r; raising %s with extra args %r', type(exc), exc, self.private, responder.hostkeys_seen, error, error_args) - raise PeerError(error, error=error, auth_method_results={}, **error_args) from exc + raise PeerError(error, error_args, error=error, auth_method_results={}) from exc except ferny.SshAuthenticationError as exc: logger.debug('authentication to host %s failed: %s', host, exc) - results = {method: 'not-provided' for method in exc.methods} + results: JsonObject = {method: 'not-provided' for method in exc.methods} if 'password' in results and self.password is not None: if responder.password_attempts == 0: results['password'] = 'not-tried' diff --git a/src/cockpit/router.py b/src/cockpit/router.py index 3ac0a633ac0..737f33df7e7 100644 --- a/src/cockpit/router.py +++ b/src/cockpit/router.py @@ -92,17 +92,16 @@ def do_kill(self, host: Optional[str], group: Optional[str]) -> None: def send_channel_data(self, channel: str, data: bytes) -> None: self.router.write_channel_data(channel, data) - def send_channel_message(self, channel: str, **kwargs: JsonDocument) -> None: - self.router.write_message(channel, **kwargs) - - def send_channel_control(self, channel, command, **kwargs: JsonDocument) -> None: - self.router.write_control(channel=channel, command=command, **kwargs) + def send_channel_control( + self, channel: str, command: str, _msg: 'JsonObject | None', **kwargs: JsonDocument + ) -> None: + self.router.write_control(_msg, channel=channel, command=command, **kwargs) if command == 'close': self.router.endpoints[self].remove(channel) self.router.drop_channel(channel) - def shutdown_endpoint(self, **kwargs: JsonDocument) -> None: - self.router.shutdown_endpoint(self, **kwargs) + def shutdown_endpoint(self, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None: + self.router.shutdown_endpoint(self, _msg, **kwargs) class RoutingError(Exception): @@ -163,11 +162,11 @@ def drop_channel(self, channel: str) -> None: except KeyError: logger.error('trying to drop non-existent channel %s from %s', channel, self.open_channels) - def shutdown_endpoint(self, endpoint: Endpoint, **kwargs) -> None: + def shutdown_endpoint(self, endpoint: Endpoint, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None: channels = self.endpoints.pop(endpoint) logger.debug('shutdown_endpoint(%s, %s) will close %s', endpoint, kwargs, channels) for channel in channels: - self.write_control(command='close', channel=channel, **kwargs) + self.write_control(_msg, command='close', channel=channel, **kwargs) self.drop_channel(channel) # were we waiting to exit? diff --git a/src/cockpit/superuser.py b/src/cockpit/superuser.py index aa48706ed79..317ee98896b 100644 --- a/src/cockpit/superuser.py +++ b/src/cockpit/superuser.py @@ -222,7 +222,7 @@ def init(self, params: JsonObject) -> None: self._init_task = asyncio.create_task(self.go(name, responder)) self._init_task.add_done_callback(self._init_done) - def _init_done(self, task): + def _init_done(self, task: 'asyncio.Task[None]') -> None: logger.debug('superuser init done! %s', task.exception()) self.router.write_control(command='superuser-init-done') del self._init_task From 8b304421dea95feaa24cd381f79871cbc5515ce5 Mon Sep 17 00:00:00 2001 From: Allison Karlitskaya Date: Wed, 22 Nov 2023 14:15:27 +0100 Subject: [PATCH 2/4] bridge: rewrite HTTP channel using AsyncChannel This fixes many issues, including: - Based on AsyncChannel: The new code is based on AsyncChannel and handles blocking operations by using .run_in_executor() instead of creating a working thread. All control flow now happens in a straight line in the run() function. - Implements flow control: The previous code would read many small blocks in a thread and flood the main thread with requests to write them through immediately, without any attempt at flow control. The new code will block if the AsyncChannel send queue backs up. - Better type safety: we use our new JSON helpers to do this. The argument parsing is also less redundant in general since we don't have to do up-front checking and then look everything up again later in the worker thread. - Fewer manual exit paths: use ChannelError instead of manual closes. - Fix header handling: The header handling was incorrect before, compared with the C bridge: we were supposed to always remove 'Connection' and 'Transfer-Encoding', and to remove 'Content-Length' and 'Range' for non-binary channels, but we ended up getting the logic wrong, removing 'Content-Length' and 'Range' for binary channels and 'Conntection' and 'Transfer-Encoding' for text. The behaviour now matches the C bridge. - Safety fixes: The previous code had at least one case of unsafely sending data from the worker thread directly to the channel. This is no longer possible under the new design. --- src/cockpit/channels/http.py | 222 +++++++++++++++++------------------ 1 file changed, 108 insertions(+), 114 deletions(-) diff --git a/src/cockpit/channels/http.py b/src/cockpit/channels/http.py index 6051658fe31..f0dea1ec4d1 100644 --- a/src/cockpit/channels/http.py +++ b/src/cockpit/channels/http.py @@ -20,29 +20,49 @@ import logging import socket import ssl -import threading -from ..channel import Channel +from ..channel import AsyncChannel, ChannelError +from ..jsonutil import JsonObject, get_dict, get_int, get_object, get_str, typechecked logger = logging.getLogger(__name__) -class HttpChannel(Channel): +class HttpChannel(AsyncChannel): payload = 'http-stream2' - def create_connection(self): - opt_address = self.options.get('address') or 'localhost' - opt_port = self.options.get('port') - opt_unix = self.options.get('unix') - opt_tls = self.options.get('tls') - logger.debug('connecting to %s:%s; tls: %s', opt_address, opt_port or opt_unix, opt_tls) + @staticmethod + def get_headers(response: http.client.HTTPResponse, binary: 'str | None') -> JsonObject: + # Never send these headers + remove = {'Connection', 'Transfer-Encoding'} + + if binary != 'raw': + # Only send these headers for raw binary streams + remove.update({'Content-Length', 'Range'}) + + return {key: value for key, value in response.getheaders() if key not in remove} + + @staticmethod + def create_client(options: JsonObject) -> http.client.HTTPConnection: + opt_address = get_str(options, 'address', 'localhost') + opt_tls = get_dict(options, 'tls', None) + opt_unix = get_str(options, 'unix', None) + opt_port = get_int(options, 'port', None) + + if opt_tls is not None and opt_unix is not None: + raise ChannelError('protocol-error', message='TLS on Unix socket is not supported') + if opt_port is None and opt_unix is None: + raise ChannelError('protocol-error', message='no "port" or "unix" option for channel') + if opt_port is not None and opt_unix is not None: + raise ChannelError('protocol-error', message='cannot specify both "port" and "unix" options') if opt_tls is not None: - if 'authority' in opt_tls: - if 'data' in opt_tls['authority']: - context = ssl.create_default_context(cadata=opt_tls['authority']['data']) + authority = get_dict(opt_tls, 'authority', None) + if authority is not None: + data = get_str(authority, 'data', None) + if data is not None: + context = ssl.create_default_context(cadata=data) else: - context = ssl.create_default_context(cafile=opt_tls['authority']['file']) + context = ssl.create_default_context(cafile=get_str(authority, 'file')) else: context = ssl.create_default_context() @@ -50,115 +70,89 @@ def create_connection(self): context.check_hostname = False context.verify_mode = ssl.VerifyMode.CERT_NONE - connection = http.client.HTTPSConnection(opt_address, opt_port, context=context) + # See https://github.com/python/typeshed/issues/11057 + return http.client.HTTPSConnection(opt_address, port=opt_port, context=context) # type: ignore[arg-type] + + else: + return http.client.HTTPConnection(opt_address, port=opt_port) + + @staticmethod + def connect(connection: http.client.HTTPConnection, opt_unix: 'str | None') -> None: + # Blocks. Runs in a thread. + if opt_unix: + # create the connection's socket so that it won't call .connect() internally (which only supports TCP) + connection.sock = socket.socket(socket.AF_UNIX) + connection.sock.connect(opt_unix) else: - connection = http.client.HTTPConnection(opt_address, opt_port) + # explicitly call connect(), so that we can do proper error handling + connection.connect() + + @staticmethod + def request( + connection: http.client.HTTPConnection, method: str, path: str, headers: 'dict[str, str]', body: bytes + ) -> http.client.HTTPResponse: + # Blocks. Runs in a thread. + connection.request(method, path, headers=headers or {}, body=body) + return connection.getresponse() + + async def run(self, options: JsonObject) -> None: + logger.debug('open %s', options) - try: - if opt_unix: - # create the connection's socket so that it won't call .connect() internally (which only supports TCP) - connection.sock = socket.socket(socket.AF_UNIX) - connection.sock.connect(opt_unix) - else: - # explicitly call connect(), so that we can do proper error handling - connection.connect() - except (OSError, IOError) as e: - logger.error('Failed to open %s:%s: %s %s', opt_address, opt_port or opt_unix, type(e), e) - problem = 'unknown-hostkey' if isinstance(e, ssl.SSLCertVerificationError) else 'not-found' - self.close(problem=problem, message=str(e)) - return None + binary = get_str(options, 'binary', None) + method = get_str(options, 'method') + path = get_str(options, 'path') + headers = get_object(options, 'headers', lambda d: {k: typechecked(v, str) for k, v in d.items()}, None) - return connection + if 'connection' in options: + raise ChannelError('protocol-error', message='connection sharing is not implemented on this bridge') - def read_send_response(self, response): - """Completely read the response and send it to the channel""" + loop = asyncio.get_running_loop() + connection = self.create_client(options) + self.ready() + + body = b'' while True: - # we want to stream data blocks as soon as they come in - block = response.read1(4096) - if not block: - logger.debug('reading response done') - # this returns immediately and does not read anything more, but updates the http.client's - # internal state machine to "response done" - block = response.read() - assert block == b'' + data = await self.read() + if data == b'': break - logger.debug('read block of size %i', len(block)) - self.loop.call_soon_threadsafe(self.send_data, block) - - def parse_headers(self, http_msg): - headers = dict(http_msg) - remove = ['Connection', 'Transfer-Encoding'] - if self.options.get('binary'): - remove = ['Content-Length', 'Range'] - for h in remove: - try: - del headers[h] - except KeyError: - pass - return headers - - def request(self): - connection = self.create_connection() - if not connection: - # make_connection does the error reporting - return - - connection.request(self.options.get('method'), - self.options.get('path'), - headers=self.options.get('headers') or {}, - body=self.body) - try: - response = connection.getresponse() - self.loop.call_soon_threadsafe(lambda: self.send_control( - command='response', status=response.status, reason=response.reason, - headers=self.parse_headers(response.headers))) - self.read_send_response(response) - except (http.client.HTTPException, OSError) as error: - msg = str(error) - logger.debug('HTTP reading response failed: %s', msg) - self.loop.call_soon_threadsafe(lambda: self.close(problem='terminated', message=msg)) - return - finally: - connection.close() - - self.loop.call_soon_threadsafe(self.done) - self.loop.call_soon_threadsafe(self.close) - logger.debug('closed') - - def do_open(self, options): - logger.debug('open %s', options) - # TODO: generic JSON validation - if not options.get('method'): - self.close(problem='protocol-error', message='missing or empty "method" field in HTTP stream request') - return - if options.get('path') is None: - self.close(problem='protocol-error', message='missing "path" field in HTTP stream request') - return - if options.get('tls') is not None and options.get('unix'): - self.close(problem='protocol-error', message='TLS on Unix socket is not supported') - return - if options.get('connection') is not None: - self.close(problem='protocol-error', message='connection sharing is not implemented on this bridge') - return - - opt_port = options.get('port') - opt_unix = options.get('unix') - if opt_port is None and opt_unix is None: - self.close(problem='protocol-error', message='no "port" or "unix" option for channel') - return - if opt_port is not None and opt_unix is not None: - self.close(problem='protocol-error', message='cannot specify both "port" and "unix" options') - return + body += data - self.options = options - self.body = b'' + # Connect in a thread and handle errors + try: + await loop.run_in_executor(None, self.connect, connection, get_str(options, 'unix', None)) + except ssl.SSLCertVerificationError as exc: + raise ChannelError('unknown-hostkey', message=str(exc)) from exc + except (OSError, IOError) as exc: + raise ChannelError('not-found', message=str(exc)) from exc - self.ready() + # Submit request in a thread and handle errors + try: + response = await loop.run_in_executor(None, self.request, connection, method, path, headers or {}, body) + except (http.client.HTTPException, OSError) as exc: + raise ChannelError('terminated', message=str(exc)) from exc - def do_data(self, data): - self.body += data + self.send_control(command='response', + status=response.status, + reason=response.reason, + headers=self.get_headers(response, binary)) - def do_done(self): - self.loop = asyncio.get_running_loop() - threading.Thread(target=self.request, daemon=True).start() + # Receive the body and finish up + try: + while True: + block = await loop.run_in_executor(None, response.read1, self.BLOCK_SIZE) + if not block: + break + await self.write(block) + + logger.debug('reading response done') + # this returns immediately and does not read anything more, but updates the http.client's + # internal state machine to "response done" + block = response.read() + assert block == b'' + + await loop.run_in_executor(None, connection.close) + except (http.client.HTTPException, OSError) as exc: + raise ChannelError('terminated', message=str(exc)) from exc + + self.done() From 6606fe29aaf79e246285282eabb36c1d772eed89 Mon Sep 17 00:00:00 2001 From: Allison Karlitskaya Date: Wed, 22 Nov 2023 12:10:56 +0100 Subject: [PATCH 3/4] bridge: unify ChannelError and CockpitProblem These two exception types are extremely similar and would benefit from being unified. In order to do that, we need to make `Channel.close()` operate on a dictionary (as is found in CockpitProblem.attrs) rather than a set of kwargs (as was previously stored on ChannelError). Add a couple of type signatures around to help us make sure we get this right. --- src/cockpit/channel.py | 63 +++++++++++++++--------------- src/cockpit/channels/dbus.py | 2 +- src/cockpit/channels/filesystem.py | 2 +- src/cockpit/protocol.py | 3 +- test/pytest/test_peer.py | 2 +- 5 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/cockpit/channel.py b/src/cockpit/channel.py index 7c27dee6ee8..dcf0b68409c 100644 --- a/src/cockpit/channel.py +++ b/src/cockpit/channel.py @@ -21,6 +21,7 @@ from typing import BinaryIO, ClassVar, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type from .jsonutil import JsonDocument, JsonError, JsonObject, create_object, get_bool, get_str +from .protocol import CockpitProblem from .router import Endpoint, Router, RoutingRule logger = logging.getLogger(__name__) @@ -77,10 +78,8 @@ def shutdown(self): pass # we don't hold any state -class ChannelError(Exception): - def __init__(self, problem, **kwargs): - super().__init__(f'ChannelError {problem}') - self.kwargs = dict(kwargs, problem=problem) +class ChannelError(CockpitProblem): + pass class Channel(Endpoint): @@ -131,19 +130,19 @@ def do_control(self, command, message): elif command == 'options': self.do_options(message) - def do_channel_control(self, channel, command, message): + def do_channel_control(self, channel: str, command: str, message: JsonObject) -> None: # Already closing? Ignore. if self._close_args is not None: return # Catch errors and turn them into close messages try: - self.do_control(command, message) + try: + self.do_control(command, message) + except JsonError as exc: + raise ChannelError('protocol-error', message=str(exc)) from exc except ChannelError as exc: - self.close(**exc.kwargs) - except JsonError as exc: - logger.warning("%s %s %s: %s", self, channel, command, exc) - self.close(problem='protocol-error', message=str(exc)) + self.close(exc.attrs) def do_kill(self, host: Optional[str], group: Optional[str]) -> None: # Already closing? Ignore. @@ -157,27 +156,27 @@ def do_kill(self, host: Optional[str], group: Optional[str]) -> None: self.do_close() # At least this one really ought to be implemented... - def do_open(self, options): + def do_open(self, options: JsonObject) -> None: raise NotImplementedError # ... but many subclasses may reasonably want to ignore some of these. - def do_ready(self): + def do_ready(self) -> None: pass - def do_done(self): + def do_done(self) -> None: pass - def do_close(self): + def do_close(self) -> None: self.close() - def do_options(self, message): + def do_options(self, message: JsonObject) -> None: raise ChannelError('not-supported', message='This channel does not implement "options"') # 'reasonable' default, overridden in other channels for receive-side flow control - def do_ping(self, message): + def do_ping(self, message: JsonObject) -> None: self.send_pong(message) - def do_channel_data(self, channel, data): + def do_channel_data(self, channel: str, data: bytes) -> None: # Already closing? Ignore. if self._close_args is not None: return @@ -186,26 +185,26 @@ def do_channel_data(self, channel, data): try: self.do_data(data) except ChannelError as exc: - self.close(**exc.kwargs) + self.close(exc.attrs) - def do_data(self, _data): + def do_data(self, _data: bytes) -> None: # By default, channels can't receive data. self.close() # output - def ready(self, **kwargs): + def ready(self, **kwargs: JsonDocument) -> None: self.thaw_endpoint() self.send_control(command='ready', **kwargs) - def done(self): + def done(self) -> None: self.send_control(command='done') # tasks and close management def is_closing(self) -> bool: return self._close_args is not None - def _close_now(self): - self.shutdown_endpoint(**self._close_args) + def _close_now(self) -> None: + self.shutdown_endpoint(self._close_args) def _task_done(self, task): # Strictly speaking, we should read the result and check for exceptions but: @@ -228,7 +227,7 @@ def create_task(self, coroutine, name=None): task.add_done_callback(self._task_done) return task - def close(self, **kwargs): + def close(self, close_args: 'JsonObject | None' = None) -> None: """Requests the channel to be closed. After you call this method, you won't get anymore `.do_*()` calls. @@ -239,7 +238,7 @@ def close(self, **kwargs): if self._close_args is not None: # close already requested return - self._close_args = kwargs + self._close_args = close_args or {} if not self._tasks: self._close_now() @@ -323,24 +322,24 @@ async def create_transport(self, loop: asyncio.AbstractEventLoop, options: JsonO """ raise NotImplementedError - def do_open(self, options): + def do_open(self, options: JsonObject) -> None: loop = asyncio.get_running_loop() self._create_transport_task = asyncio.create_task(self.create_transport(loop, options)) self._create_transport_task.add_done_callback(self.create_transport_done) - def create_transport_done(self, task): + def create_transport_done(self, task: 'asyncio.Task[asyncio.Transport]') -> None: assert task is self._create_transport_task self._create_transport_task = None try: transport = task.result() except ChannelError as exc: - self.close(**exc.kwargs) + self.close(exc.attrs) return self.connection_made(transport) self.ready() - def connection_made(self, transport: asyncio.BaseTransport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: assert isinstance(transport, asyncio.Transport) self._transport = transport @@ -348,7 +347,7 @@ def _get_close_args(self) -> JsonObject: return {} def connection_lost(self, exc: Optional[Exception]) -> None: - self.close(**self._get_close_args()) + self.close(self._get_close_args()) def do_data(self, data: bytes) -> None: assert self._transport is not None @@ -449,7 +448,7 @@ async def run_wrapper(self, options): await self.run(options) self.close() except ChannelError as exc: - self.close(**exc.kwargs) + self.close(exc.attrs) async def read(self): while True: @@ -525,4 +524,4 @@ def do_resume_send(self) -> None: pass except StopIteration as stop: self.done() - self.close(**stop.value or {}) + self.close(stop.value) diff --git a/src/cockpit/channels/dbus.py b/src/cockpit/channels/dbus.py index b770ad176a9..664f181b3c3 100644 --- a/src/cockpit/channels/dbus.py +++ b/src/cockpit/channels/dbus.py @@ -261,7 +261,7 @@ async def get_ready(): if self.owner: self.ready(unique_name=self.owner) else: - self.close(problem="not-found") + self.close({'problem': 'not-found'}) self.create_task(get_ready()) else: self.ready() diff --git a/src/cockpit/channels/filesystem.py b/src/cockpit/channels/filesystem.py index 1efb879d9ad..58bcb32af5f 100644 --- a/src/cockpit/channels/filesystem.py +++ b/src/cockpit/channels/filesystem.py @@ -198,7 +198,7 @@ def do_done(self): self._tempfile = None self.done() - self.close(tag=tag_from_path(self._path)) + self.close({'tag': tag_from_path(self._path)}) def do_close(self): if self._tempfile is not None: diff --git a/src/cockpit/protocol.py b/src/cockpit/protocol.py index d65133823a3..734e2938589 100644 --- a/src/cockpit/protocol.py +++ b/src/cockpit/protocol.py @@ -38,10 +38,11 @@ class CockpitProblem(Exception): It is usually thrown in response to some violation of expected protocol when parsing messages, connecting to a peer, or opening a channel. """ + attrs: JsonObject + def __init__(self, problem: str, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None: self.attrs = create_object(_msg, kwargs) self.attrs['problem'] = problem - self.problem = problem super().__init__(get_str(self.attrs, 'message', problem)) diff --git a/test/pytest/test_peer.py b/test/pytest/test_peer.py index cba57540e5c..a4ae02b24f8 100644 --- a/test/pytest/test_peer.py +++ b/test/pytest/test_peer.py @@ -210,5 +210,5 @@ async def do_connect_transport(self) -> None: peer = BrokenPipePeer(specific_error=True) with pytest.raises(ChannelError) as raises: await peer.start() - assert raises.value.kwargs == {'message': 'kaputt', 'problem': 'not-supported'} + assert raises.value.attrs == {'message': 'kaputt', 'problem': 'not-supported'} peer.close() From 53a1cdbd88b7e83fbd522b493525ebe196738368 Mon Sep 17 00:00:00 2001 From: Allison Karlitskaya Date: Wed, 22 Nov 2023 16:28:51 +0100 Subject: [PATCH 4/4] bridge: unify RoutingError and CockpitProblem --- src/cockpit/router.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/cockpit/router.py b/src/cockpit/router.py index 737f33df7e7..526910b7437 100644 --- a/src/cockpit/router.py +++ b/src/cockpit/router.py @@ -20,7 +20,7 @@ from typing import Dict, List, Optional from .jsonutil import JsonDocument, JsonObject -from .protocol import CockpitProtocolError, CockpitProtocolServer +from .protocol import CockpitProblem, CockpitProtocolError, CockpitProtocolServer logger = logging.getLogger(__name__) @@ -104,10 +104,8 @@ def shutdown_endpoint(self, _msg: 'JsonObject | None' = None, **kwargs: JsonDocu self.router.shutdown_endpoint(self, _msg, **kwargs) -class RoutingError(Exception): - def __init__(self, problem, **kwargs): - self.problem = problem - self.kwargs = kwargs +class RoutingError(CockpitProblem): + pass class RoutingRule: @@ -194,7 +192,7 @@ def channel_control_received(self, channel: str, command: str, message: JsonObje logger.debug('Trying to find endpoint for new channel %s payload=%s', channel, message.get('payload')) endpoint = self.check_rules(message) except RoutingError as exc: - self.write_control(command='close', channel=channel, problem=exc.problem, **exc.kwargs) + self.write_control(exc.attrs, command='close', channel=channel) return self.open_channels[channel] = endpoint