Skip to content

Commit

Permalink
protocol: add better typing hints
Browse files Browse the repository at this point in the history
This file is now clean under `mypy --strict`.
  • Loading branch information
allisonkarlitskaya committed Sep 7, 2023
1 parent 88b8466 commit e46e8b8
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
35 changes: 19 additions & 16 deletions src/cockpit/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from cockpit._vendor import systemd_ctypes

from .jsonutil import JsonError, JsonObject, get_int, get_str, typechecked
from .jsonutil import JsonDocument, JsonError, JsonObject, get_int, get_str, typechecked

logger = logging.getLogger(__name__)

Expand All @@ -40,14 +40,14 @@ class CockpitProblem(Exception):
It is usually thrown in response to some violation of expected protocol
when parsing messages, connecting to a peer, or opening a channel.
"""
def __init__(self, problem: str, **kwargs):
def __init__(self, problem: str, **kwargs: JsonDocument):
super().__init__(kwargs.get('message') or problem)
self.problem = problem
self.kwargs = kwargs


class CockpitProtocolError(CockpitProblem):
def __init__(self, message, problem='protocol-error'):
def __init__(self, message: str, problem: str = 'protocol-error'):
super().__init__(problem, message=message)


Expand All @@ -61,7 +61,7 @@ class CockpitProtocol(asyncio.Protocol):
transport: Optional[asyncio.Transport] = None
buffer = b''
_closed: bool = False
_communication_done: Optional[asyncio.Future] = None
_communication_done: 'Optional[asyncio.Future[None]]' = None

def do_ready(self) -> None:
pass
Expand Down Expand Up @@ -89,7 +89,7 @@ def frame_received(self, frame: bytes) -> None:
else:
self.control_received(data)

def control_received(self, data: bytes):
def control_received(self, data: bytes) -> None:
try:
message = typechecked(json.loads(data), dict)
command = get_str(message, 'command')
Expand Down Expand Up @@ -136,16 +136,17 @@ def consume_one_frame(self, data: bytes) -> int:
self.frame_received(data[start:end])
return end

def connection_made(self, transport):
def connection_made(self, transport: asyncio.BaseTransport) -> None:
logger.debug('connection_made(%s)', transport)
assert isinstance(transport, asyncio.Transport)
self.transport = transport
self.do_ready()

if self._closed:
logger.debug(' but the protocol already was closed, so closing transport')
transport.close()

def connection_lost(self, exc):
def connection_lost(self, exc: Optional[Exception]) -> None:
logger.debug('connection_lost')
assert self.transport is not None
self.transport = None
Expand All @@ -167,7 +168,7 @@ def close(self, exc: Optional[Exception] = None) -> None:
else:
self._communication_done.set_exception(exc)

def write_channel_data(self, channel, payload):
def write_channel_data(self, channel: str, payload: bytes) -> None:
"""Send a given payload (bytes) on channel (string)"""
# Channel is certainly ascii (as enforced by .encode() below)
frame_length = len(channel + '\n') + len(payload)
Expand All @@ -178,7 +179,7 @@ def write_channel_data(self, channel, payload):
else:
logger.debug('cannot write to closed transport')

def write_message(self, _channel, **kwargs):
def write_message(self, _channel: str, **kwargs: JsonDocument) -> None:
"""Format kwargs as a JSON blob and send as a message
Any kwargs with '_' in their names will be converted to '-'
Additionally, any None values will be dropped
Expand All @@ -188,10 +189,10 @@ def write_message(self, _channel, **kwargs):
pretty = CockpitProtocol.json_encoder.encode(kwargs) + '\n'
self.write_channel_data(_channel, pretty.encode('utf-8'))

def write_control(self, **kwargs):
def write_control(self, **kwargs: JsonDocument) -> None:
self.write_message('', **kwargs)

def data_received(self, data):
def data_received(self, data: bytes) -> None:
try:
self.buffer += data
while self.buffer:
Expand All @@ -216,12 +217,12 @@ async def communicate(self) -> None:
# Helpful functionality for "server"-side protocol implementations
class CockpitProtocolServer(CockpitProtocol):
init_host: Optional[str] = None
authorizations: Optional[Dict[str, asyncio.Future]] = None
authorizations: 'Optional[Dict[str, asyncio.Future[str]]]' = None

def do_send_init(self):
def do_send_init(self) -> None:
raise NotImplementedError

def do_init(self, message):
def do_init(self, message: JsonObject) -> None:
pass

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
Expand All @@ -240,11 +241,13 @@ def transport_control_received(self, command: str, message: JsonObject) -> None:
else:
raise CockpitProtocolError(f'unexpected control message {command} received')

def do_ready(self):
def do_ready(self) -> None:
self.do_send_init()

# authorize request/response API
async def request_authorization(self, challenge: str, timeout: Optional[int] = None, **kwargs: object) -> str:
async def request_authorization(
self, challenge: str, timeout: Optional[int] = None, **kwargs: JsonDocument
) -> str:
if self.authorizations is None:
self.authorizations = {}
cookie = str(uuid.uuid4())
Expand Down
2 changes: 1 addition & 1 deletion src/cockpit/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def do_connect_transport(self) -> None:
except ferny.AuthenticationError as exc:
logger.debug('authentication to host %s failed: %s', host, exc)

results = {method: 'not-provided' for method in exc.methods}
results: JsonObject = {method: 'not-provided' for method in exc.methods}
if 'password' in results and self.password is not None:
if responder.password_attempts == 0:
results['password'] = 'not-tried'
Expand Down
1 change: 1 addition & 0 deletions test/static-code
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ if [ "${WITH_PARTIAL_TREE:-0}" = 0 ]; then
src/cockpit/__init__.py
src/cockpit/_version.py
src/cockpit/jsonutil.py
src/cockpit/protocol.py
'
test_mypy() {
command -v mypy >/dev/null || skip 'no mypy'
Expand Down

0 comments on commit e46e8b8

Please sign in to comment.