Skip to content

Commit

Permalink
AIP-72: Extract WatchedSubprocess code into more methods (#44201)
Browse files Browse the repository at this point in the history
Refactord the `WatchedSubprocess` class to multiple methods for just making it a little more easier (for me) to understand flow of code.
  • Loading branch information
kaxil authored Nov 19, 2024
1 parent a825c95 commit b9cd39a
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 69 deletions.
172 changes: 103 additions & 69 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def start(
if pid == 0:
# Parent ends of the sockets are closed by the OS as they are set as non-inheritable

# Run the child entryoint
# Run the child entrypoint
_fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target)

proc = cls(
Expand All @@ -308,40 +308,56 @@ def start(
proc.kill(signal.SIGKILL)
raise

# TODO: Use logging providers to handle the chunked upload for us
task_logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind()
proc._register_pipes(read_msgs, read_logs)

# proc.selector is a way of registering a handler/callback to be called when the given IO channel has
# Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the
# other end of the pair open
proc._close_unused_sockets(child_stdout, child_stdin, child_comms, child_logs)

# Tell the task process what it needs to do!
proc._send_startup_message(ti, path, child_comms)
return proc

def _register_pipes(self, read_msgs, read_logs):
"""Register handlers for subprocess communication channels."""
# self.selector is a way of registering a handler/callback to be called when the given IO channel has
# activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better
# alternatives are used automatically) -- this is a way of having "event-based" code, but without
# needing full async, to read and process output from each socket as it is received.

cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stdout"), level=logging.INFO))
proc.selector.register(read_stdout, selectors.EVENT_READ, cb)

cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stderr"), level=logging.ERROR))
proc.selector.register(read_stderr, selectors.EVENT_READ, cb)
# TODO: Use logging providers to handle the chunked upload for us
logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind()

proc.selector.register(
read_logs,
self.selector.register(
self.stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout")
)
self.selector.register(
self.stderr,
selectors.EVENT_READ,
make_buffered_socket_reader(process_log_messages_from_subprocess(task_logger)),
self._create_socket_handler(logger, "stderr", log_level=logging.ERROR),
)
proc.selector.register(
read_msgs,
self.selector.register(
read_logs,
selectors.EVENT_READ,
make_buffered_socket_reader(proc.handle_requests(log=log)),
make_buffered_socket_reader(process_log_messages_from_subprocess(logger)),
)
self.selector.register(
read_msgs, selectors.EVENT_READ, make_buffered_socket_reader(self.handle_requests(log))
)

# Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the
# other end of the pair open
child_stdout.close()
child_stdin.close()
child_comms.close()
child_logs.close()
@staticmethod
def _create_socket_handler(logger, channel, log_level=logging.INFO) -> Callable[[socket], bool]:
"""Create a socket handler that forwards logs to a logger."""
return make_buffered_socket_reader(forward_to_log(logger.bind(chan=channel), level=log_level))

# Tell the task process what it needs to do!
@staticmethod
def _close_unused_sockets(*sockets):
"""Close unused ends of sockets after fork."""
for sock in sockets:
sock.close()

def _send_startup_message(self, ti: TaskInstance, path: str | os.PathLike[str], child_comms: socket):
"""Send startup message to the subprocess."""
msg = StartupDetails(
ti=ti,
file=str(path),
Expand All @@ -350,10 +366,8 @@ def start(

# Send the message to tell the process what it needs to execute
log.debug("Sending", msg=msg)
feed_stdin.write(msg.model_dump_json().encode())
feed_stdin.write(b"\n")

return proc
self.stdin.write(msg.model_dump_json().encode())
self.stdin.write(b"\n")

def kill(self, signal: signal.Signals = signal.SIGINT):
if self._exit_code is not None:
Expand All @@ -366,58 +380,78 @@ def wait(self) -> int:
if self._exit_code is not None:
return self._exit_code

# Until we have a selector for the process, don't poll for more than 10s, just in case it exists but
# doesn't produce any output
max_poll_interval = 10

try:
while self._exit_code is None or len(self.selector.get_map()):
last_heartbeat_ago = time.monotonic() - self._last_heartbeat
# Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible
# so we notice the subprocess finishing as quick as we can.
max_wait_time = max(
0, # Make sure this value is never negative,
min(
# Ensure we heartbeat _at most_ 75% through time the zombie threshold time
SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
max_poll_interval,
),
)
events = self.selector.select(timeout=max_wait_time)
for key, _ in events:
socket_handler = key.data
need_more = socket_handler(key.fileobj)

if not need_more:
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]

if self._exit_code is None:
try:
self._exit_code = self._process.wait(timeout=0)
log.debug("Task process exited", exit_code=self._exit_code)
except psutil.TimeoutExpired:
pass

if last_heartbeat_ago < FASTEST_HEARTBEAT_INTERVAL:
# Avoid heartbeating too frequently
continue

try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self._last_heartbeat = time.monotonic()
except Exception:
log.warning("Couldn't heartbeat", exc_info=True)
# TODO: If we couldn't heartbeat for X times the interval, kill ourselves
pass
self._monitor_subprocess()
finally:
self.selector.close()

# self._monitor_subprocess() will set the exit code when the process has finished
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1

self.client.task_instances.finish(
id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc)
)
return self._exit_code

def _monitor_subprocess(self):
"""
Monitor the subprocess until it exits.
This function:
- Polls the subprocess for output
- Sends heartbeats to the client to keep the task alive
- Checks if the subprocess has exited
"""
# Until we have a selector for the process, don't poll for more than 10s, just in case it exists but
# doesn't produce any output
max_poll_interval = 10

while self._exit_code is None or len(self.selector.get_map()):
last_heartbeat_ago = time.monotonic() - self._last_heartbeat
# Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible
# so we notice the subprocess finishing as quick as we can.
max_wait_time = max(
0, # Make sure this value is never negative,
min(
# Ensure we heartbeat _at most_ 75% through time the zombie threshold time
SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75,
max_poll_interval,
),
)
events = self.selector.select(timeout=max_wait_time)
for key, _ in events:
socket_handler = key.data
need_more = socket_handler(key.fileobj)

if not need_more:
self.selector.unregister(key.fileobj)
key.fileobj.close() # type: ignore[union-attr]

self._check_subprocess_exit()
self._send_heartbeat_if_needed()

def _check_subprocess_exit(self):
"""Check if the subprocess has exited."""
if self._exit_code is None:
try:
self._exit_code = self._process.wait(timeout=0)
log.debug("Task process exited", exit_code=self._exit_code)
except psutil.TimeoutExpired:
pass

def _send_heartbeat_if_needed(self):
"""Send a heartbeat to the client if heartbeat interval has passed."""
if time.monotonic() - self._last_heartbeat >= FASTEST_HEARTBEAT_INTERVAL:
try:
self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid)
self._last_heartbeat = time.monotonic()
except Exception:
log.warning("Failed to send heartbeat", exc_info=True)
# TODO: If we couldn't heartbeat for X times the interval, kill ourselves
pass

@property
def final_state(self):
"""
Expand Down
3 changes: 3 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytest
from uuid6 import uuid7

from airflow.sdk import DAG, BaseOperator
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.comms import StartupDetails
from airflow.sdk.execution_time.task_runner import CommsDecoder, parse
Expand Down Expand Up @@ -70,3 +71,5 @@ def test_parse(test_dags_dir: Path):

assert ti.task
assert ti.task.dag
assert isinstance(ti.task, BaseOperator)
assert isinstance(ti.task.dag, DAG)

0 comments on commit b9cd39a

Please sign in to comment.