Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bridge: Improve handling of control messages #19637

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/cockpit/beiboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def do_init(self, message):
if isinstance(message.get('superuser'), dict):
self.write_control(command='superuser-init-done')
message['superuser'] = False
self.ssh_peer.write_control(**message)
self.ssh_peer.write_control(message)


async def run(args) -> None:
Expand All @@ -306,12 +306,12 @@ async def run(args) -> None:
if bridge.packages:
message['packages'] = {p: None for p in bridge.packages.packages}

bridge.write_control(**message)
bridge.write_control(message)
bridge.ssh_peer.thaw_endpoint()
except ferny.InteractionError as exc:
sys.exit(str(exc))
except CockpitProblem as exc:
bridge.write_control(command='init', problem=exc.problem, **exc.kwargs)
bridge.write_control(exc.attrs, command='init')
return

logger.debug('Startup done. Looping until connection closes.')
Expand Down
4 changes: 3 additions & 1 deletion src/cockpit/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ def do_init(self, message: JsonObject) -> None:
def do_send_init(self) -> None:
init_args = {
'capabilities': {'explicit-superuser': True},
'command': 'init',
'os-release': self.get_os_release(),
'version': 1,
}

if self.packages is not None:
init_args['packages'] = {p: None for p in self.packages.packages}

self.write_control(command='init', version=1, **init_args)
self.write_control(init_args)

# PackagesListener interface
def packages_loaded(self) -> None:
Expand Down
81 changes: 42 additions & 39 deletions src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import asyncio
import json
import logging
from typing import BinaryIO, ClassVar, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type

from .jsonutil import JsonError, JsonObject, get_bool, get_str
from .jsonutil import JsonDocument, JsonError, JsonObject, create_object, get_bool, get_str
from .protocol import CockpitProblem
from .router import Endpoint, Router, RoutingRule

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,10 +78,8 @@ def shutdown(self):
pass # we don't hold any state


class ChannelError(Exception):
def __init__(self, problem, **kwargs):
super().__init__(f'ChannelError {problem}')
self.kwargs = dict(kwargs, problem=problem)
class ChannelError(CockpitProblem):
pass


class Channel(Endpoint):
Expand Down Expand Up @@ -130,19 +130,19 @@ def do_control(self, command, message):
elif command == 'options':
self.do_options(message)

def do_channel_control(self, channel, command, message):
def do_channel_control(self, channel: str, command: str, message: JsonObject) -> None:
# Already closing? Ignore.
if self._close_args is not None:
return

# Catch errors and turn them into close messages
try:
self.do_control(command, message)
try:
self.do_control(command, message)
except JsonError as exc:
raise ChannelError('protocol-error', message=str(exc)) from exc
except ChannelError as exc:
self.close(**exc.kwargs)
except JsonError as exc:
logger.warning("%s %s %s: %s", self, channel, command, exc)
self.close(problem='protocol-error', message=str(exc))
self.close(exc.attrs)

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
# Already closing? Ignore.
Expand All @@ -156,27 +156,27 @@ def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
self.do_close()

# At least this one really ought to be implemented...
def do_open(self, options):
def do_open(self, options: JsonObject) -> None:
raise NotImplementedError

# ... but many subclasses may reasonably want to ignore some of these.
def do_ready(self):
def do_ready(self) -> None:
pass

def do_done(self):
def do_done(self) -> None:
pass

def do_close(self):
def do_close(self) -> None:
self.close()

def do_options(self, message):
def do_options(self, message: JsonObject) -> None:
raise ChannelError('not-supported', message='This channel does not implement "options"')

# 'reasonable' default, overridden in other channels for receive-side flow control
def do_ping(self, message):
def do_ping(self, message: JsonObject) -> None:
self.send_pong(message)

def do_channel_data(self, channel, data):
def do_channel_data(self, channel: str, data: bytes) -> None:
# Already closing? Ignore.
if self._close_args is not None:
return
Expand All @@ -185,26 +185,26 @@ def do_channel_data(self, channel, data):
try:
self.do_data(data)
except ChannelError as exc:
self.close(**exc.kwargs)
self.close(exc.attrs)

