Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

router: track endpoints more explicitly #19554

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/cockpit/beiboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,17 @@ def report_exists(files):


class DefaultRoutingRule(RoutingRule):
peer: Peer
peer: 'Peer | None'

def __init__(self, router: Router, peer: Peer):
def __init__(self, router: Router):
super().__init__(router)
self.peer = peer

def apply_rule(self, options: JsonObject) -> Peer:
def apply_rule(self, options: JsonObject) -> 'Peer | None':
return self.peer

def shutdown(self) -> None:
self.peer.close()
if self.peer is not None:
self.peer.close()


class AuthorizeResponder(ferny.AskpassHandler):
Expand Down Expand Up @@ -259,11 +259,15 @@ class SshBridge(Router):
ssh_peer: SshPeer

def __init__(self, args: argparse.Namespace):
self.ssh_peer = SshPeer(self, args.destination, args)
# By default, we route everything to the other host. We add an extra
# routing rule for the packages webserver only if we're running the
# beipack.
rule = DefaultRoutingRule(self)
super().__init__([rule])

super().__init__([
DefaultRoutingRule(self, self.ssh_peer),
])
# This needs to be created after Router.__init__ is called.
self.ssh_peer = SshPeer(self, args.destination, args)
rule.peer = self.ssh_peer

def do_send_init(self):
pass # wait for the peer to do it first
Expand Down
2 changes: 1 addition & 1 deletion src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def is_closing(self) -> bool:
return self._close_args is not None

def _close_now(self):
self.send_control('close', **self._close_args)
self.shutdown_endpoint(**self._close_args)

def _task_done(self, task):
# Strictly speaking, we should read the result and check for exceptions but:
Expand Down
2 changes: 1 addition & 1 deletion src/cockpit/channels/dbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,5 +511,5 @@ def do_data(self, data):
def do_close(self):
for slot in self.matches:
slot.cancel()
self.matches = None # error out
self.matches = []
self.close()
3 changes: 3 additions & 0 deletions src/cockpit/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
assert self.init_future is None
self.write_control(command='kill', host=host, group=group)

def do_close(self) -> None:
self.close()


class ConfiguredPeer(Peer):
config: BridgeConfig
Expand Down
34 changes: 24 additions & 10 deletions src/cockpit/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Endpoint:
__endpoint_frozen_queue: Optional[ExecutionQueue] = None

def __init__(self, router: 'Router'):
router.endpoints[self] = set()
martinpitt marked this conversation as resolved.
Show resolved Hide resolved
self.router = router

def freeze_endpoint(self):
Expand All @@ -75,6 +76,9 @@ def thaw_endpoint(self):
self.__endpoint_frozen_queue = None

# interface for receiving messages
def do_close(self):
raise NotImplementedError

def do_channel_control(self, channel: str, command: str, message: JsonObject) -> None:
raise NotImplementedError

Expand All @@ -94,6 +98,7 @@ def send_channel_message(self, channel: str, **kwargs: JsonDocument) -> None:
def send_channel_control(self, channel, command, **kwargs: JsonDocument) -> None:
self.router.write_control(channel=channel, command=command, **kwargs)
if command == 'close':
self.router.endpoints[self].remove(channel)
martinpitt marked this conversation as resolved.
Show resolved Hide resolved
self.router.drop_channel(channel)

def shutdown_endpoint(self, **kwargs: JsonDocument) -> None:
Expand Down Expand Up @@ -130,13 +135,15 @@ def shutdown(self):
class Router(CockpitProtocolServer):
routing_rules: List[RoutingRule]
open_channels: Dict[str, Endpoint]
endpoints: 'dict[Endpoint, set[str]]'
_eof: bool = False

def __init__(self, routing_rules: List[RoutingRule]):
for rule in routing_rules:
rule.router = self
self.routing_rules = routing_rules
self.open_channels = {}
self.endpoints = {}

def check_rules(self, options: JsonObject) -> Endpoint:
for rule in self.routing_rules:
Expand All @@ -156,19 +163,22 @@ def drop_channel(self, channel: str) -> None:
except KeyError:
logger.error('trying to drop non-existent channel %s from %s', channel, self.open_channels)

# were we waiting to exit?
if not self.open_channels and self._eof and self.transport:
self.transport.close()

def shutdown_endpoint(self, endpoint: Endpoint, **kwargs) -> None:
channels = {key for key, value in self.open_channels.items() if value == endpoint}
channels = self.endpoints.pop(endpoint)
logger.debug('shutdown_endpoint(%s, %s) will close %s', endpoint, kwargs, channels)
for channel in channels:
self.write_control(command='close', channel=channel, **kwargs)
self.drop_channel(channel)

# were we waiting to exit?
if self._eof:
logger.debug(' %d endpoints remaining', len(self.endpoints))
if not self.endpoints and self.transport:
logger.debug(' close transport')
self.transport.close()

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
endpoints = set(self.open_channels.values())
endpoints = set(self.endpoints)
logger.debug('do_kill(%s, %s). Considering %d endpoints.', host, group, len(endpoints))
for endpoint in endpoints:
endpoint.do_kill(host, group)
Expand All @@ -189,6 +199,7 @@ def channel_control_received(self, channel: str, command: str, message: JsonObje
return

self.open_channels[channel] = endpoint
self.endpoints[endpoint].add(channel)
else:
try:
endpoint = self.open_channels[channel]
Expand All @@ -208,12 +219,15 @@ def channel_data_received(self, channel: str, data: bytes) -> None:
endpoint.do_channel_data(channel, data)

def eof_received(self) -> bool:
self._eof = True
logger.debug('eof_received(%r)', self)

for channel, endpoint in list(self.open_channels.items()):
endpoint.do_channel_control(channel, 'close', {'command': 'close', 'channel': channel})
endpoints = set(self.endpoints)
for endpoint in endpoints:
endpoint.do_close()

return bool(self.open_channels)
self._eof = True
logger.debug(' endpoints remaining: %r', self.endpoints)
return bool(self.endpoints)

def do_closed(self, exc: Optional[Exception]) -> None:
for rule in self.routing_rules:
Expand Down