From 6f2e81638ae6ead5b516e06539ec3fe02b2af1a7 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 19 Nov 2024 22:17:44 +0000 Subject: [PATCH] AIP-72: Add a basic test for a task run This PR adds a very basic test to parse & run. We will start adding more things here and porting things from core. --- .../airflow/sdk/execution_time/task_runner.py | 3 +- task_sdk/tests/dags/super_basic_run.py | 35 +++++++++++++++++++ .../tests/execution_time/test_task_runner.py | 15 +++++++- 3 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 task_sdk/tests/dags/super_basic_run.py diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index e00efe4597b7b..0b757f7c128f2 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -164,7 +164,8 @@ def run(ti: RuntimeTaskInstance, log: Logger): except SystemExit: ... except BaseException: - ... + # TODO: Handle TI handle failure + raise def finalize(log: Logger): ... diff --git a/task_sdk/tests/dags/super_basic_run.py b/task_sdk/tests/dags/super_basic_run.py new file mode 100644 index 0000000000000..2988d85418ac0 --- /dev/null +++ b/task_sdk/tests/dags/super_basic_run.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.dag import dag + + +class CustomOperator(BaseOperator): + def execute(self, context): + task_id = context["task_instance"].task_id + print(f"Hello World {task_id}!") + + +@dag() +def super_basic_run(): + CustomOperator(task_id="hello") + + +super_basic_run() diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 7f2ea2060f10c..4666f2049e806 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -20,6 +20,7 @@ import uuid from pathlib import Path from socket import socketpair +from unittest import mock import pytest from uuid6 import uuid7 @@ -27,7 +28,7 @@ from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.execution_time.comms import StartupDetails -from airflow.sdk.execution_time.task_runner import CommsDecoder, parse +from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run class TestCommsDecoder: @@ -73,3 +74,15 @@ def test_parse(test_dags_dir: Path): assert ti.task.dag assert isinstance(ti.task, BaseOperator) assert isinstance(ti.task.dag, DAG) + + +def test_run_basic(test_dags_dir: Path): + """Test running a basic task.""" + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", try_number=1), + file=str(test_dags_dir / "super_basic_run.py"), + requests_fd=0, + ) + + ti = parse(what) + run(ti, log=mock.MagicMock())