Skip to content

Commit

Permalink
Add SignalManager
Browse files Browse the repository at this point in the history
  • Loading branch information
filipcacky committed Oct 7, 2024
1 parent 1b2e7dc commit 64e3c7d
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 9 deletions.
1 change: 1 addition & 0 deletions metaflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class and related decorators.
# Runner API
if sys.version_info >= (3, 7):
from .runner.metaflow_runner import Runner
from .runner.signal_manager import SignalManager
from .runner.nbrun import NBRunner
from .runner.deployer import Deployer
from .runner.nbdeploy import NBDeployer
Expand Down
4 changes: 3 additions & 1 deletion metaflow/runner/metaflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .utils import handle_timeout, async_handle_timeout, clear_and_set_os_environ
from .subprocess_manager import CommandManager, SubprocessManager
from .signal_manager import SignalManager


class ExecutingRun(object):
Expand Down Expand Up @@ -231,6 +232,7 @@ def __init__(
env: Optional[Dict] = None,
cwd: Optional[str] = None,
file_read_timeout: int = 3600,
signal_manager: Optional[SignalManager] = None,
**kwargs
):
# these imports are required here and not at the top
Expand All @@ -257,7 +259,7 @@ def __init__(

self.cwd = cwd
self.file_read_timeout = file_read_timeout
self.spm = SubprocessManager()
self.spm = SubprocessManager(signal_manager=signal_manager)
self.top_level_kwargs = kwargs
self.api = MetaflowAPI.from_cli(self.flow_file, start)

Expand Down
128 changes: 128 additions & 0 deletions metaflow/runner/signal_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import asyncio
import signal
from typing import NewType, Mapping, Set, Callable, Optional

SignalHandler = NewType("SignalHandler", Callable[[int, []], None])


class SignalManager:
"""
A context manager for managing signal handlers.
This class works as a context manager, restoring any overwritten
signal handlers when the context is exited. This only works for signals
in a synchronous context (ie. hooked by `signal`).
Parameters
----------
hook_signals : bool
If True, the signal manager will overwrite any existing signal handlers
in either `asyncio` or `signal`. If you already have any signal
handling in place, you can set this to False and use `trigger_signal`
to trigger metaflow-related signal handlers.
event_loop : Optional[asyncio.AbstractEventLoop]
The event loop to use for handling signals.
If None, the current running event loop is used, if any.
"""

hook_signals: bool
event_loop: Optional[asyncio.AbstractEventLoop]
signal_map: Mapping[int, Set[SignalHandler]] = dict()
replaced_signals: Mapping[int, SignalHandler] = dict()

def __init__(
self,
hook_signals: bool = True,
event_loop: Optional[asyncio.AbstractEventLoop] = None,
):
self.hook_signals = hook_signals
try:
self.event_loop = event_loop or asyncio.get_running_loop()
except RuntimeError:
self.event_loop = None

def __exit__(self, exc_type, exc_value, traceback):
for sig in self.signal_map:
self._maybe_remove_signal_handler(sig)

for sig in self.replaced_signals:
signal.signal(sig, self.replaced_signals[sig])

def _handle_signal(self, signum, frame):
for handler in self.signal_map[signum]:
handler(signum, frame)

def _maybe_add_signal_handler(self, sig):
if not self.hook_signals:
return

if self.event_loop is None:
replaced = signal.signal(sig, self._handle_signal)
self.replaced_signals[sig] = replaced

else:
self.event_loop.add_signal_handler(
sig, lambda: self._handle_signal(sig, None)
)

def _maybe_remove_signal_handler(self, sig: int):
if not self.hook_signals:
return

if self.event_loop is None:
signal.signal(sig, self.replaced_signals[sig])
del self.replaced_signals[sig]
else:
self.event_loop.remove_signal_handler(sig)

def add_signal_handler(self, sig: int, handler: SignalHandler):
"""
Add a signal handler for the given signal.
Parameters
----------
sig: int
The signal to handle.
handler: SignalHandler
The handler to call when the signal is received.
"""
if sig not in self.signal_map:
self.signal_map[sig] = set()
self._maybe_add_signal_handler(sig)

self.signal_map[sig].add(handler)

def remove_signal_handler(self, sig: signal.Signals, handler: SignalHandler):
"""
Remove a signal handler for the given signal.
Parameters
----------
sig: int
The signal to handle.
handler: SignalHandler
The handler to remove.
Raises
------
KeyError
If the signal `sig` is not being handled.
"""
if sig not in self.signal_map:
return

self.signal_map[sig].discard(handler)

def trigger_signal(self, sig: int, frame=None):
"""
Trigger a signal handler for the given signal.
Parameters
----------
sig : int
The signal to handle.
frame : [] (optional)
The frame to pass to the signal handler.
Only used in a synchronous context.
"""
self._handle_signal(sig, frame)
19 changes: 11 additions & 8 deletions metaflow/runner/subprocess_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import threading
from typing import Callable, Dict, Iterator, List, Optional, Tuple

from .signal_manager import SignalManager


def kill_process_and_descendants(pid, termination_timeout):
# TODO: there's a race condition that new descendants might
Expand Down Expand Up @@ -73,17 +75,17 @@ class SubprocessManager(object):
CommandManager objects, each of which manages an individual subprocess.
"""

def __init__(self):
def __init__(self, signal_manager: SignalManager):
self.commands: Dict[int, CommandManager] = {}
self.signal_manager = signal_manager or SignalManager()

try:
loop = asyncio.get_running_loop()
loop.add_signal_handler(
if self.signal_manager.event_loop is not None:
self.signal_manager.add_signal_handler(
signal.SIGINT,
lambda: asyncio.create_task(self._async_handle_sigint()),
lambda s, f: asyncio.create_task(self._async_handle_sigint()),
)
except RuntimeError:
signal.signal(signal.SIGINT, self._handle_sigint)
else:
self.signal_manager.add_signal_handler(signal.SIGINT, self._handle_sigint)

async def _async_handle_sigint(self):
pids = [
Expand Down Expand Up @@ -193,7 +195,8 @@ def get(self, pid: int) -> Optional["CommandManager"]:
return self.commands.get(pid, None)

def cleanup(self) -> None:
"""Clean up log files for all running subprocesses."""
"""Clean up signal handler and log files for all running subprocesses."""
self.signal_manager.remove_signal_handler(signal.SIGINT, self.signal_handler)

for v in self.commands.values():
v.cleanup()
Expand Down

0 comments on commit 64e3c7d

Please sign in to comment.