def do_data(self, _data):
def do_data(self, _data: bytes) -> None:
# By default, channels can't receive data.
self.close()

# output
def ready(self, **kwargs):
def ready(self, **kwargs: JsonDocument) -> None:
self.thaw_endpoint()
self.send_control(command='ready', **kwargs)

def done(self):
def done(self) -> None:
self.send_control(command='done')

# tasks and close management
def is_closing(self) -> bool:
return self._close_args is not None

def _close_now(self):
self.shutdown_endpoint(**self._close_args)
def _close_now(self) -> None:
self.shutdown_endpoint(self._close_args)

def _task_done(self, task):
# Strictly speaking, we should read the result and check for exceptions but:
Expand All @@ -227,7 +227,7 @@ def create_task(self, coroutine, name=None):
task.add_done_callback(self._task_done)
return task

def close(self, **kwargs):
def close(self, close_args: 'JsonObject | None' = None) -> None:
"""Requests the channel to be closed.

After you call this method, you won't get anymore `.do_*()` calls.
Expand All @@ -238,7 +238,7 @@ def close(self, **kwargs):
if self._close_args is not None:
# close already requested
return
self._close_args = kwargs
self._close_args = close_args or {}
if not self._tasks:
self._close_now()

Expand Down Expand Up @@ -276,14 +276,17 @@ def do_resume_send(self) -> None:
"""Called to indicate that the channel may start sending again."""
# change to `raise NotImplementedError` after everyone implements it

def send_message(self, **kwargs):
self.send_channel_message(self.channel, **kwargs)
json_encoder: ClassVar[json.JSONEncoder] = json.JSONEncoder(indent=2)

def send_json(self, **kwargs: JsonDocument) -> bool:
pretty = self.json_encoder.encode(create_object(None, kwargs)) + '\n'
return self.send_data(pretty.encode())

def send_control(self, command, **kwargs):
self.send_channel_control(self.channel, command=command, **kwargs)
def send_control(self, command: str, **kwargs: JsonDocument) -> None:
self.send_channel_control(self.channel, command, None, **kwargs)

def send_pong(self, message):
self.send_channel_control(**dict(message, command='pong'))
def send_pong(self, message: JsonObject) -> None:
self.send_channel_control(self.channel, 'pong', message)


class ProtocolChannel(Channel, asyncio.Protocol):
Expand Down Expand Up @@ -319,32 +322,32 @@ async def create_transport(self, loop: asyncio.AbstractEventLoop, options: JsonO
"""
raise NotImplementedError

def do_open(self, options):
def do_open(self, options: JsonObject) -> None:
loop = asyncio.get_running_loop()
self._create_transport_task = asyncio.create_task(self.create_transport(loop, options))
self._create_transport_task.add_done_callback(self.create_transport_done)

def create_transport_done(self, task):
def create_transport_done(self, task: 'asyncio.Task[asyncio.Transport]') -> None:
assert task is self._create_transport_task
self._create_transport_task = None
try:
transport = task.result()
except ChannelError as exc:
self.close(**exc.kwargs)
self.close(exc.attrs)
return

self.connection_made(transport)
self.ready()

def connection_made(self, transport: asyncio.BaseTransport):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
assert isinstance(transport, asyncio.Transport)
self._transport = transport

def _get_close_args(self) -> JsonObject:
return {}

def connection_lost(self, exc: Optional[Exception]) -> None:
self.close(**self._get_close_args())
self.close(self._get_close_args())

def do_data(self, data: bytes) -> None:
assert self._transport is not None
Expand Down Expand Up @@ -445,7 +448,7 @@ async def run_wrapper(self, options):
await self.run(options)
self.close()
except ChannelError as exc:
self.close(**exc.kwargs)
self.close(exc.attrs)

async def read(self):
while True:
Expand Down Expand Up @@ -521,4 +524,4 @@ def do_resume_send(self) -> None:
pass
except StopIteration as stop:
self.done()
self.close(**stop.value or {})
self.close(stop.value)
Loading