Skip to content

Commit

Permalink
bridge: clean up handling of control messages
Browse files Browse the repository at this point in the history
Remove the general purpose mechanism for sending JSON-formatted
messages from the protocol level, leaving only the mechanism for sending
control messages.  Non-control JSON-formatted messages (such as those
used for D-Bus replies) are now handled at the channel level and treated
as normal data, subject to flow control.  This removes one of the main
interaction points between endpoints and the router.

Add a pair of functions to jsonutil which define the exact way in which
we intend to handle keyword args when building control messages.  We
make a clear distinction between data which is meant to be handled
"verbatim" and data which is meant to be subject to processing rules
(our '_' to '-' replacements, etc).

Start using the new jsonutil functions for all control message handling.
That means that all control-message-creating paths now offer the
opportunity to pass a verbatim dictionary as well as a set of kwargs.
Propagate this change upwards, throughout. For forwarding messages from
peers this is substantially cleaner because it means we don't rewrite
their messages.  The improved typing results in some extra mypy errors
which requires some hinting at call sites.

One note: ideally, we'd use `/` in the argument list of all of these
functions to ensure that the "verbatim object" field that we add
everywhere can only be passed positionally.  This is unfortunately only
in Python 3.8.  To work around that issue, we add a `_msg` field
everywhere with its name chosen never to clash with an actual kwarg we
might use (such as `message`).  Unfortunately, `mypy` isn't yet
convinced that this is a purely positional argument, and in places that
we write `**kwargs` it has no way to know that one of the kwargs won't
in fact be `_msg`, so it complains that the types don't match.  In the
cases where that happens, we can manually specify `None` to avoid that
problem.
  • Loading branch information
allisonkarlitskaya committed Nov 21, 2023
1 parent d9e56a3 commit 3ec3155
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 62 deletions.
6 changes: 3 additions & 3 deletions src/cockpit/beiboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def do_init(self, message):
if isinstance(message.get('superuser'), dict):
self.write_control(command='superuser-init-done')
message['superuser'] = False
self.ssh_peer.write_control(**message)
self.ssh_peer.write_control(message)


async def run(args) -> None:
Expand All @@ -306,12 +306,12 @@ async def run(args) -> None:
if bridge.packages:
message['packages'] = {p: None for p in bridge.packages.packages}

bridge.write_control(**message)
bridge.write_control(message)
bridge.ssh_peer.thaw_endpoint()
except ferny.InteractionError as exc:
sys.exit(str(exc))
except CockpitProblem as exc:
bridge.write_control(command='init', problem=exc.problem, **exc.kwargs)
bridge.write_control(exc.attrs, command='init')
return

logger.debug('Startup done. Looping until connection closes.')
Expand Down
4 changes: 3 additions & 1 deletion src/cockpit/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ def do_init(self, message: JsonObject) -> None:
def do_send_init(self) -> None:
init_args = {
'capabilities': {'explicit-superuser': True},
'command': 'init',
'os-release': self.get_os_release(),
'version': 1,
}

if self.packages is not None:
init_args['packages'] = {p: None for p in self.packages.packages}

self.write_control(command='init', version=1, **init_args)
self.write_control(init_args)

# PackagesListener interface
def packages_loaded(self) -> None:
Expand Down
14 changes: 7 additions & 7 deletions src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
from typing import BinaryIO, ClassVar, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type

from .jsonutil import JsonError, JsonObject, get_bool, get_str
from .jsonutil import JsonDocument, JsonError, JsonObject, get_bool, get_str, print_object
from .router import Endpoint, Router, RoutingRule

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -276,14 +276,14 @@ def do_resume_send(self) -> None:
"""Called to indicate that the channel may start sending again."""
# change to `raise NotImplementedError` after everyone implements it

