From b9cd39ab900fd913e8ade77475f1fce5e10b1ea7 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 19 Nov 2024 22:14:47 +0000 Subject: [PATCH] AIP-72: Extract `WatchedSubprocess` code into more methods (#44201) Refactord the `WatchedSubprocess` class to multiple methods for just making it a little more easier (for me) to understand flow of code. --- .../airflow/sdk/execution_time/supervisor.py | 172 +++++++++++------- .../tests/execution_time/test_task_runner.py | 3 + 2 files changed, 106 insertions(+), 69 deletions(-) diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index f2715ad3e5da9..1d72f7be633e3 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -284,7 +284,7 @@ def start( if pid == 0: # Parent ends of the sockets are closed by the OS as they are set as non-inheritable - # Run the child entryoint + # Run the child entrypoint _fork_main(child_stdin, child_stdout, child_stderr, child_logs.fileno(), target) proc = cls( @@ -308,40 +308,56 @@ def start( proc.kill(signal.SIGKILL) raise - # TODO: Use logging providers to handle the chunked upload for us - task_logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() + proc._register_pipes(read_msgs, read_logs) - # proc.selector is a way of registering a handler/callback to be called when the given IO channel has + # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the + # other end of the pair open + proc._close_unused_sockets(child_stdout, child_stdin, child_comms, child_logs) + + # Tell the task process what it needs to do! + proc._send_startup_message(ti, path, child_comms) + return proc + + def _register_pipes(self, read_msgs, read_logs): + """Register handlers for subprocess communication channels.""" + # self.selector is a way of registering a handler/callback to be called when the given IO channel has # activity to read on (https://www.man7.org/linux/man-pages/man2/select.2.html etc, but better # alternatives are used automatically) -- this is a way of having "event-based" code, but without # needing full async, to read and process output from each socket as it is received. - cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stdout"), level=logging.INFO)) - proc.selector.register(read_stdout, selectors.EVENT_READ, cb) - - cb = make_buffered_socket_reader(forward_to_log(task_logger.bind(chan="stderr"), level=logging.ERROR)) - proc.selector.register(read_stderr, selectors.EVENT_READ, cb) + # TODO: Use logging providers to handle the chunked upload for us + logger: FilteringBoundLogger = structlog.get_logger(logger_name="task").bind() - proc.selector.register( - read_logs, + self.selector.register( + self.stdout, selectors.EVENT_READ, self._create_socket_handler(logger, "stdout") + ) + self.selector.register( + self.stderr, selectors.EVENT_READ, - make_buffered_socket_reader(process_log_messages_from_subprocess(task_logger)), + self._create_socket_handler(logger, "stderr", log_level=logging.ERROR), ) - proc.selector.register( - read_msgs, + self.selector.register( + read_logs, selectors.EVENT_READ, - make_buffered_socket_reader(proc.handle_requests(log=log)), + make_buffered_socket_reader(process_log_messages_from_subprocess(logger)), + ) + self.selector.register( + read_msgs, selectors.EVENT_READ, make_buffered_socket_reader(self.handle_requests(log)) ) - # Close the remaining parent-end of the sockets we've passed to the child via fork. We still have the - # other end of the pair open - child_stdout.close() - child_stdin.close() - child_comms.close() - child_logs.close() + @staticmethod + def _create_socket_handler(logger, channel, log_level=logging.INFO) -> Callable[[socket], bool]: + """Create a socket handler that forwards logs to a logger.""" + return make_buffered_socket_reader(forward_to_log(logger.bind(chan=channel), level=log_level)) - # Tell the task process what it needs to do! + @staticmethod + def _close_unused_sockets(*sockets): + """Close unused ends of sockets after fork.""" + for sock in sockets: + sock.close() + def _send_startup_message(self, ti: TaskInstance, path: str | os.PathLike[str], child_comms: socket): + """Send startup message to the subprocess.""" msg = StartupDetails( ti=ti, file=str(path), @@ -350,10 +366,8 @@ def start( # Send the message to tell the process what it needs to execute log.debug("Sending", msg=msg) - feed_stdin.write(msg.model_dump_json().encode()) - feed_stdin.write(b"\n") - - return proc + self.stdin.write(msg.model_dump_json().encode()) + self.stdin.write(b"\n") def kill(self, signal: signal.Signals = signal.SIGINT): if self._exit_code is not None: @@ -366,58 +380,78 @@ def wait(self) -> int: if self._exit_code is not None: return self._exit_code - # Until we have a selector for the process, don't poll for more than 10s, just in case it exists but - # doesn't produce any output - max_poll_interval = 10 - try: - while self._exit_code is None or len(self.selector.get_map()): - last_heartbeat_ago = time.monotonic() - self._last_heartbeat - # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible - # so we notice the subprocess finishing as quick as we can. - max_wait_time = max( - 0, # Make sure this value is never negative, - min( - # Ensure we heartbeat _at most_ 75% through time the zombie threshold time - SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75, - max_poll_interval, - ), - ) - events = self.selector.select(timeout=max_wait_time) - for key, _ in events: - socket_handler = key.data - need_more = socket_handler(key.fileobj) - - if not need_more: - self.selector.unregister(key.fileobj) - key.fileobj.close() # type: ignore[union-attr] - - if self._exit_code is None: - try: - self._exit_code = self._process.wait(timeout=0) - log.debug("Task process exited", exit_code=self._exit_code) - except psutil.TimeoutExpired: - pass - - if last_heartbeat_ago < FASTEST_HEARTBEAT_INTERVAL: - # Avoid heartbeating too frequently - continue - - try: - self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid) - self._last_heartbeat = time.monotonic() - except Exception: - log.warning("Couldn't heartbeat", exc_info=True) - # TODO: If we couldn't heartbeat for X times the interval, kill ourselves - pass + self._monitor_subprocess() finally: self.selector.close() + # self._monitor_subprocess() will set the exit code when the process has finished + # If it hasn't, assume it's failed + self._exit_code = self._exit_code if self._exit_code is not None else 1 + self.client.task_instances.finish( id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) ) return self._exit_code + def _monitor_subprocess(self): + """ + Monitor the subprocess until it exits. + + This function: + + - Polls the subprocess for output + - Sends heartbeats to the client to keep the task alive + - Checks if the subprocess has exited + """ + # Until we have a selector for the process, don't poll for more than 10s, just in case it exists but + # doesn't produce any output + max_poll_interval = 10 + + while self._exit_code is None or len(self.selector.get_map()): + last_heartbeat_ago = time.monotonic() - self._last_heartbeat + # Monitor the task to see if it's done. Wait in a syscall (`select`) for as long as possible + # so we notice the subprocess finishing as quick as we can. + max_wait_time = max( + 0, # Make sure this value is never negative, + min( + # Ensure we heartbeat _at most_ 75% through time the zombie threshold time + SLOWEST_HEARTBEAT_INTERVAL - last_heartbeat_ago * 0.75, + max_poll_interval, + ), + ) + events = self.selector.select(timeout=max_wait_time) + for key, _ in events: + socket_handler = key.data + need_more = socket_handler(key.fileobj) + + if not need_more: + self.selector.unregister(key.fileobj) + key.fileobj.close() # type: ignore[union-attr] + + self._check_subprocess_exit() + self._send_heartbeat_if_needed() + + def _check_subprocess_exit(self): + """Check if the subprocess has exited.""" + if self._exit_code is None: + try: + self._exit_code = self._process.wait(timeout=0) + log.debug("Task process exited", exit_code=self._exit_code) + except psutil.TimeoutExpired: + pass + + def _send_heartbeat_if_needed(self): + """Send a heartbeat to the client if heartbeat interval has passed.""" + if time.monotonic() - self._last_heartbeat >= FASTEST_HEARTBEAT_INTERVAL: + try: + self.client.task_instances.heartbeat(self.ti_id, pid=self._process.pid) + self._last_heartbeat = time.monotonic() + except Exception: + log.warning("Failed to send heartbeat", exc_info=True) + # TODO: If we couldn't heartbeat for X times the interval, kill ourselves + pass + @property def final_state(self): """ diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 40c112170c6cd..7f2ea2060f10c 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -24,6 +24,7 @@ import pytest from uuid6 import uuid7 +from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.execution_time.comms import StartupDetails from airflow.sdk.execution_time.task_runner import CommsDecoder, parse @@ -70,3 +71,5 @@ def test_parse(test_dags_dir: Path): assert ti.task assert ti.task.dag + assert isinstance(ti.task, BaseOperator) + assert isinstance(ti.task.dag, DAG)