diff --git a/src/cockpit/channel.py b/src/cockpit/channel.py index f03cf5a49a0..9589943174e 100644 --- a/src/cockpit/channel.py +++ b/src/cockpit/channel.py @@ -144,7 +144,7 @@ def do_channel_control(self, channel: str, command: str, message: JsonObject) -> except ChannelError as exc: self.close(exc.attrs) - def do_kill(self, host: Optional[str], group: Optional[str]) -> None: + def do_kill(self, host: 'str | None', group: 'str | None', _message: JsonObject) -> None: # Already closing? Ignore. if self._close_args is not None: return diff --git a/src/cockpit/peer.py b/src/cockpit/peer.py index 7c2c7f500e7..3009e3b33fe 100644 --- a/src/cockpit/peer.py +++ b/src/cockpit/peer.py @@ -225,9 +225,9 @@ def do_channel_data(self, channel: str, data: bytes) -> None: assert self.init_future is None self.write_channel_data(channel, data) - def do_kill(self, host: Optional[str], group: Optional[str]) -> None: + def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None: assert self.init_future is None - self.write_control(command='kill', host=host, group=group) + self.write_control(message) def do_close(self) -> None: self.close() diff --git a/src/cockpit/protocol.py b/src/cockpit/protocol.py index e42f226bc7a..23b930a1633 100644 --- a/src/cockpit/protocol.py +++ b/src/cockpit/protocol.py @@ -19,9 +19,8 @@ import json import logging import uuid -from typing import Dict, Optional -from .jsonutil import JsonError, JsonObject, JsonValue, create_object, get_str, typechecked +from .jsonutil import JsonError, JsonObject, JsonValue, create_object, get_int, get_str, typechecked logger = logging.getLogger(__name__) @@ -47,7 +46,7 @@ def __init__(self, problem: str, _msg: 'JsonObject | None' = None, **kwargs: Jso class CockpitProtocolError(CockpitProblem): - def __init__(self, message, problem='protocol-error'): + def __init__(self, message: str, problem: str = 'protocol-error'): super().__init__(problem, message=message) @@ -57,14 +56,15 @@ class CockpitProtocol(asyncio.Protocol): We need to use this because Python's SelectorEventLoop doesn't supported buffered protocols. """ - transport: Optional[asyncio.Transport] = None + transport: 'asyncio.Transport | None' = None buffer = b'' _closed: bool = False + _communication_done: 'asyncio.Future[None] | None' = None def do_ready(self) -> None: pass - def do_closed(self, exc: Optional[Exception]) -> None: + def do_closed(self, exc: 'Exception | None') -> None: pass def transport_control_received(self, command: str, message: JsonObject) -> None: @@ -87,7 +87,7 @@ def frame_received(self, frame: bytes) -> None: else: self.control_received(data) - def control_received(self, data: bytes): + def control_received(self, data: bytes) -> None: try: message = typechecked(json.loads(data), dict) command = get_str(message, 'command') @@ -103,52 +103,40 @@ def control_received(self, data: bytes): except (json.JSONDecodeError, JsonError) as exc: raise CockpitProtocolError(f'control message: {exc!s}') from exc - def consume_one_frame(self, view): + def consume_one_frame(self, data: bytes) -> int: """Consumes a single frame from view. Returns positive if a number of bytes were consumed, or negative if no work can be done because of a given number of bytes missing. """ - # Nothing to look at? Save ourselves the trouble... - if not view: - return 0 - - view = bytes(view) - # We know the length + newline is never more than 10 bytes, so just - # slice that out and deal with it directly. We don't have .index() on - # a memoryview, for example. - # From a performance standpoint, hitting the exception case is going to - # be very rare: we're going to receive more than the first few bytes of - # the packet in the regular case. The more likely situation is where - # we get "unlucky" and end up splitting the header between two read()s. - header = bytes(view[:10]) try: - newline = header.index(b'\n') + newline = data.index(b'\n') except ValueError as exc: - if len(header) < 10: + if len(data) < 10: # Let's try reading more - return len(header) - 10 + return len(data) - 10 raise CockpitProtocolError("size line is too long") from exc try: - length = int(header[:newline]) + length = int(data[:newline]) except ValueError as exc: raise CockpitProtocolError("frame size is not an integer") from exc start = newline + 1 end = start + length - if end > len(view): + if end > len(data): # We need to read more - return len(view) - end + return len(data) - end # We can consume a full frame - self.frame_received(view[start:end]) + self.frame_received(data[start:end]) return end - def connection_made(self, transport): + def connection_made(self, transport: asyncio.BaseTransport) -> None: logger.debug('connection_made(%s)', transport) + assert isinstance(transport, asyncio.Transport) self.transport = transport self.do_ready() @@ -156,13 +144,13 @@ def connection_made(self, transport): logger.debug(' but the protocol already was closed, so closing transport') transport.close() - def connection_lost(self, exc): + def connection_lost(self, exc: 'Exception | None') -> None: logger.debug('connection_lost') assert self.transport is not None self.transport = None self.close(exc) - def close(self, exc: Optional[Exception] = None) -> None: + def close(self, exc: 'Exception | None' = None) -> None: if self._closed: return self._closed = True @@ -172,7 +160,7 @@ def close(self, exc: Optional[Exception] = None) -> None: self.do_closed(exc) - def write_channel_data(self, channel, payload): + def write_channel_data(self, channel: str, payload: bytes) -> None: """Send a given payload (bytes) on channel (string)""" # Channel is certainly ascii (as enforced by .encode() below) frame_length = len(channel + '\n') + len(payload) @@ -189,10 +177,10 @@ def write_control(self, _msg: 'JsonObject | None' = None, **kwargs: JsonValue) - pretty = json.dumps(create_object(_msg, kwargs), indent=2) + '\n' self.write_channel_data('', pretty.encode()) - def data_received(self, data): + def data_received(self, data: bytes) -> None: try: self.buffer += data - while True: + while self.buffer: result = self.consume_one_frame(self.buffer) if result <= 0: return @@ -200,47 +188,38 @@ def data_received(self, data): except CockpitProtocolError as exc: self.close(exc) - def eof_received(self) -> Optional[bool]: + def eof_received(self) -> bool: return False # Helpful functionality for "server"-side protocol implementations class CockpitProtocolServer(CockpitProtocol): - init_host: Optional[str] = None - authorizations: Optional[Dict[str, asyncio.Future]] = None + init_host: 'str | None' = None + authorizations: 'dict[str, asyncio.Future[str]] | None' = None - def do_send_init(self): + def do_send_init(self) -> None: raise NotImplementedError - def do_init(self, message): + def do_init(self, message: JsonObject) -> None: pass - def do_kill(self, host: Optional[str], group: Optional[str]) -> None: + def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None: raise NotImplementedError - def transport_control_received(self, command, message): + def transport_control_received(self, command: str, message: JsonObject) -> None: if command == 'init': - try: - if int(message['version']) != 1: - raise CockpitProtocolError('incorrect version number', 'protocol-error') - except KeyError as exc: - raise CockpitProtocolError('version field is missing', 'protocol-error') from exc - except ValueError as exc: - raise CockpitProtocolError('version field is not an int', 'protocol-error') from exc - - try: - self.init_host = message['host'] - except KeyError as exc: - raise CockpitProtocolError('missing host field', 'protocol-error') from exc + if get_int(message, 'version') != 1: + raise CockpitProtocolError('incorrect version number') + self.init_host = get_str(message, 'host') self.do_init(message) elif command == 'kill': - self.do_kill(message.get('host'), message.get('group')) + self.do_kill(get_str(message, 'host', None), get_str(message, 'group', None), message) elif command == 'authorize': self.do_authorize(message) else: raise CockpitProtocolError(f'unexpected control message {command} received') - def do_ready(self): + def do_ready(self) -> None: self.do_send_init() # authorize request/response API @@ -259,11 +238,8 @@ async def request_authorization( self.authorizations.pop(cookie) def do_authorize(self, message: JsonObject) -> None: - cookie = message.get('cookie') - response = message.get('response') - - if not isinstance(cookie, str) or not isinstance(response, str): - raise CockpitProtocolError('invalid authorize response') + cookie = get_str(message, 'cookie') + response = get_str(message, 'response') if self.authorizations is None or cookie not in self.authorizations: logger.warning('no matching authorize request') diff --git a/src/cockpit/remote.py b/src/cockpit/remote.py index 51065b5849e..ccbdd22993e 100644 --- a/src/cockpit/remote.py +++ b/src/cockpit/remote.py @@ -147,11 +147,11 @@ async def do_connect_transport(self) -> None: args = self.session.wrap_subprocess_args(['cockpit-bridge']) await self.spawn(args, []) - def do_kill(self, host: Optional[str], group: Optional[str]) -> None: + def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None: if host == self.host: self.close() elif host is None: - super().do_kill(None, group) + super().do_kill(host, group, message) def do_authorize(self, message: JsonObject) -> None: if get_str(message, 'challenge').startswith('plain1:'): diff --git a/src/cockpit/router.py b/src/cockpit/router.py index 5252567140e..884682f0502 100644 --- a/src/cockpit/router.py +++ b/src/cockpit/router.py @@ -86,7 +86,7 @@ def do_channel_control(self, channel: str, command: str, message: JsonObject) -> def do_channel_data(self, channel: str, data: bytes) -> None: raise NotImplementedError - def do_kill(self, host: Optional[str], group: Optional[str]) -> None: + def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None: raise NotImplementedError # interface for sending messages @@ -185,11 +185,11 @@ def shutdown_endpoint(self, endpoint: Endpoint, _msg: 'JsonObject | None' = None logger.debug(' close transport') self.transport.close() - def do_kill(self, host: Optional[str], group: Optional[str]) -> None: + def do_kill(self, host: 'str | None', group: 'str | None', message: JsonObject) -> None: endpoints = set(self.endpoints) logger.debug('do_kill(%s, %s). Considering %d endpoints.', host, group, len(endpoints)) for endpoint in endpoints: - endpoint.do_kill(host, group) + endpoint.do_kill(host, group, message) def channel_control_received(self, channel: str, command: str, message: JsonObject) -> None: # If this is an open message then we need to apply the routing rules to diff --git a/src/cockpit/transports.py b/src/cockpit/transports.py index faa7aaee349..8aabc7d83d6 100644 --- a/src/cockpit/transports.py +++ b/src/cockpit/transports.py @@ -29,14 +29,14 @@ import struct import subprocess import termios -from typing import Any, ClassVar, Deque, Dict, List, Optional, Sequence, Tuple +from typing import Any, ClassVar, Sequence from .jsonutil import JsonObject, get_int libc6 = ctypes.cdll.LoadLibrary('libc.so.6') -def prctl(*args): +def prctl(*args: int) -> None: if libc6.prctl(*args) != 0: raise OSError('prctl() failed') @@ -55,7 +55,7 @@ class _Transport(asyncio.Transport): _loop: asyncio.AbstractEventLoop _protocol: asyncio.Protocol - _queue: Optional[Deque[bytes]] + _queue: 'collections.deque[bytes] | None' _in_fd: int _out_fd: int _closing: bool @@ -67,7 +67,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, protocol: asyncio.Protocol, in_fd: int = -1, out_fd: int = -1, - extra: Optional[Dict[str, object]] = None): + extra: 'dict[str, object] | None' = None): super().__init__(extra) self._loop = loop @@ -138,7 +138,7 @@ def resume_reading(self) -> None: def _close(self) -> None: pass - def abort(self, exc: Optional[Exception] = None) -> None: + def abort(self, exc: 'Exception | None' = None) -> None: self._closing = True self._close_reader() self._remove_write_queue() @@ -162,10 +162,10 @@ def get_write_buffer_size(self) -> int: return 0 return sum(len(block) for block in self._queue) - def get_write_buffer_limits(self) -> Tuple[int, int]: + def get_write_buffer_limits(self) -> 'tuple[int, int]': return (0, 0) - def set_write_buffer_limits(self, high: Optional[int] = None, low: Optional[int] = None) -> None: + def set_write_buffer_limits(self, high: 'int | None' = None, low: 'int | None' = None) -> None: assert high is None or high == 0 assert low is None or low == 0 @@ -305,11 +305,11 @@ class SubprocessTransport(_Transport, asyncio.SubprocessTransport): data from it, making it available via the .get_stderr() method. """ - _returncode: Optional[int] = None + _returncode: 'int | None' = None - _pty_fd: Optional[int] = None - _process: Optional['subprocess.Popen[bytes]'] = None - _stderr: Optional['Spooler'] + _pty_fd: 'int | None' = None + _process: 'subprocess.Popen[bytes] | None' = None + _stderr: 'Spooler | None' @staticmethod def _create_watcher() -> asyncio.AbstractChildWatcher: @@ -363,11 +363,11 @@ def __init__(self, args: Sequence[str], *, pty: bool = False, - window: Optional[WindowSize] = None, + window: 'WindowSize | None' = None, **kwargs: Any): # go down as a team -- we don't want any leaked processes when the bridge terminates - def preexec_fn(): + def preexec_fn() -> None: prctl(SET_PDEATHSIG, signal.SIGTERM) if pty: fcntl.ioctl(0, termios.TIOCSCTTY, 0) @@ -422,7 +422,7 @@ def get_pid(self) -> int: assert self._process is not None return self._process.pid - def get_returncode(self) -> Optional[int]: + def get_returncode(self) -> 'int | None': return self._returncode def get_pipe_transport(self, fd: int) -> asyncio.Transport: @@ -502,7 +502,7 @@ class Spooler: _loop: asyncio.AbstractEventLoop _fd: int - _contents: List[bytes] + _contents: 'list[bytes]' def __init__(self, loop: asyncio.AbstractEventLoop, fd: int): self._loop = loop diff --git a/test/static-code b/test/static-code index 7f3b136c606..f1eb2bf323b 100755 --- a/test/static-code +++ b/test/static-code @@ -44,6 +44,13 @@ test_ruff() { } if [ "${WITH_PARTIAL_TREE:-0}" = 0 ]; then + mypy_strict_files=' + src/cockpit/__init__.py + src/cockpit/_version.py + src/cockpit/jsonutil.py + src/cockpit/protocol.py + src/cockpit/transports.py + ' test_mypy() { command -v mypy >/dev/null || skip 'no mypy' for pkg in systemd_ctypes ferny bei; do @@ -53,6 +60,7 @@ if [ "${WITH_PARTIAL_TREE:-0}" = 0 ]; then # test scripts individually, to avoid clashing on `__main__` # also skip integration tests, they are too big and not annotated find_scripts 'python3' "*.none" | grep -zv 'test/' | xargs -r -0 -n1 mypy --no-error-summary + mypy --no-error-summary --strict $mypy_strict_files } test_vulture() {