def send_message(self, **kwargs):
self.send_channel_message(self.channel, **kwargs)
def send_message(self, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None:
self.send_channel_data(self.channel, print_object(_msg, kwargs))

def send_control(self, command, **kwargs):
self.send_channel_control(self.channel, command=command, **kwargs)
def send_control(self, command: str, **kwargs: JsonDocument) -> None:
self.send_channel_control(self.channel, command, None, **kwargs)

def send_pong(self, message):
self.send_channel_control(**dict(message, command='pong'))
def send_pong(self, message: JsonObject) -> None:
self.send_channel_control(self.channel, 'pong', message)


class ProtocolChannel(Channel, asyncio.Protocol):
Expand Down
11 changes: 3 additions & 8 deletions src/cockpit/channels/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union

from ..channel import AsyncChannel, ChannelError
from ..jsonutil import JsonList
from ..samples import SAMPLERS, SampleDescription, Sampler, Samples

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -88,7 +89,7 @@ def parse_options(self, options):
self.samplers = {cls() for cls in sampler_classes}

def send_meta(self, samples: Samples, timestamp: float):
metrics = []
metrics: JsonList = []
for metricinfo in self.metrics:
if metricinfo.desc.instanced:
metrics.append({
Expand All @@ -105,13 +106,7 @@ def send_meta(self, samples: Samples, timestamp: float):
'semantics': metricinfo.desc.semantics
})

meta = {
'timestamp': timestamp * 1000,
'interval': self.interval,
'source': 'internal',
'metrics': metrics
}
self.send_message(**meta)
self.send_message(source='internal', interval=self.interval, timestamp=timestamp * 1000, metrics=metrics)
self.need_meta = False

def sample(self):
Expand Down
2 changes: 1 addition & 1 deletion src/cockpit/channels/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def run(self, options: JsonObject) -> None:
# Note: we can't cache documents right now. See
# https://github.com/cockpit-project/cockpit/issues/19071
# for future plans.
out_headers = {
out_headers: JsonObject = {
'Cache-Control': 'no-cache, no-store',
'Content-Type': document.content_type,
}
Expand Down
39 changes: 39 additions & 0 deletions src/cockpit/jsonutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from enum import Enum
from typing import Callable, Dict, List, Optional, Sequence, Type, TypeVar, Union

from cockpit._vendor import systemd_ctypes

JSON_ENCODER = systemd_ctypes.JSONEncoder(indent=2)

JsonList = List['JsonDocument']
JsonObject = Dict[str, 'JsonDocument']
JsonLiteral = Union[str, float, bool, None]
Expand Down Expand Up @@ -104,3 +108,38 @@ def get_objv(obj: JsonObject, key: str, constructor: Callable[[JsonObject], T])
def as_objv(value: JsonDocument) -> Sequence[T]:
return tuple(constructor(typechecked(item, dict)) for item in typechecked(value, list))
return _get(obj, as_objv, key, ())


def create_object(message: 'JsonObject | None', kwargs: JsonObject) -> JsonObject:
"""Constructs a JSON object based on message and kwargs.
See print_object() for details."""
result = dict(message or {})

for key, value in kwargs.items():
# rewrite '_' (necessary in Python syntax kwargs list) to '-' (idiomatic JSON)
json_key = key.replace('_', '-')
result[json_key] = value

return result


def print_object(message: 'JsonObject | None', kwargs: JsonObject) -> bytes:
"""Construct and pretty-print a JSON object based on message and kwargs
If only message is given, it is pretty-printed, unmodified. If message is
None, it is equivalent to an empty dictionary.
If kwargs are present, then any underscore ('_') present in a key name is
rewritten to a dash ('-'). This is intended to bridge between the required
Python syntax when providing kwargs and idiomatic JSON (which uses '-' for
attributes). These values override values in message. If the value is a
kwarg is None, it is taken as an instruction to delete a key in the
original message, if present.
The idea is that `message` should be used for passing data along, and
kwargs used for data originating at a given call site, possibly including
modifications to an original message.
"""
pretty = JSON_ENCODER.encode(create_object(message, kwargs))
return pretty.encode()
8 changes: 4 additions & 4 deletions src/cockpit/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _connect_task_done(task: asyncio.Task) -> None:
if init_host is not None:
logger.debug(' sending init message back, host %s', init_host)
# Send "init" back
self.write_control(command='init', version=1, host=init_host, **kwargs)
self.write_control(None, command='init', version=1, host=init_host, **kwargs)

# Thaw the queued messages
self.thaw_endpoint()
Expand Down Expand Up @@ -182,7 +182,7 @@ def do_closed(self, exc: Optional[Exception]) -> None:
else:
self.shutdown_endpoint(problem='terminated', message=f'Peer exited with status {exc.exit_code}')
elif isinstance(exc, CockpitProblem):
self.shutdown_endpoint(problem=exc.problem, **exc.kwargs)
self.shutdown_endpoint(exc.attrs)
else:
self.shutdown_endpoint(problem='internal-error',
message=f"[{exc.__class__.__name__}] {exc!s}")
Expand All @@ -209,7 +209,7 @@ def process_exited(self) -> None:
def channel_control_received(self, channel: str, command: str, message: JsonObject) -> None:
if self.init_future is not None:
raise CockpitProtocolError('Received unexpected channel control message before init')
self.send_channel_control(**message)
self.send_channel_control(channel, command, message)

def channel_data_received(self, channel: str, data: bytes) -> None:
if self.init_future is not None:
Expand All @@ -219,7 +219,7 @@ def channel_data_received(self, channel: str, data: bytes) -> None:
# Forwarding data: from the router to the peer
def do_channel_control(self, channel: str, command: str, message: JsonObject) -> None:
assert self.init_future is None
self.write_control(**message)
self.write_control(message)

def do_channel_data(self, channel: str, data: bytes) -> None:
assert self.init_future is None
Expand Down
39 changes: 14 additions & 25 deletions src/cockpit/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
import json
import logging
import uuid
from typing import ClassVar, Dict, Optional
from typing import Dict, Optional

from cockpit._vendor import systemd_ctypes

from .jsonutil import JsonError, JsonObject, get_str, typechecked
from .jsonutil import JsonDocument, JsonError, JsonObject, create_object, get_str, print_object, typechecked

logger = logging.getLogger(__name__)

Expand All @@ -40,10 +38,11 @@ class CockpitProblem(Exception):
It is usually thrown in response to some violation of expected protocol
when parsing messages, connecting to a peer, or opening a channel.
"""
def __init__(self, problem: str, **kwargs):
super().__init__(kwargs.get('message') or problem)
def __init__(self, problem: str, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None:
self.attrs = create_object(_msg, kwargs)
self.attrs['problem'] = problem
self.problem = problem
self.kwargs = kwargs
super().__init__(get_str(self.attrs, 'message', problem))


class CockpitProtocolError(CockpitProblem):
Expand All @@ -57,7 +56,6 @@ class CockpitProtocol(asyncio.Protocol):
We need to use this because Python's SelectorEventLoop doesn't supported
buffered protocols.
"""
json_encoder: ClassVar[json.JSONEncoder] = systemd_ctypes.JSONEncoder(indent=2)
transport: Optional[asyncio.Transport] = None
buffer = b''
_closed: bool = False
Expand Down Expand Up @@ -191,21 +189,10 @@ def write_channel_data(self, channel, payload):
else:
logger.debug('cannot write to closed transport')

def write_message(self, _channel, **kwargs):
"""Format kwargs as a JSON blob and send as a message
Any kwargs with '_' in their names will be converted to '-'
"""
for name in list(kwargs):
if '_' in name:
kwargs[name.replace('_', '-')] = kwargs[name]
del kwargs[name]

logger.debug('sending message %s %s', _channel, kwargs)
pretty = CockpitProtocol.json_encoder.encode(kwargs) + '\n'
self.write_channel_data(_channel, pretty.encode('utf-8'))

def write_control(self, **kwargs):
self.write_message('', **kwargs)
def write_control(self, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None:
"""Write a control message. See print_object() for details."""
logger.debug('sending control message %r %r', _msg, kwargs)
self.write_channel_data('', print_object(_msg, kwargs))

def data_received(self, data):
try:
Expand Down Expand Up @@ -269,14 +256,16 @@ def do_ready(self):
self.do_send_init()

# authorize request/response API
async def request_authorization(self, challenge: str, timeout: Optional[int] = None, **kwargs: object) -> str:
async def request_authorization(
self, challenge: str, timeout: 'int | None' = None, **kwargs: JsonDocument
) -> str:
if self.authorizations is None:
self.authorizations = {}
cookie = str(uuid.uuid4())
future = asyncio.get_running_loop().create_future()
try:
self.authorizations[cookie] = future
self.write_control(command='authorize', challenge=challenge, cookie=cookie, **kwargs)
self.write_control(None, command='authorize', challenge=challenge, cookie=cookie, **kwargs)
return await asyncio.wait_for(future, timeout)
finally:
self.authorizations.pop(cookie)
Expand Down
6 changes: 3 additions & 3 deletions src/cockpit/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def do_connect_transport(self) -> None:
# containing the key that would need to be accepted. That will
# cause the front-end to present a dialog.
_reason, host, algorithm, key, fingerprint = responder.hostkeys_seen[0]
error_args = {'host_key': f'{host} {algorithm} {key}', 'host_fingerprint': fingerprint}
error_args: JsonObject = {'host-key': f'{host} {algorithm} {key}', 'host-fingerprint': fingerprint}
else:
error_args = {}

Expand All @@ -124,12 +124,12 @@ async def do_connect_transport(self) -> None:

logger.debug('SshPeer got a %s %s; private %s, seen hostkeys %r; raising %s with extra args %r',
type(exc), exc, self.private, responder.hostkeys_seen, error, error_args)
raise PeerError(error, error=error, auth_method_results={}, **error_args) from exc
raise PeerError(error, error_args, error=error, auth_method_results={}) from exc

except ferny.SshAuthenticationError as exc:
logger.debug('authentication to host %s failed: %s', host, exc)

results = {method: 'not-provided' for method in exc.methods}
results: JsonObject = {method: 'not-provided' for method in exc.methods}
if 'password' in results and self.password is not None:
if responder.password_attempts == 0:
results['password'] = 'not-tried'
Expand Down
17 changes: 8 additions & 9 deletions src/cockpit/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,16 @@ def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
def send_channel_data(self, channel: str, data: bytes) -> None:
self.router.write_channel_data(channel, data)

def send_channel_message(self, channel: str, **kwargs: JsonDocument) -> None:
self.router.write_message(channel, **kwargs)

def send_channel_control(self, channel, command, **kwargs: JsonDocument) -> None:
self.router.write_control(channel=channel, command=command, **kwargs)
def send_channel_control(
self, channel: str, command: str, _msg: 'JsonObject | None', **kwargs: JsonDocument
) -> None:
self.router.write_control(_msg, channel=channel, command=command, **kwargs)
if command == 'close':
self.router.endpoints[self].remove(channel)
self.router.drop_channel(channel)

def shutdown_endpoint(self, **kwargs: JsonDocument) -> None:
self.router.shutdown_endpoint(self, **kwargs)
def shutdown_endpoint(self, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None:
self.router.shutdown_endpoint(self, _msg, **kwargs)


class RoutingError(Exception):
Expand Down Expand Up @@ -163,11 +162,11 @@ def drop_channel(self, channel: str) -> None:
except KeyError:
logger.error('trying to drop non-existent channel %s from %s', channel, self.open_channels)

def shutdown_endpoint(self, endpoint: Endpoint, **kwargs) -> None:
def shutdown_endpoint(self, endpoint: Endpoint, _msg: 'JsonObject | None' = None, **kwargs: JsonDocument) -> None:
channels = self.endpoints.pop(endpoint)
logger.debug('shutdown_endpoint(%s, %s) will close %s', endpoint, kwargs, channels)
for channel in channels:
self.write_control(command='close', channel=channel, **kwargs)
self.write_control(_msg, command='close', channel=channel, **kwargs)
self.drop_channel(channel)

# were we waiting to exit?
Expand Down
2 changes: 1 addition & 1 deletion src/cockpit/superuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def init(self, params: JsonObject) -> None:
self._init_task = asyncio.create_task(self.go(name, responder))
self._init_task.add_done_callback(self._init_done)

def _init_done(self, task):
def _init_done(self, task: 'asyncio.Task[None]') -> None:
logger.debug('superuser init done! %s', task.exception())
self.router.write_control(command='superuser-init-done')
del self._init_task
Expand Down

0 comments on commit 3ec3155

Please sign in to comment.