diff --git a/.github/workflows/airflow-operator.yml b/.github/workflows/airflow-operator.yml index 4868ed39b58..cc4704cfa56 100644 --- a/.github/workflows/airflow-operator.yml +++ b/.github/workflows/airflow-operator.yml @@ -16,7 +16,6 @@ on: - 'internal/jobservice/*' - 'pkg/api/*.proto' - 'pkg/api/jobservice/*.proto' - - 'scripts/build-airflow-operator.sh' - 'scripts/build-python-client.sh' - 'third_party/airflow/**' - './magefiles/tests.go' @@ -37,7 +36,6 @@ on: - 'internal/jobservice/*' - 'pkg/api/*.proto' - 'pkg/api/jobservice/*.proto' - - 'scripts/build-airflow-operator.sh' - 'scripts/build-python-client.sh' - 'third_party/airflow/**' diff --git a/build/airflow-operator/Dockerfile b/build/airflow-operator/Dockerfile index a3d774b30d6..87a2e81a5cb 100644 --- a/build/airflow-operator/Dockerfile +++ b/build/airflow-operator/Dockerfile @@ -1,5 +1,5 @@ ARG PLATFORM=x86_64 -ARG BASE_IMAGE=python:3.8.18-bookworm +ARG BASE_IMAGE=python:3.10.14-bookworm FROM --platform=$PLATFORM ${BASE_IMAGE} RUN mkdir /proto diff --git a/developer/env/docker/server.env b/developer/env/docker/server.env index 6b52f9b6342..a5b4496abe4 100644 --- a/developer/env/docker/server.env +++ b/developer/env/docker/server.env @@ -1,3 +1,3 @@ ARMADA_QUEUECACHEREFRESHPERIOD="1s" ARMADA_CORSALLOWEDORIGINS="http://localhost:3000,http://localhost:10000,http://example.com:10000" - +ARMADA_QUERYAPI_POSTGRES_CONNECTION_HOST=postgres diff --git a/docker-compose.yaml b/docker-compose.yaml index 4e96216ef27..68af1ce7aa2 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -61,6 +61,7 @@ services: depends_on: - lookoutv2-migration - eventingester + - lookoutingesterv2 working_dir: /app env_file: - developer/env/docker/server.env diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 048667a2562..665d6c0e82c 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -12,66 +12,61 @@ This class provides integration with Airflow and Armada ## armada.operators.armada module -### _class_ armada.operators.armada.ArmadaOperator(name, armada_channel_args, job_service_channel_args, armada_queue, job_request_items, lookout_url_template=None, poll_interval=30, \*\*kwargs) -Bases: `BaseOperator` +### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, \*\*kwargs) +Bases: `BaseOperator`, `LoggingMixin` -Implementation of an ArmadaOperator for airflow. +An Airflow operator that manages Job submission to Armada. -Airflow operators inherit from BaseOperator. +This operator submits a job to an Armada cluster, polls for its completion, +and handles job cancellation if the Airflow task is killed. * **Parameters** - * **name** (*str*) – The name of the airflow task + * **name** (*str*) – - * **armada_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. + * **channel_args** (*GrpcChannelArgs*) – - * **job_service_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. + * **armada_queue** (*str*) – - * **armada_queue** (*str*) – The queue name for Armada. + * **job_request** (*JobSubmitRequestItem*) – - * **job_request_items** (*List**[**JobSubmitRequestItem**]*) – A PodSpec that is used by Armada for submitting a job + * **job_set_prefix** (*str** | **None*) – - * **lookout_url_template** (*str** | **None*) – A URL template to be used to provide users - a valid link to the related lookout job in this operator’s log. - The format should be: - “[https://lookout.armada.domain/jobs](https://lookout.armada.domain/jobs)?job_id=” where will - be replaced with the actual job ID. + * **lookout_url_template** (*str** | **None*) – - * **poll_interval** (*int*) – How often to poll jobservice to get status. + * **poll_interval** (*int*) – + * **container_logs** (*str** | **None*) – -* **Returns** - an armada operator instance + * **k8s_token_retriever** (*TokenRetriever** | **None*) – + * **deferrable** (*bool*) – -#### execute(context) -Executes the Armada Operator. -Runs an Armada job and calls the job_service_client for polling. + * **job_acknowledgement_timeout** (*int*) – -* **Parameters** - **context** – The airflow context. +#### _property_ client(_: ArmadaClien_ ) +#### execute(context) +Submits the job to Armada and polls for completion. -* **Returns** +* **Parameters** - None + **context** (*Context*) – The execution context provided by Airflow. @@ -81,20 +76,11 @@ Runs an Armada job and calls the job_service_client for polling. -#### render_template_fields(context, jinja_env=None) -Template all attributes listed in *self.template_fields*. - -This mutates the attributes in-place and is irreversible. - - -* **Parameters** - - - * **context** (*Context*) – Context dict with values to apply on content. - - - * **jinja_env** (*Environment** | **None*) – Jinja’s environment to use for rendering. +#### on_kill() +Override this method to clean up subprocesses when a task instance gets killed. +Any use of the threading, subprocess or multiprocessing module within an +operator needs to be cleaned up, or it will leave ghost processes behind. * **Return type** @@ -103,133 +89,36 @@ This mutates the attributes in-place and is irreversible. -#### template_fields(_: Sequence[str_ _ = ('job_request_items',_ ) -## armada.operators.armada_deferrable module - - -### _class_ armada.operators.armada_deferrable.ArmadaDeferrableOperator(name, armada_channel_args, job_service_channel_args, armada_queue, job_request_items, lookout_url_template=None, poll_interval=30, \*\*kwargs) -Bases: `BaseOperator` - -Implementation of a deferrable armada operator for airflow. - -Distinguished from ArmadaOperator by its ability to defer itself after -submitting its job_request_items. - -See -[https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html](https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html) -for more information about deferrable airflow operators. - -Airflow operators inherit from BaseOperator. - - -* **Parameters** - - - * **name** (*str*) – The name of the airflow task. - - - * **armada_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. - - - * **job_service_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. - - - * **armada_queue** (*str*) – The queue name for Armada. - - - * **job_request_items** (*List**[**JobSubmitRequestItem**]*) – A PodSpec that is used by Armada for submitting a job. - - - * **lookout_url_template** (*str** | **None*) – A URL template to be used to provide users - a valid link to the related lookout job in this operator’s log. - The format should be: - “[https://lookout.armada.domain/jobs](https://lookout.armada.domain/jobs)?job_id=” where will - be replaced with the actual job ID. - - - * **poll_interval** (*int*) – How often to poll jobservice to get status. - - - -* **Returns** - - A deferrable armada operator instance. - - - -#### execute(context) -Executes the Armada Operator. Only meant to be called by airflow. - -Submits an Armada job and defers itself to ArmadaJobCompleteTrigger to wait -until the job completes. - +#### pod_manager(k8s_context) * **Parameters** - **context** – The airflow context. - - - -* **Returns** - - None + **k8s_context** (*str*) – * **Return type** - None + *PodLogManager* #### render_template_fields(context, jinja_env=None) -Template all attributes listed in *self.template_fields*. - +Template all attributes listed in self.template_fields. This mutates the attributes in-place and is irreversible. +Args: -* **Parameters** - - - * **context** (*Context*) – Context dict with values to apply on content. - - - * **jinja_env** (*Environment** | **None*) – Jinja’s environment to use for rendering. - - - -* **Return type** - - None - - - -#### resume_job_complete(context, event, job_id) -Resumes this operator after deferring itself to ArmadaJobCompleteTrigger. -Only meant to be called from within Airflow. - -Reports the result of the job and returns. + context (Context): The execution context provided by Airflow. * **Parameters** - * **context** – The airflow context. + * **context** (*Context*) – Airflow Context dict wi1th values to apply on content - * **event** (*dict*) – The payload from the TriggerEvent raised by - ArmadaJobCompleteTrigger. - - - * **job_id** (*str*) – The job ID. - - - -* **Returns** - - None + * **jinja_env** (*Environment** | **None*) – jinja’s environment to use for rendering. @@ -239,492 +128,52 @@ Reports the result of the job and returns. -#### serialize() -Get a serialized version of this object. - - -* **Returns** - - A dict of keyword arguments used when instantiating - - - -* **Return type** - - dict - - -this object. - - -#### template_fields(_: Sequence[str_ _ = ('job_request_items',_ ) - -### _class_ armada.operators.armada_deferrable.ArmadaJobCompleteTrigger(job_id, job_service_channel_args, armada_queue, job_set_id, airflow_task_name, poll_interval=30) -Bases: `BaseTrigger` - -An airflow trigger that monitors the job state of an armada job. - -Triggers when the job is complete. +#### template_fields(_: Sequence[str_ _ = ('job_request', 'job_set_prefix'_ ) +Initializes a new ArmadaOperator. * **Parameters** - * **job_id** (*str*) – The job ID to monitor. - - - * **job_service_channel_args** (*GrpcChannelArgsDict*) – GRPC channel arguments to be used when - creating a grpc channel to connect to the job service instance. + * **name** (*str*) – The name of the job to be submitted. - * **armada_queue** (*str*) – The name of the armada queue. + * **channel_args** (*GrpcChannelArgs*) – The gRPC channel arguments for connecting to the Armada server. - * **job_set_id** (*str*) – The ID of the job set. + * **armada_queue** (*str*) – The name of the Armada queue to which the job will be submitted. - * **airflow_task_name** (*str*) – Name of the airflow task to which this trigger - belongs. + * **job_request** (*JobSubmitRequestItem*) – The job to be submitted to Armada. - * **poll_interval** (*int*) – How often to poll jobservice to get status. + * **job_set_prefix** (*Optional**[**str**]*) – A string to prepend to the jobSet name + * **lookout_url_template** – Template for creating lookout links. If not specified -* **Returns** - An armada job complete trigger instance. - - - -#### _async_ run() -Runs the trigger. Meant to be called by an airflow triggerer process. - - -#### serialize() -Return the information needed to reconstruct this Trigger. - - -* **Returns** - - Tuple of (class path, keyword arguments needed to re-instantiate). - - - -* **Return type** - - tuple +then no tracking information will be logged. +:type lookout_url_template: Optional[str] +:param poll_interval: The interval in seconds between polling for job status updates. +:type poll_interval: int +:param container_logs: Name of container whose logs will be published to stdout. +:type container_logs: Optional[str] +:param k8s_token_retriever: A serialisable Kubernetes token retriever object. We use +this to read logs from Kubernetes pods. +:type k8s_token_retriever: Optional[TokenRetriever] +:param deferrable: Whether the operator should run in a deferrable mode, allowing +for asynchronous execution. +:type deferrable: bool +:param job_acknowledgement_timeout: The timeout in seconds to wait for a job to be +acknowledged by Armada. +:type job_acknowledgement_timeout: int +:param kwargs: Additional keyword arguments to pass to the BaseOperator. +## armada.operators.armada_deferrable module ## armada.operators.jobservice module - -### _class_ armada.operators.jobservice.JobServiceClient(channel) -Bases: `object` - -The JobService Client - -Implementation of gRPC stubs from JobService - - -* **Parameters** - - **channel** – gRPC channel used for authentication. See - [https://grpc.github.io/grpc/python/grpc.html](https://grpc.github.io/grpc/python/grpc.html) - for more information. - - - -* **Returns** - - a job service client instance - - - -#### get_job_status(queue, job_set_id, job_id) -Get job status of a given job in a queue and job_set_id. - -Uses the GetJobStatus rpc to get a status of your job - - -* **Parameters** - - - * **queue** (*str*) – The name of the queue - - - * **job_set_id** (*str*) – The name of the job set (a grouping of jobs) - - - * **job_id** (*str*) – The id of the job - - - -* **Returns** - - A Job Service Request (State, Error) - - - -* **Return type** - - *JobServiceResponse* - - - -#### health() -Health Check for GRPC Request - - -* **Return type** - - *HealthCheckResponse* - - - -### armada.operators.jobservice.get_retryable_job_service_client(target, credentials=None, compression=None) -Get a JobServiceClient that has retry configured - - -* **Parameters** - - - * **target** (*str*) – grpc channel target - - - * **credentials** (*ChannelCredentials** | **None*) – grpc channel credentials (if needed) - - - * **compresion** – grpc channel compression - - - * **compression** (*Compression** | **None*) – - - - -* **Returns** - - A job service client instance - - - -* **Return type** - - *JobServiceClient* - - ## armada.operators.jobservice_asyncio module - -### _class_ armada.operators.jobservice_asyncio.JobServiceAsyncIOClient(channel) -Bases: `object` - -The JobService AsyncIO Client - -AsyncIO implementation of gRPC stubs from JobService - - -* **Parameters** - - **channel** (*Channel*) – AsyncIO gRPC channel used for authentication. See - [https://grpc.github.io/grpc/python/grpc_asyncio.html](https://grpc.github.io/grpc/python/grpc_asyncio.html) - for more information. - - - -* **Returns** - - A job service client instance - - - -#### _async_ get_job_status(queue, job_set_id, job_id) -Get job status of a given job in a queue and job_set_id. - -Uses the GetJobStatus rpc to get a status of your job - - -* **Parameters** - - - * **queue** (*str*) – The name of the queue - - - * **job_set_id** (*str*) – The name of the job set (a grouping of jobs) - - - * **job_id** (*str*) – The id of the job - - - -* **Returns** - - A Job Service Request (State, Error) - - - -* **Return type** - - *JobServiceResponse* - - - -#### _async_ health() -Health Check for GRPC Request - - -* **Return type** - - *HealthCheckResponse* - - - -### armada.operators.jobservice_asyncio.get_retryable_job_service_asyncio_client(target, credentials, compression) -Get a JobServiceAsyncIOClient that has retry configured - - -* **Parameters** - - - * **target** (*str*) – grpc channel target - - - * **credentials** (*ChannelCredentials** | **None*) – grpc channel credentials (if needed) - - - * **compresion** – grpc channel compression - - - * **compression** (*Compression** | **None*) – - - - -* **Returns** - - A job service asyncio client instance - - - -* **Return type** - - *JobServiceAsyncIOClient* - - ## armada.operators.utils module - - -### _class_ armada.operators.utils.JobState(value) -Bases: `Enum` - -An enumeration. - - -#### CANCELLED(_ = _ ) - -#### CONNECTION_ERR(_ = _ ) - -#### DUPLICATE_FOUND(_ = _ ) - -#### FAILED(_ = _ ) - -#### JOB_ID_NOT_FOUND(_ = _ ) - -#### RUNNING(_ = _ ) - -#### SUBMITTED(_ = _ ) - -#### SUCCEEDED(_ = _ ) - -### armada.operators.utils.airflow_error(job_state, name, job_id) -Throw an error on a terminal event if job errored out - - -* **Parameters** - - - * **job_state** (*JobState*) – A JobState enum class - - - * **name** (*str*) – The name of your armada job - - - * **job_id** (*str*) – The job id that armada assigns to it - - - -* **Returns** - - No Return or an AirflowFailException. - - -AirflowFailException tells Airflow Schedule to not reschedule the task - - -### armada.operators.utils.annotate_job_request_items(context, job_request_items) -Annotates the inbound job request items with Airflow context elements - - -* **Parameters** - - - * **context** – The airflow context. - - - * **job_request_items** (*List**[**JobSubmitRequestItem**]*) – The job request items to be sent to armada - - - -* **Returns** - - annotated job request items for armada - - - -* **Return type** - - *List*[*JobSubmitRequestItem*] - - - -### armada.operators.utils.default_job_status_callable(armada_queue, job_set_id, job_id, job_service_client) - -* **Parameters** - - - * **armada_queue** (*str*) – - - - * **job_set_id** (*str*) – - - - * **job_id** (*str*) – - - - * **job_service_client** (*JobServiceClient*) – - - - -* **Return type** - - *JobServiceResponse* - - - -### armada.operators.utils.get_annotation_key_prefix() -Provides the annotation key prefix, -which can be specified in env var ANNOTATION_KEY_PREFIX. -A default is provided if the env var is not defined - - -* **Returns** - - string annotation key prefix - - - -* **Return type** - - str - - - -### armada.operators.utils.job_state_from_pb(state) - -* **Return type** - - *JobState* - - - -### armada.operators.utils.search_for_job_complete(armada_queue, job_set_id, airflow_task_name, job_id, poll_interval=30, job_service_client=None, job_status_callable=, time_out_for_failure=7200) -Poll JobService cache until you get a terminated event. - -A terminated event is SUCCEEDED, FAILED or CANCELLED - - -* **Parameters** - - - * **armada_queue** (*str*) – The queue for armada - - - * **job_set_id** (*str*) – Your job_set_id - - - * **airflow_task_name** (*str*) – The name of your armada job - - - * **poll_interval** (*int*) – Polling interval for jobservice to get status. - - - * **job_id** (*str*) – The name of the job id that armada assigns to it - - - * **job_service_client** (*JobServiceClient** | **None*) – A JobServiceClient that is used for polling. - It is optional only for testing - - - * **job_status_callable** – A callable object for test injection. - - - * **time_out_for_failure** (*int*) – The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - - - -* **Returns** - - A tuple of JobStateEnum, message - - - -* **Return type** - - *Tuple*[*JobState*, str] - - - -### _async_ armada.operators.utils.search_for_job_complete_async(armada_queue, job_set_id, airflow_task_name, job_id, job_service_client, log, poll_interval, time_out_for_failure=7200) -Poll JobService cache asyncronously until you get a terminated event. - -A terminated event is SUCCEEDED, FAILED or CANCELLED - - -* **Parameters** - - - * **armada_queue** (*str*) – The queue for armada - - - * **job_set_id** (*str*) – Your job_set_id - - - * **airflow_task_name** (*str*) – The name of your armada job - - - * **job_id** (*str*) – The name of the job id that armada assigns to it - - - * **job_service_client** (*JobServiceAsyncIOClient*) – A JobServiceClient that is used for polling. - It is optional only for testing - - - * **poll_interval** (*int*) – How often to poll jobservice to get status. - - - * **time_out_for_failure** (*int*) – The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - - - -* **Returns** - - A tuple of JobStateEnum, message - - - -* **Return type** - - *Tuple*[*JobState*, str] diff --git a/magefiles/airflow.go b/magefiles/airflow.go index 8a1d970b5d7..bb160a64536 100644 --- a/magefiles/airflow.go +++ b/magefiles/airflow.go @@ -95,10 +95,5 @@ func AirflowOperator() error { return fmt.Errorf("failed to build Airflow Operator: %w", err) } - err = dockerRun("run", "--rm", "-v", "${PWD}/proto-airflow:/proto-airflow", "-v", "${PWD}:/go/src/armada", "-w", "/go/src/armada", "armada-airflow-operator-builder", "./scripts/build-airflow-operator.sh") - if err != nil { - return fmt.Errorf("failed to run build-airflow-operator.sh script: %w", err) - } - return nil } diff --git a/magefiles/tests.go b/magefiles/tests.go index 620aab1dd97..11cc3b2b6be 100644 --- a/magefiles/tests.go +++ b/magefiles/tests.go @@ -136,10 +136,6 @@ func runTest(name, outputFileName string) error { // Teste2eAirflow runs e2e tests for airflow func Teste2eAirflow() error { mg.Deps(AirflowOperator) - if err := BuildDockers("jobservice"); err != nil { - return err - } - cmd, err := go_CMD() if err != nil { return err @@ -149,29 +145,14 @@ func Teste2eAirflow() error { fmt.Println(err) } - if err := dockerRun("rm", "-f", "jobservice"); err != nil { - fmt.Println(err) - } - - err = dockerRun("run", "-d", "--name", "jobservice", "--network=kind", - "--mount", "type=bind,src=${PWD}/e2e,dst=/e2e", "gresearch/armada-jobservice", "run", "--config", - "/e2e/setup/jobservice.yaml") - if err != nil { - return err - } - err = dockerRun("run", "-v", "${PWD}/e2e:/e2e", "-v", "${PWD}/third_party/airflow:/code", - "--workdir", "/code", "-e", "ARMADA_SERVER=server", "-e", "ARMADA_PORT=50051", "-e", "JOB_SERVICE_HOST=jobservice", - "-e", "JOB_SERVICE_PORT=60003", "--entrypoint", "python3", "--network=kind", "armada-airflow-operator-builder:latest", - "-m", "pytest", "-v", "-s", "/code/tests/integration/test_airflow_operator_logic.py") + "--workdir", "/code", "-e", "ARMADA_SERVER=server", "-e", "ARMADA_PORT=50051", "--entrypoint", + "python3", "--network=kind", "armada-airflow-operator-builder:latest", + "-m", "pytest", "-v", "-s", "/code/test/integration/test_airflow_operator_logic.py") if err != nil { return err } - err = dockerRun("rm", "-f", "jobservice") - if err != nil { - return err - } return nil } diff --git a/scripts/build-airflow-operator.sh b/scripts/build-airflow-operator.sh deleted file mode 100755 index 531839ca025..00000000000 --- a/scripts/build-airflow-operator.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -# This script is intended to be run under the docker container at $ARMADADIR/build/python-api-client/ - -# make the python package armada.client, not pkg.api -mkdir -p proto-airflow -cp pkg/api/jobservice/jobservice.proto proto-airflow -sed -i 's/\([^\/]\)pkg\/api/\1jobservice/g' proto-airflow/*.proto - - -# generate python stubs -cd proto-airflow -python3 -m grpc_tools.protoc -I. --plugin=protoc-gen-mypy=$(which protoc-gen-mypy) --mypy_out=../third_party/airflow/armada/jobservice --python_out=../third_party/airflow/armada/jobservice --grpc_python_out=../third_party/airflow/armada/jobservice \ - jobservice.proto -cd .. -# This hideous code is because we can't use python package option in grpc. -# See https://github.com/protocolbuffers/protobuf/issues/7061 for an explanation. -# We need to import these packages as a module. -sed -i 's/import jobservice_pb2 as jobservice__pb2/from armada.jobservice import jobservice_pb2 as jobservice__pb2/g' third_party/airflow/armada/jobservice/*.py diff --git a/third_party/airflow/README.md b/third_party/airflow/README.md index 573b3861e5b..a09df8865c5 100644 --- a/third_party/airflow/README.md +++ b/third_party/airflow/README.md @@ -1,12 +1,112 @@ # armada-airflow-operator -An Airflow operator for interfacing with the armada client - -## Background +Armada Airflow Operator, which manages airflow jobs. This allows Armada jobs to be run as part of an Airflow DAG + +## Overview + +The `ArmadaOperator` allows user to run an Armada Job as a task in an Airflow DAG. It handles job submission, job +state management and (optionally) log streaming back to Airflow. + +The Operator works by periodically polling Armada for the state of each job. As a result, it is only intended for DAGs +with tens or (at the limit) hundreds of concurrent jobs. + +## Installation + +`pip install armada-airflow` + +## Example Usage + +```python +from datetime import datetime + +from airflow import DAG +from armada_client.armada import submit_pb2 +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) + +from armada.operators.armada import ArmadaOperator + +def create_dummy_job(): + """ + Create a dummy job with a single container. + """ + + # For information on where this comes from, + # see https://github.com/kubernetes/api/blob/master/core/v1/generated.proto + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="sleep", + image="alpine:3.16.2", + args=["sh", "-c", "for i in $(seq 1 60); do echo $i; sleep 1; done"], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="1"), + "memory": api_resource.Quantity(string="1Gi"), + }, + limits={ + "cpu": api_resource.Quantity(string="1"), + "memory": api_resource.Quantity(string="1Gi"), + }, + ), + ) + ], + ) + + return submit_pb2.JobSubmitRequestItem( + priority=1, pod_spec=pod, namespace="armada" + ) + +armada_channel_args = {"target": "127.0.0.1:50051"} + + +with DAG( + "test_new_armada_operator", + description="Example DAG Showing Usage Of ArmadaOperator", + schedule=None, + start_date=datetime(2022, 1, 1), + catchup=False, +) as dag: + armada_task = ArmadaOperator( + name="non_deferrable_task", + task_id="1", + channel_args=armada_channel_args, + armada_queue="armada", + job_request=create_dummy_job(), + container_logs="sleep", + lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", + deferrable=False + ) + + armada_deferrable_task = ArmadaOperator( + name="deferrable_task", + task_id="2", + channel_args=armada_channel_args, + armada_queue="armada", + job_request=create_dummy_job(), + container_logs="sleep", + lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", + deferrable=True + ) + + armada_task >> armada_deferrable_task +``` +## Parameters -Airflow is an open source project focused on orchestrating Direct Acylic Graphs (DAGs) across different compute platforms. To interface Airflow with Armada, you should use our armada operator. +| Name | Description | Notes | +|----------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------| +| channel_args | A list of key-value pairs ([channel_arguments](https://grpc.github.io/grpc/python/glossary.html#term-channel_arguments) in gRPC runtime) to configure the channel. | None | +| armada_queue | Armada queue to be used for the job | Make sure that Airflow user is permissioned on this queue | +| job_request | A `JobSubmitRequestItem` that is to be submitted to Armada as part of this task | Object contains a `core_v1.PodSpec` within it | +| job_set_prefix | A prefix for the JobSet name provided to Armada when submitting the job | The JobSet name submitted will be the Airflow `run_id` prefixed with this provided prefix | +| poll_interval | Integer number of seconds representing how ofter Airflow will poll Armada for Job Status. Defaults to 30 Seconds | Decreasing this makes the operator more responsive but comes at the cost of increased load on the Armada Server. Please do not decrease below 10 seconds. | +| container_logs | Name of the container in your job from which you wish to stream logs. If unset then no logs will be streamed | Only use this if you are running relatively few (<50) concurrent jobs | +| deferrable | Flag to specify whether to run the operator in Airflow Deferrable Mode | Defaults to True | -## Airflow +# Contributing The [airflow documentation](https://airflow.apache.org/) was used for setting up a simple test server. @@ -48,13 +148,13 @@ You can install the package via `pip3 install third_party/airflow`. You can use our tox file that streamlines development lifecycle. For development, you can install black, tox, mypy and flake8. -`python3.8 -m tox -e py38` will run unit tests. +`python3.10 -m tox -e py310` will run unit tests. -`python3.8 -m tox -e format` will run a format check +`python3.10 -m tox -e format` will run black on your code. -`python3.8 -m tox -e format-code` will run black on your code. +`python3.10 -m tox -e format-check` will run a format check. -`python3.8 -m tox -e docs` will generate a new sphinx doc. +`python3.10 -m tox -e docs` will generate a new sphinx doc. ## Releasing the client Armada-airflow releases are automated via Github Actions, for contributors with sufficient access to run them. diff --git a/third_party/airflow/armada/auth.py b/third_party/airflow/armada/auth.py new file mode 100644 index 00000000000..ca90b521ecf --- /dev/null +++ b/third_party/airflow/armada/auth.py @@ -0,0 +1,11 @@ +from typing import Dict, Any, Tuple, Protocol + + +""" We use this interface for objects fetching Kubernetes auth tokens. Since + it's used within the Trigger, it must be serialisable.""" + + +class TokenRetriever(Protocol): + def get_token(self) -> str: ... + + def serialize(self) -> Tuple[str, Dict[str, Any]]: ... diff --git a/third_party/airflow/armada/logs/log_consumer.py b/third_party/airflow/armada/logs/log_consumer.py new file mode 100644 index 00000000000..8fd8c31d3ef --- /dev/null +++ b/third_party/airflow/armada/logs/log_consumer.py @@ -0,0 +1,252 @@ +# Copyright 2016-2024 The Apache Software Foundation +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import queue +from datetime import timedelta +from http.client import HTTPResponse +from typing import Generator, TYPE_CHECKING, Callable, Awaitable, List + +from aiohttp.client_exceptions import ClientResponse +from airflow.utils.timezone import utcnow +from kubernetes.client import V1Pod +from kubernetes_asyncio.client import V1Pod as aio_V1Pod + +from armada.logs.utils import container_is_running, get_container_status + +if TYPE_CHECKING: + from urllib3.response import HTTPResponse # noqa: F811 + + +class PodLogsConsumerAsync: + """ + Responsible for pulling pod logs from a stream asynchronously, checking the + container status before reading data. + + This class contains a workaround for the issue + https://github.com/apache/airflow/issues/23497. + + :param response: HTTP response with logs + :param pod: Pod instance from Kubernetes client + :param read_pod_async: Callable returning a pod object that can be awaited on, + given (pod name, namespace) as arguments + :param container_name: Name of the container that we're reading logs from + :param post_termination_timeout: (Optional) The period of time in seconds + representing for how long time + logs are available after the container termination. + :param read_pod_cache_timeout: (Optional) The container's status cache lifetime. + The container status is cached to reduce API calls. + + :meta private: + """ + + def __init__( + self, + response: ClientResponse, + pod_name: str, + namespace: str, + read_pod_async: Callable[[str, str], Awaitable[aio_V1Pod]], + container_name: str, + post_termination_timeout: int = 120, + read_pod_cache_timeout: int = 120, + ): + self.response = response + self.pod_name = pod_name + self.namespace = namespace + self._read_pod_async = read_pod_async + self.container_name = container_name + self.post_termination_timeout = post_termination_timeout + self.last_read_pod_at = None + self.read_pod_cache = None + self.read_pod_cache_timeout = read_pod_cache_timeout + self.log_queue = queue.Queue() + + def __aiter__(self): + return self + + async def __anext__(self): + r"""Yield log items divided by the '\n' symbol.""" + if not self.log_queue.empty(): + return self.log_queue.get() + + incomplete_log_item: List[bytes] = [] + if await self.logs_available(): + async for data_chunk in self.response.content: + if b"\n" in data_chunk: + log_items = data_chunk.split(b"\n") + for x in self._extract_log_items(incomplete_log_item, log_items): + if x is not None: + self.log_queue.put(x) + incomplete_log_item = self._save_incomplete_log_item(log_items[-1]) + else: + incomplete_log_item.append(data_chunk) + if not await self.logs_available(): + break + else: + self.response.close() + raise StopAsyncIteration + if incomplete_log_item: + item = b"".join(incomplete_log_item) + if item is not None: + self.log_queue.put(item) + + # Prevents method from returning None + if not self.log_queue.empty(): + return self.log_queue.get() + + self.response.close() + raise StopAsyncIteration + + @staticmethod + def _extract_log_items(incomplete_log_item: List[bytes], log_items: List[bytes]): + yield b"".join(incomplete_log_item) + log_items[0] + b"\n" + for x in log_items[1:-1]: + yield x + b"\n" + + @staticmethod + def _save_incomplete_log_item(sub_chunk: bytes): + return [sub_chunk] if [sub_chunk] else [] + + async def logs_available(self): + remote_pod = await self.read_pod() + if container_is_running(pod=remote_pod, container_name=self.container_name): + return True + container_status = get_container_status( + pod=remote_pod, container_name=self.container_name + ) + state = container_status.state if container_status else None + terminated = state.terminated if state else None + if terminated: + termination_time = terminated.finished_at + if termination_time: + return ( + termination_time + timedelta(seconds=self.post_termination_timeout) + > utcnow() + ) + return False + + async def read_pod(self): + _now = utcnow() + if ( + self.read_pod_cache is None + or self.last_read_pod_at + timedelta(seconds=self.read_pod_cache_timeout) + < _now + ): + self.read_pod_cache = await self._read_pod_async( + self.pod_name, self.namespace + ) + self.last_read_pod_at = _now + return self.read_pod_cache + + +class PodLogsConsumer: + """ + Responsible for pulling pod logs from a stream with checking a container status + before reading data. + + This class is a workaround for the issue + https://github.com/apache/airflow/issues/23497. + + :param response: HTTP response with logs + :param pod: Pod instance from Kubernetes client + :param read_pod: Callable returning a pod object given (pod name, namespace) as + arguments + :param container_name: Name of the container that we're reading logs from + :param post_termination_timeout: (Optional) The period of time in seconds + representing for how long time + logs are available after the container termination. + :param read_pod_cache_timeout: (Optional) The container's status cache lifetime. + The container status is cached to reduce API calls. + + :meta private: + """ + + def __init__( + self, + response: HTTPResponse, + pod_name: str, + namespace: str, + read_pod: Callable[[str, str], V1Pod], + container_name: str, + post_termination_timeout: int = 120, + read_pod_cache_timeout: int = 120, + ): + self.response = response + self.pod_name = pod_name + self.namespace = namespace + self._read_pod = read_pod + self.container_name = container_name + self.post_termination_timeout = post_termination_timeout + self.last_read_pod_at = None + self.read_pod_cache = None + self.read_pod_cache_timeout = read_pod_cache_timeout + + def __iter__(self) -> Generator[bytes, None, None]: + r"""Yield log items divided by the '\n' symbol.""" + incomplete_log_item: List[bytes] = [] + if self.logs_available(): + for data_chunk in self.response.stream(amt=None, decode_content=True): + if b"\n" in data_chunk: + log_items = data_chunk.split(b"\n") + yield from self._extract_log_items(incomplete_log_item, log_items) + incomplete_log_item = self._save_incomplete_log_item(log_items[-1]) + else: + incomplete_log_item.append(data_chunk) + if not self.logs_available(): + break + if incomplete_log_item: + yield b"".join(incomplete_log_item) + + @staticmethod + def _extract_log_items(incomplete_log_item: List[bytes], log_items: List[bytes]): + yield b"".join(incomplete_log_item) + log_items[0] + b"\n" + for x in log_items[1:-1]: + yield x + b"\n" + + @staticmethod + def _save_incomplete_log_item(sub_chunk: bytes): + return [sub_chunk] if [sub_chunk] else [] + + def logs_available(self): + remote_pod = self.read_pod() + if container_is_running(pod=remote_pod, container_name=self.container_name): + return True + container_status = get_container_status( + pod=remote_pod, container_name=self.container_name + ) + state = container_status.state if container_status else None + terminated = state.terminated if state else None + if terminated: + termination_time = terminated.finished_at + if termination_time: + return ( + termination_time + timedelta(seconds=self.post_termination_timeout) + > utcnow() + ) + return False + + def read_pod(self): + _now = utcnow() + if ( + self.read_pod_cache is None + or self.last_read_pod_at + timedelta(seconds=self.read_pod_cache_timeout) + < _now + ): + self.read_pod_cache = self._read_pod(self.pod_name, self.namespace) + self.last_read_pod_at = _now + return self.read_pod_cache diff --git a/third_party/airflow/armada/logs/pod_log_manager.py b/third_party/airflow/armada/logs/pod_log_manager.py new file mode 100644 index 00000000000..20e8e51c852 --- /dev/null +++ b/third_party/airflow/armada/logs/pod_log_manager.py @@ -0,0 +1,550 @@ +# Copyright 2016-2024 The Apache Software Foundation +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +import math +import time +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, cast, Optional + +import pendulum +import tenacity +from kubernetes import client, watch, config +from kubernetes_asyncio import client as async_client, config as async_config +from kubernetes.client.rest import ApiException +from pendulum import DateTime +from pendulum.parsing.exceptions import ParserError +from urllib3.exceptions import HTTPError as BaseHTTPError + +from airflow.exceptions import AirflowException +from airflow.utils.log.logging_mixin import LoggingMixin + +from armada.auth import TokenRetriever +from armada.logs.log_consumer import PodLogsConsumer, PodLogsConsumerAsync +from armada.logs.utils import container_is_running + +if TYPE_CHECKING: + from kubernetes.client.models.v1_pod import V1Pod + + +class PodPhase: + """ + Possible pod phases. + + See https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase. + """ + + PENDING = "Pending" + RUNNING = "Running" + FAILED = "Failed" + SUCCEEDED = "Succeeded" + + terminal_states = {FAILED, SUCCEEDED} + + +@dataclass +class PodLoggingStatus: + """Return the status of the pod and last log time when exiting from + `fetch_container_logs`.""" + + running: bool + last_log_time: DateTime | None + + +class PodLogManagerAsync(LoggingMixin): + """Monitor logs of Kubernetes pods asynchronously.""" + + def __init__( + self, + k8s_context: str, + token_retriever: Optional[TokenRetriever] = None, + ): + """ + Create the launcher. + + :param k8s_context: kubernetes context + :param token_retriever: Retrieves auth tokens + """ + super().__init__() + self._k8s_context = k8s_context + self._watch = watch.Watch() + self._k8s_client = None + self._token_retriever = token_retriever + + async def _refresh_k8s_auth_token(self, interval=60 * 5): + if self._token_retriever is not None: + while True: + await asyncio.sleep(interval) + self._k8s_client.api_client.configuration.api_key["BearerToken"] = ( + f"Bearer {self._token_retriever.get_token()}" + ) + + async def k8s_client(self) -> async_client: + await async_config.load_kube_config(context=self._k8s_context) + asyncio.create_task(self._refresh_k8s_auth_token()) + return async_client.CoreV1Api() + + async def fetch_container_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + *, + follow=False, + since_time: DateTime | None = None, + post_termination_timeout: int = 120, + ) -> PodLoggingStatus: + """ + Follow the logs of container and stream to airflow logging. Doesn't block whilst + logs are being fetched. + + Returns when container exits. + + Between when the pod starts and logs being available, there might be a delay due + to CSR not approved + and signed yet. In such situation, ApiException is thrown. This is why we are + retrying on this + specific exception. + """ + # Can't await in constructor, so instantiating here + if self._k8s_client is None: + self._k8s_client = await self.k8s_client() + + @tenacity.retry( + retry=tenacity.retry_if_exception_type(ApiException), + stop=tenacity.stop_after_attempt(10), + wait=tenacity.wait_fixed(1), + ) + async def consume_logs( + *, + since_time: DateTime | None = None, + follow: bool = True, + logs: PodLogsConsumerAsync | None, + ) -> tuple[DateTime | None, PodLogsConsumerAsync | None]: + """ + Try to follow container logs until container completes. + + For a long-running container, sometimes the log read may be interrupted + Such errors of this kind are suppressed. + + Returns the last timestamp observed in logs. + """ + last_captured_timestamp = None + try: + logs = await self._read_pod_logs( + pod_name=pod_name, + namespace=namespace, + container_name=container_name, + timestamps=True, + since_seconds=( + math.ceil((pendulum.now() - since_time).total_seconds()) + if since_time + else None + ), + follow=follow, + post_termination_timeout=post_termination_timeout, + ) + message_to_log = None + message_timestamp = None + progress_callback_lines = [] + try: + async for raw_line in logs: + line = raw_line.decode("utf-8", errors="backslashreplace") + line_timestamp, message = self._parse_log_line(line) + if line_timestamp: # detect new log line + if message_to_log is None: # first line in the log + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines.append(line) + else: # previous log line is complete + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines = [line] + else: # continuation of the previous log line + message_to_log = f"{message_to_log}\n{message}" + progress_callback_lines.append(line) + finally: + if message_to_log is not None: + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + except BaseHTTPError as e: + self.log.warning( + "Reading of logs interrupted for container %r with error %r; will " + "retry. " + "Set log level to DEBUG for traceback.", + container_name, + e, + ) + self.log.debug( + "Traceback for interrupted logs read for pod %r", + pod_name, + exc_info=True, + ) + return last_captured_timestamp or since_time, logs + + # note: `read_pod_logs` follows the logs, so we shouldn't necessarily *need* to + # loop as we do here. But in a long-running process we might temporarily lose + # connectivity. + # So the looping logic is there to let us resume following the logs. + logs = None + last_log_time = since_time + while True: + last_log_time, logs = await consume_logs( + since_time=last_log_time, + follow=follow, + logs=logs, + ) + if not await self._container_is_running_async( + pod_name, namespace, container_name=container_name + ): + return PodLoggingStatus(running=False, last_log_time=last_log_time) + if not follow: + return PodLoggingStatus(running=True, last_log_time=last_log_time) + else: + self.log.warning( + "Pod %s log read interrupted but container %s still running", + pod_name, + container_name, + ) + time.sleep(1) + + def _parse_log_line(self, line: str) -> tuple[DateTime | None, str]: + """ + Parse K8s log line and returns the final state. + + :param line: k8s log line + :return: timestamp and log message + """ + timestamp, sep, message = line.strip().partition(" ") + if not sep: + return None, line + try: + last_log_time = cast(DateTime, pendulum.parse(timestamp)) + except ParserError: + return None, line + return last_log_time, message + + async def _container_is_running_async( + self, pod_name: str, namespace: str, container_name: str + ) -> bool: + """Read pod and checks if container is running.""" + remote_pod = await self.read_pod(pod_name, namespace) + return container_is_running(pod=remote_pod, container_name=container_name) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + async def _read_pod_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + tail_lines: int | None = None, + timestamps: bool = False, + since_seconds: int | None = None, + follow=True, + post_termination_timeout: int = 120, + ) -> PodLogsConsumerAsync: + """Read log from the POD.""" + additional_kwargs = {} + if since_seconds: + additional_kwargs["since_seconds"] = since_seconds + + if tail_lines: + additional_kwargs["tail_lines"] = tail_lines + + try: + logs = await self._k8s_client.read_namespaced_pod_log( + name=pod_name, + namespace=namespace, + container=container_name, + follow=follow, + timestamps=timestamps, + _preload_content=False, + **additional_kwargs, + ) + except BaseHTTPError: + self.log.exception("There was an error reading the kubernetes API.") + raise + + return PodLogsConsumerAsync( + response=logs, + pod_name=pod_name, + namespace=namespace, + read_pod_async=self.read_pod, + container_name=container_name, + post_termination_timeout=post_termination_timeout, + ) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + async def read_pod(self, pod_name: str, namespace: str) -> V1Pod: + """Read POD information.""" + try: + return await self._k8s_client.read_namespaced_pod(pod_name, namespace) + except BaseHTTPError as e: + raise AirflowException( + f"There was an error reading the kubernetes API: {e}" + ) + + +class PodLogManager(LoggingMixin): + """Monitor logs of Kubernetes pods.""" + + def __init__( + self, k8s_context: str, token_retriever: Optional[TokenRetriever] = None + ): + """ + Create the launcher. + + :param k8s_context: kubernetes context + :param token_retriever: Retrieves auth tokens + """ + super().__init__() + self._k8s_context = k8s_context + self._watch = watch.Watch() + self._token_retriever = token_retriever + + def _refresh_k8s_auth_token(self): + if self._token_retriever is not None: + self._k8s_client.api_client.configuration.api_key["BearerToken"] = ( + f"Bearer {self._token_retriever.get_token()}" + ) + + @cached_property + def _k8s_client(self) -> client: + config.load_kube_config(context=self._k8s_context) + return client.CoreV1Api() + + def fetch_container_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + *, + follow=False, + since_time: DateTime | None = None, + post_termination_timeout: int = 120, + ) -> PodLoggingStatus: + """ + Follow the logs of container and stream to airflow logging. + + Returns when container exits. + + Between when the pod starts and logs being available, there might be a delay due + to CSR not approved + and signed yet. In such situation, ApiException is thrown. This is why we are + retrying on this + specific exception. + """ + + @tenacity.retry( + retry=tenacity.retry_if_exception_type(ApiException), + stop=tenacity.stop_after_attempt(10), + wait=tenacity.wait_fixed(1), + ) + def consume_logs( + *, + since_time: DateTime | None = None, + follow: bool = True, + logs: PodLogsConsumer | None, + ) -> tuple[DateTime | None, PodLogsConsumer | None]: + """ + Try to follow container logs until container completes. + + For a long-running container, sometimes the log read may be interrupted + Such errors of this kind are suppressed. + + Returns the last timestamp observed in logs. + """ + last_captured_timestamp = None + try: + logs = self._read_pod_logs( + pod_name=pod_name, + namespace=namespace, + container_name=container_name, + timestamps=True, + since_seconds=( + math.ceil((pendulum.now() - since_time).total_seconds()) + if since_time + else None + ), + follow=follow, + post_termination_timeout=post_termination_timeout, + ) + message_to_log = None + message_timestamp = None + progress_callback_lines = [] + try: + for raw_line in logs: + line = raw_line.decode("utf-8", errors="backslashreplace") + line_timestamp, message = self._parse_log_line(line) + if line_timestamp: # detect new log line + if message_to_log is None: # first line in the log + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines.append(line) + else: # previous log line is complete + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + message_to_log = message + message_timestamp = line_timestamp + progress_callback_lines = [line] + else: # continuation of the previous log line + message_to_log = f"{message_to_log}\n{message}" + progress_callback_lines.append(line) + finally: + if message_to_log is not None: + self.log.info("[%s] %s", container_name, message_to_log) + last_captured_timestamp = message_timestamp + except BaseHTTPError as e: + self.log.warning( + "Reading of logs interrupted for container %r with error %r; will " + "retry. " + "Set log level to DEBUG for traceback.", + container_name, + e, + ) + self.log.debug( + "Traceback for interrupted logs read for pod %r", + pod_name, + exc_info=True, + ) + return last_captured_timestamp or since_time, logs + + # note: `read_pod_logs` follows the logs, so we shouldn't necessarily *need* to + # loop as we do here. But in a long-running process we might temporarily lose + # connectivity. + # So the looping logic is there to let us resume following the logs. + logs = None + last_log_time = since_time + while True: + last_log_time, logs = consume_logs( + since_time=last_log_time, + follow=follow, + logs=logs, + ) + if not self._container_is_running( + pod_name, namespace, container_name=container_name + ): + return PodLoggingStatus(running=False, last_log_time=last_log_time) + if not follow: + return PodLoggingStatus(running=True, last_log_time=last_log_time) + else: + self.log.warning( + "Pod %s log read interrupted but container %s still running", + pod_name, + container_name, + ) + time.sleep(1) + self._refresh_k8s_auth_token() + + def _parse_log_line(self, line: str) -> tuple[DateTime | None, str]: + """ + Parse K8s log line and returns the final state. + + :param line: k8s log line + :return: timestamp and log message + """ + timestamp, sep, message = line.strip().partition(" ") + if not sep: + return None, line + try: + last_log_time = cast(DateTime, pendulum.parse(timestamp)) + except ParserError: + return None, line + return last_log_time, message + + def _container_is_running( + self, pod_name: str, namespace: str, container_name: str + ) -> bool: + """Read pod and checks if container is running.""" + remote_pod = self.read_pod(pod_name, namespace) + return container_is_running(pod=remote_pod, container_name=container_name) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + def _read_pod_logs( + self, + pod_name: str, + namespace: str, + container_name: str, + tail_lines: int | None = None, + timestamps: bool = False, + since_seconds: int | None = None, + follow=True, + post_termination_timeout: int = 120, + ) -> PodLogsConsumer: + """Read log from the POD.""" + additional_kwargs = {} + if since_seconds: + additional_kwargs["since_seconds"] = since_seconds + + if tail_lines: + additional_kwargs["tail_lines"] = tail_lines + + try: + logs = self._k8s_client.read_namespaced_pod_log( + name=pod_name, + namespace=namespace, + container=container_name, + follow=follow, + timestamps=timestamps, + _preload_content=False, + **additional_kwargs, + ) + except BaseHTTPError: + self.log.exception("There was an error reading the kubernetes API.") + raise + + return PodLogsConsumer( + response=logs, + pod_name=pod_name, + namespace=namespace, + read_pod=self.read_pod, + container_name=container_name, + post_termination_timeout=post_termination_timeout, + ) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(), + reraise=True, + ) + def read_pod(self, pod_name: str, namespace: str) -> V1Pod: + """Read POD information.""" + try: + return self._k8s_client.read_namespaced_pod(pod_name, namespace) + except BaseHTTPError as e: + raise AirflowException( + f"There was an error reading the kubernetes API: {e}" + ) diff --git a/third_party/airflow/armada/logs/utils.py b/third_party/airflow/armada/logs/utils.py new file mode 100644 index 00000000000..ade71ba5fbe --- /dev/null +++ b/third_party/airflow/armada/logs/utils.py @@ -0,0 +1,55 @@ +# Copyright 2016-2024 The Apache Software Foundation +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING + +from kubernetes.client import V1Pod, V1ContainerStatus + +if TYPE_CHECKING: + from kubernetes.client.models.v1_container_status import ( # noqa: F811 + V1ContainerStatus, + ) + from kubernetes.client.models.v1_pod import V1Pod # noqa: F811 + + +def get_container_status(pod: V1Pod, container_name: str) -> V1ContainerStatus: + """Retrieve container status.""" + container_statuses = pod.status.container_statuses if pod and pod.status else None + if container_statuses: + # In general the variable container_statuses can store multiple items matching + # different containers. + # The following generator expression yields all items that have name equal to + # the container_name. + # The function next() here calls the generator to get only the first value. If + # there's nothing found + # then None is returned. + return next((x for x in container_statuses if x.name == container_name), None) + return None + + +def container_is_running(pod: V1Pod, container_name: str) -> bool: + """ + Examine V1Pod ``pod`` to determine whether ``container_name`` is running. + + If that container is present and running, returns True. Returns False otherwise. + """ + container_status = get_container_status(pod, container_name) + if not container_status: + return False + return container_status.state.running is not None diff --git a/third_party/airflow/armada/model.py b/third_party/airflow/armada/model.py new file mode 100644 index 00000000000..80e6e0d0a77 --- /dev/null +++ b/third_party/airflow/armada/model.py @@ -0,0 +1,83 @@ +import importlib +from typing import Tuple, Any, Optional, Sequence, Dict + +import grpc + + +""" This class exists so that we can retain our connection to the Armada Query API + when using the deferrable Armada Airflow Operator. Airflow requires any state + within deferrable operators be serialisable, unfortunately grpc.Channel isn't + itself serialisable.""" + + +class GrpcChannelArgs: + def __init__( + self, + target: str, + options: Optional[Sequence[Tuple[str, Any]]] = None, + compression: Optional[grpc.Compression] = None, + auth: Optional[grpc.AuthMetadataPlugin] = None, + auth_details: Optional[Dict[str, Any]] = None, + ): + self.target = target + self.options = options + self.compression = compression + if auth: + self.auth = auth + elif auth_details: + classpath, kwargs = auth_details + module_path, class_name = classpath.rsplit( + ".", 1 + ) # Split the classpath to module and class name + module = importlib.import_module( + module_path + ) # Dynamically import the module + cls = getattr(module, class_name) # Get the class from the module + self.auth = cls( + **kwargs + ) # Instantiate the class with the deserialized kwargs + else: + self.auth = None + + def serialize(self) -> Dict[str, Any]: + auth_details = self.auth.serialize() if self.auth else None + return { + "target": self.target, + "options": self.options, + "compression": self.compression, + "auth_details": auth_details, + } + + def channel(self) -> grpc.Channel: + if self.auth is None: + return grpc.insecure_channel( + target=self.target, options=self.options, compression=self.compression + ) + + return grpc.secure_channel( + target=self.target, + options=self.options, + compression=self.compression, + credentials=grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), + grpc.metadata_call_credentials(self.auth), + ), + ) + + def aio_channel(self) -> grpc.aio.Channel: + if self.auth is None: + return grpc.aio.insecure_channel( + target=self.target, + options=self.options, + compression=self.compression, + ) + + return grpc.aio.secure_channel( + target=self.target, + options=self.options, + compression=self.compression, + credentials=grpc.composite_channel_credentials( + grpc.ssl_channel_credentials(), + grpc.metadata_call_credentials(self.auth), + ), + ) diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index 33475651275..cb9fd361c27 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -16,151 +16,331 @@ # specific language governing permissions and limitations # under the License. -import logging -from typing import Optional, List, Sequence +import os +import time +from functools import lru_cache, cached_property +from typing import Optional, Sequence, Any, Dict -from airflow.models import BaseOperator +import jinja2 +from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.utils.context import Context +from airflow.models import BaseOperator +from airflow.utils.context import Context +from airflow.utils.log.logging_mixin import LoggingMixin +from armada_client.armada.job_pb2 import JobRunDetails +from armada_client.typings import JobState from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.client import ArmadaClient - -from armada.operators.grpc import GrpcChannelArgsDict, GrpcChannelArguments -from armada.operators.jobservice import ( - JobServiceClient, - default_jobservice_channel_options, -) -from armada.operators.utils import ( - airflow_error, - search_for_job_complete, - annotate_job_request_items, -) -from armada.jobservice import jobservice_pb2 - from google.protobuf.json_format import MessageToDict, ParseDict -import jinja2 - +from armada_client.client import ArmadaClient +from armada.auth import TokenRetriever +from armada.logs.pod_log_manager import PodLogManager +from armada.model import GrpcChannelArgs +from armada.triggers.armada import ArmadaTrigger -armada_logger = logging.getLogger("airflow.task") +class ArmadaOperator(BaseOperator, LoggingMixin): + """ + An Airflow operator that manages Job submission to Armada. -class ArmadaOperator(BaseOperator): + This operator submits a job to an Armada cluster, polls for its completion, + and handles job cancellation if the Airflow task is killed. """ - Implementation of an ArmadaOperator for airflow. - - Airflow operators inherit from BaseOperator. - - :param name: The name of the airflow task - :param armada_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. - :param job_service_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. - :param armada_queue: The queue name for Armada. - :param job_request_items: A PodSpec that is used by Armada for submitting a job - :param lookout_url_template: A URL template to be used to provide users - a valid link to the related lookout job in this operator's log. - The format should be: - "https://lookout.armada.domain/jobs?job_id=" where will - be replaced with the actual job ID. - :param poll_interval: How often to poll jobservice to get status. - :return: an armada operator instance + + template_fields: Sequence[str] = ("job_request", "job_set_prefix") + """ +Initializes a new ArmadaOperator. - template_fields: Sequence[str] = ("job_request_items",) +:param name: The name of the job to be submitted. +:type name: str +:param channel_args: The gRPC channel arguments for connecting to the Armada server. +:type channel_args: GrpcChannelArgs +:param armada_queue: The name of the Armada queue to which the job will be submitted. +:type armada_queue: str +:param job_request: The job to be submitted to Armada. +:type job_request: JobSubmitRequestItem +:param job_set_prefix: A string to prepend to the jobSet name +:type job_set_prefix: Optional[str] +:param lookout_url_template: Template for creating lookout links. If not specified +then no tracking information will be logged. +:type lookout_url_template: Optional[str] +:param poll_interval: The interval in seconds between polling for job status updates. +:type poll_interval: int +:param container_logs: Name of container whose logs will be published to stdout. +:type container_logs: Optional[str] +:param k8s_token_retriever: A serialisable Kubernetes token retriever object. We use +this to read logs from Kubernetes pods. +:type k8s_token_retriever: Optional[TokenRetriever] +:param deferrable: Whether the operator should run in a deferrable mode, allowing +for asynchronous execution. +:type deferrable: bool +:param job_acknowledgement_timeout: The timeout in seconds to wait for a job to be +acknowledged by Armada. +:type job_acknowledgement_timeout: int +:param kwargs: Additional keyword arguments to pass to the BaseOperator. +""" def __init__( self, name: str, - armada_channel_args: GrpcChannelArgsDict, - job_service_channel_args: GrpcChannelArgsDict, + channel_args: GrpcChannelArgs, armada_queue: str, - job_request_items: List[JobSubmitRequestItem], + job_request: JobSubmitRequestItem, + job_set_prefix: Optional[str] = "", lookout_url_template: Optional[str] = None, poll_interval: int = 30, + container_logs: Optional[str] = None, + k8s_token_retriever: Optional[TokenRetriever] = None, + deferrable: bool = conf.getboolean( + "operators", "default_deferrable", fallback=False + ), + job_acknowledgement_timeout: int = 5 * 60, **kwargs, ) -> None: super().__init__(**kwargs) self.name = name - self.armada_channel_args = GrpcChannelArguments(**armada_channel_args) - - if "options" not in job_service_channel_args: - job_service_channel_args["options"] = default_jobservice_channel_options - - self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args) + self.channel_args = channel_args self.armada_queue = armada_queue - self.job_request_items = job_request_items + self.job_request = job_request + self.job_set_prefix = job_set_prefix self.lookout_url_template = lookout_url_template self.poll_interval = poll_interval + self.container_logs = container_logs + self.k8s_token_retriever = k8s_token_retriever + self.deferrable = deferrable + self.job_acknowledgement_timeout = job_acknowledgement_timeout + self.job_id = None + self.job_set_id = None + + if self.container_logs and self.k8s_token_retriever is None: + self.log.warning( + "Token refresh mechanism not configured, airflow may stop retrieving " + "logs from Kubernetes" + ) def execute(self, context) -> None: """ - Executes the Armada Operator. - - Runs an Armada job and calls the job_service_client for polling. + Submits the job to Armada and polls for completion. - :param context: The airflow context. - - :return: None + :param context: The execution context provided by Airflow. + :type context: Context """ - job_service_client = JobServiceClient(self.job_service_channel_args.channel()) - # Health Check - health = job_service_client.health() - if health.status != jobservice_pb2.HealthCheckResponse.SERVING: - armada_logger.warn("Armada Job Service is not health") - # This allows us to use a unique id from airflow - # and have all jobs in a dag correspond to same jobset - job_set_id = context["run_id"] - - armada_client = ArmadaClient(channel=self.armada_channel_args.channel()) - job = armada_client.submit_jobs( - queue=self.armada_queue, - job_set_id=job_set_id, - job_request_items=annotate_job_request_items( - context, self.job_request_items - ), - ) + # We take the job_set_id from Airflow's run_id. This means that all jobs in the + # dag will be in the same jobset. + self.job_set_id = f"{self.job_set_prefix}{context['run_id']}" + self._annotate_job_request(context, self.job_request) - try: - job_id = job.job_response_items[0].job_id - except Exception: - raise AirflowException("Armada has issues submitting job") - - armada_logger.info("Running Armada job %s with id %s", self.name, job_id) - - lookout_url = self._get_lookout_url(job_id) - if len(lookout_url) > 0: - armada_logger.info("Lookout URL: %s", lookout_url) - - job_state, job_message = search_for_job_complete( - job_service_client=job_service_client, - armada_queue=self.armada_queue, - job_set_id=job_set_id, - airflow_task_name=self.name, - job_id=job_id, - poll_interval=self.poll_interval, + # Submit job or reattach to previously submitted job. We always do this + # synchronously. + self.job_id = self._reattach_or_submit_job( + context, self.armada_queue, self.job_set_id, self.job_request ) - armada_logger.info( - "Armada Job finished with %s and message: %s", job_state, job_message - ) - airflow_error(job_state, self.name, job_id) - def _get_lookout_url(self, job_id: str) -> str: - if self.lookout_url_template is None: - return "" - return self.lookout_url_template.replace("", job_id) + # Wait until finished + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=ArmadaTrigger( + job_id=self.job_id, + armada_queue=self.armada_queue, + job_set_id=self.job_set_id, + channel_args=self.channel_args, + poll_interval=self.poll_interval, + tracking_message=self._trigger_tracking_message(), + job_acknowledgement_timeout=self.job_acknowledgement_timeout, + container_logs=self.container_logs, + k8s_token_retriever=self.k8s_token_retriever, + job_request_namespace=self.job_request.namespace, + ), + method_name="_execute_complete", + ) + else: + self._poll_for_termination(self._trigger_tracking_message()) + + @cached_property + def client(self) -> ArmadaClient: + return ArmadaClient(channel=self.channel_args.channel()) + + @lru_cache(maxsize=None) + def pod_manager(self, k8s_context: str) -> PodLogManager: + return PodLogManager( + k8s_context=k8s_context, token_retriever=self.k8s_token_retriever + ) def render_template_fields( self, context: Context, jinja_env: Optional[jinja2.Environment] = None, ) -> None: - self.job_request_items = [ - MessageToDict(x, preserving_proto_field_name=True) - for x in self.job_request_items - ] + """ + Template all attributes listed in self.template_fields. + This mutates the attributes in-place and is irreversible. + + Args: + context (Context): The execution context provided by Airflow. + :param context: Airflow Context dict wi1th values to apply on content + :param jinja_env: jinja’s environment to use for rendering. + """ + self.job_request = MessageToDict( + self.job_request, preserving_proto_field_name=True + ) super().render_template_fields(context, jinja_env) - self.job_request_items = [ - ParseDict(x, JobSubmitRequestItem()) for x in self.job_request_items - ] + self.job_request = ParseDict(self.job_request, JobSubmitRequestItem()) + + def _cancel_job(self) -> None: + try: + result = self.client.cancel_jobs( + queue=self.armada_queue, + job_set_id=self.job_set_id, + job_id=self.job_id, + ) + if len(list(result.cancelled_ids)) > 0: + self.log.info(f"Cancelled job with id {result.cancelled_ids}") + else: + self.log.warning(f"Failed to cancel job with id {self.job_id}") + except Exception as e: + self.log.warning(f"Failed to cancel job with id {self.job_id}: {e}") + + def on_kill(self) -> None: + if self.job_id is not None: + self.log.info( + f"on_kill called, cancelling job with id {self.job_id} in queue " + f"{self.armada_queue}" + ) + self._cancel_job() + + def _trigger_tracking_message(self): + if self.lookout_url_template: + return ( + f"Job details available at " + f'{self.lookout_url_template.replace("", self.job_id)}' + ) + + return "" + + def _execute_complete(self, _: Context, event: Dict[str, Any]): + if event["status"] == "error": + raise AirflowException(event["response"]) + + def _reattach_or_submit_job( + self, + context: Context, + queue: str, + job_set_id: str, + job_request: JobSubmitRequestItem, + ) -> str: + ti = context["ti"] + existing_id = ti.xcom_pull( + dag_id=ti.dag_id, task_ids=ti.task_id, key=f"{ti.try_number}" + ) + if existing_id is not None: + self.log.info( + f"Attached to existing job with id {existing_id['armada_job_id']}" + ) + return existing_id["armada_job_id"] + + job_id = self._submit_job(queue, job_set_id, job_request) + self.log.info(f"Submitted job with id {job_id}") + ti.xcom_push(key=f"{ti.try_number}", value={"armada_job_id": job_id}) + return job_id + + def _submit_job( + self, queue: str, job_set_id: str, job_request: JobSubmitRequestItem + ) -> str: + resp = self.client.submit_jobs(queue, job_set_id, [job_request]) + num_responses = len(resp.job_response_items) + + # We submitted exactly one job to armada, so we expect a single response + if num_responses != 1: + raise AirflowException( + f"No valid received from Armada (expected 1 job to be created " + f"but got {num_responses}" + ) + job = resp.job_response_items[0] + + # Throw if armada told us we had submitted something bad + if job.error: + raise AirflowException(f"Error submitting job to Armada: {job.error}") + + return job.job_id + + def _poll_for_termination(self, tracking_message: str) -> None: + last_log_time = None + run_details = None + state = JobState.UNKNOWN + + start_time = time.time() + job_acknowledged = False + while state.is_active(): + response = self.client.get_job_status([self.job_id]) + state = JobState(response.job_states[self.job_id]) + self.log.info( + f"job {self.job_id} is in state: {state.name}. {tracking_message}" + ) + + if state != JobState.UNKNOWN: + job_acknowledged = True + + if ( + not job_acknowledged + and int(time.time() - start_time) > self.job_acknowledgement_timeout + ): + self.log.info( + f"Job {self.job_id} not acknowledged by the Armada server within " + f"timeout ({self.job_acknowledgement_timeout}), terminating" + ) + self.on_kill() + return + + if self.container_logs and not run_details: + if state == JobState.RUNNING or state.is_terminal(): + run_details = self._get_latest_job_run_details(self.job_id) + + if run_details: + try: + # pod_name format is sufficient for now. Ideally pod name should be + # retrieved from queryapi + log_status = self.pod_manager( + run_details.cluster + ).fetch_container_logs( + pod_name=f"armada-{self.job_id}-0", + namespace=self.job_request.namespace, + container_name=self.container_logs, + since_time=last_log_time, + ) + last_log_time = log_status.last_log_time + except Exception as e: + self.log.warning(f"Error fetching logs {e}") + + time.sleep(self.poll_interval) + + self.log.info(f"job {self.job_id} terminated with state: {state.name}") + if state != JobState.SUCCEEDED: + raise AirflowException( + f"job {self.job_id} did not succeed. Final status was {state.name}" + ) + + def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: + job_details = self.client.get_job_details([job_id]).job_details[job_id] + if job_details and job_details.latest_run_id: + for run in job_details.job_runs: + if run.run_id == job_details.latest_run_id: + return run + return None + + @staticmethod + def _annotate_job_request(context, request: JobSubmitRequestItem): + if "ANNOTATION_KEY_PREFIX" in os.environ: + annotation_key_prefix = f'{os.environ.get("ANNOTATION_KEY_PREFIX")}' + else: + annotation_key_prefix = "armadaproject.io/" + + task_id = context["ti"].task_id + run_id = context["run_id"] + dag_id = context["dag"].dag_id + + request.annotations[annotation_key_prefix + "taskId"] = task_id + request.annotations[annotation_key_prefix + "taskRunId"] = run_id + request.annotations[annotation_key_prefix + "dagId"] = dag_id diff --git a/third_party/airflow/armada/operators/armada_deferrable.py b/third_party/airflow/armada/operators/armada_deferrable.py deleted file mode 100644 index f7aa1413637..00000000000 --- a/third_party/airflow/armada/operators/armada_deferrable.py +++ /dev/null @@ -1,301 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from typing import Optional, Sequence, List - -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.context import Context - -from armada_client.armada.submit_pb2 import JobSubmitRequestItem -from armada_client.client import ArmadaClient - -from armada.operators.jobservice import ( - JobServiceClient, - default_jobservice_channel_options, -) -from armada.operators.grpc import GrpcChannelArgsDict, GrpcChannelArguments -from armada.operators.jobservice_asyncio import JobServiceAsyncIOClient -from armada.operators.utils import ( - airflow_error, - search_for_job_complete_async, - annotate_job_request_items, -) -from armada.jobservice import jobservice_pb2 - -from google.protobuf.json_format import MessageToDict, ParseDict - -import jinja2 - - -armada_logger = logging.getLogger("airflow.task") - - -class ArmadaDeferrableOperator(BaseOperator): - """ - Implementation of a deferrable armada operator for airflow. - - Distinguished from ArmadaOperator by its ability to defer itself after - submitting its job_request_items. - - See - https://airflow.apache.org/docs/apache-airflow/stable/authoring-and-scheduling/deferring.html - for more information about deferrable airflow operators. - - Airflow operators inherit from BaseOperator. - - :param name: The name of the airflow task. - :param armada_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the armada server instance. - :param job_service_channel_args: GRPC channel arguments to be used when creating - a grpc channel to connect to the job service instance. - :param armada_queue: The queue name for Armada. - :param job_request_items: A PodSpec that is used by Armada for submitting a job. - :param lookout_url_template: A URL template to be used to provide users - a valid link to the related lookout job in this operator's log. - The format should be: - "https://lookout.armada.domain/jobs?job_id=" where will - be replaced with the actual job ID. - :param poll_interval: How often to poll jobservice to get status. - :return: A deferrable armada operator instance. - """ - - template_fields: Sequence[str] = ("job_request_items",) - - def __init__( - self, - name: str, - armada_channel_args: GrpcChannelArgsDict, - job_service_channel_args: GrpcChannelArgsDict, - armada_queue: str, - job_request_items: List[JobSubmitRequestItem], - lookout_url_template: Optional[str] = None, - poll_interval: int = 30, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.name = name - self.armada_channel_args = GrpcChannelArguments(**armada_channel_args) - - if "options" not in job_service_channel_args: - job_service_channel_args["options"] = default_jobservice_channel_options - - self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args) - self.armada_queue = armada_queue - self.job_request_items = job_request_items - self.lookout_url_template = lookout_url_template - self.poll_interval = poll_interval - - def serialize(self) -> dict: - """ - Get a serialized version of this object. - - :return: A dict of keyword arguments used when instantiating - this object. - """ - - return { - "task_id": self.task_id, - "name": self.name, - "armada_channel_args": self.armada_channel_args.serialize(), - "job_service_channel_args": self.job_service_channel_args.serialize(), - "armada_queue": self.armada_queue, - "job_request_items": self.job_request_items, - "lookout_url_template": self.lookout_url_template, - "poll_interval": self.poll_interval, - } - - def execute(self, context) -> None: - """ - Executes the Armada Operator. Only meant to be called by airflow. - - Submits an Armada job and defers itself to ArmadaJobCompleteTrigger to wait - until the job completes. - - :param context: The airflow context. - - :return: None - """ - self.job_request_items = annotate_job_request_items( - context=context, job_request_items=self.job_request_items - ) - job_service_client = JobServiceClient(self.job_service_channel_args.channel()) - - # Health Check - health = job_service_client.health() - if health.status != jobservice_pb2.HealthCheckResponse.SERVING: - armada_logger.warn("Armada Job Service is not healthy.") - else: - armada_logger.debug("Jobservice is healthy.") - - armada_client = ArmadaClient(channel=self.armada_channel_args.channel()) - - armada_logger.debug("Submitting job(s).") - # This allows us to use a unique id from airflow - # and have all jobs in a dag correspond to same jobset - job = armada_client.submit_jobs( - queue=self.armada_queue, - job_set_id=context["run_id"], - job_request_items=self.job_request_items, - ) - - try: - job_id = job.job_response_items[0].job_id - except Exception: - raise AirflowException("Error submitting job(s) to Armada") - - armada_logger.info("Running Armada job %s with id %s", self.name, job_id) - - lookout_url = self._get_lookout_url(job_id) - if len(lookout_url) > 0: - armada_logger.info("Lookout URL: %s", lookout_url) - - # TODO: configurable timeout? - self.defer( - trigger=ArmadaJobCompleteTrigger( - job_id=job_id, - job_service_channel_args=self.job_service_channel_args.serialize(), - armada_queue=self.armada_queue, - job_set_id=context["run_id"], - airflow_task_name=self.name, - poll_interval=self.poll_interval, - ), - method_name="resume_job_complete", - kwargs={"job_id": job_id}, - ) - - def resume_job_complete(self, context, event: dict, job_id: str) -> None: - """ - Resumes this operator after deferring itself to ArmadaJobCompleteTrigger. - Only meant to be called from within Airflow. - - Reports the result of the job and returns. - - :param context: The airflow context. - :param event: The payload from the TriggerEvent raised by - ArmadaJobCompleteTrigger. - :param job_id: The job ID. - :return: None - """ - - job_state = event["job_state"] - job_message = event["job_message"] - - armada_logger.info( - "Armada Job finished with %s and message: %s", job_state, job_message - ) - airflow_error(job_state, self.name, job_id) - - def _get_lookout_url(self, job_id: str) -> str: - if self.lookout_url_template is None: - return "" - return self.lookout_url_template.replace("", job_id) - - def render_template_fields( - self, - context: Context, - jinja_env: Optional[jinja2.Environment] = None, - ) -> None: - self.job_request_items = [ - MessageToDict(x, preserving_proto_field_name=True) - for x in self.job_request_items - ] - super().render_template_fields(context, jinja_env) - self.job_request_items = [ - ParseDict(x, JobSubmitRequestItem()) for x in self.job_request_items - ] - - -class ArmadaJobCompleteTrigger(BaseTrigger): - """ - An airflow trigger that monitors the job state of an armada job. - - Triggers when the job is complete. - - :param job_id: The job ID to monitor. - :param job_service_channel_args: GRPC channel arguments to be used when - creating a grpc channel to connect to the job service instance. - :param armada_queue: The name of the armada queue. - :param job_set_id: The ID of the job set. - :param airflow_task_name: Name of the airflow task to which this trigger - belongs. - :param poll_interval: How often to poll jobservice to get status. - :return: An armada job complete trigger instance. - """ - - def __init__( - self, - job_id: str, - job_service_channel_args: GrpcChannelArgsDict, - armada_queue: str, - job_set_id: str, - airflow_task_name: str, - poll_interval: int = 30, - ) -> None: - super().__init__() - self.job_id = job_id - self.job_service_channel_args = GrpcChannelArguments(**job_service_channel_args) - self.armada_queue = armada_queue - self.job_set_id = job_set_id - self.airflow_task_name = airflow_task_name - self.poll_interval = poll_interval - - def serialize(self) -> tuple: - return ( - "armada.operators.armada_deferrable.ArmadaJobCompleteTrigger", - { - "job_id": self.job_id, - "job_service_channel_args": self.job_service_channel_args.serialize(), - "armada_queue": self.armada_queue, - "job_set_id": self.job_set_id, - "airflow_task_name": self.airflow_task_name, - "poll_interval": self.poll_interval, - }, - ) - - def __eq__(self, o): - return ( - self.task_id == o.task_id - and self.job_id == o.job_id - and self.job_service_channel_args == o.job_service_channel_args - and self.armada_queue == o.armada_queue - and self.job_set_id == o.job_set_id - and self.airflow_task_name == o.airflow_task_name - and self.poll_interval == o.poll_interval - ) - - async def run(self): - """ - Runs the trigger. Meant to be called by an airflow triggerer process. - """ - job_service_client = JobServiceAsyncIOClient( - channel=self.job_service_channel_args.aio_channel() - ) - - job_state, job_message = await search_for_job_complete_async( - armada_queue=self.armada_queue, - job_set_id=self.job_set_id, - airflow_task_name=self.airflow_task_name, - job_id=self.job_id, - job_service_client=job_service_client, - log=self.log, - poll_interval=self.poll_interval, - ) - yield TriggerEvent({"job_state": job_state, "job_message": job_message}) diff --git a/third_party/airflow/armada/operators/grpc.py b/third_party/airflow/armada/operators/grpc.py deleted file mode 100644 index 3e146ccce07..00000000000 --- a/third_party/airflow/armada/operators/grpc.py +++ /dev/null @@ -1,149 +0,0 @@ -import importlib -from typing import Optional, Sequence, Tuple, Any, TypedDict - -import grpc - - -class CredentialsCallbackDict(TypedDict): - """ - Helper class to provide stronger type checking on Credential callback args. - """ - - module_name: str - function_name: str - function_kwargs: dict - - -class GrpcChannelArgsDict(TypedDict): - """ - Helper class to provide stronger type checking on Grpc channel arugments. - """ - - target: str - options: Optional[Sequence[Tuple[str, Any]]] - compression: Optional[grpc.Compression] - credentials_callback_args: Optional[CredentialsCallbackDict] - - -class CredentialsCallback(object): - """ - Allows the use of an arbitrary callback function to get grpc credentials. - - :param module_name: The fully qualified python module name where the - function is located. - :param function_name: The name of the function to be called. - :param function_kwargs: Keyword arguments to function_name in a dictionary. - """ - - def __init__( - self, - module_name: str, - function_name: str, - function_kwargs: dict, - ) -> None: - self.module_name = module_name - self.function_name = function_name - self.function_kwargs = function_kwargs - - def call(self): - """Do the callback to get grpc credentials.""" - module = importlib.import_module(self.module_name) - func = getattr(module, self.function_name) - return func(**self.function_kwargs) - - -class GrpcChannelArguments(object): - """ - A Serializable GRPC Arguments Object. - - :param target: Target keyword argument used - when instantiating a grpc channel. - :param credentials_callback_args: Arguments to CredentialsCallback to use - when instantiating a grpc channel that takes credentials. - :param options: options keyword argument used - when instantiating a grpc channel. - :param compression: compression keyword argument used - when instantiating a grpc channel. - :return: a GrpcChannelArguments instance - """ - - def __init__( - self, - target: str, - options: Optional[Sequence[Tuple[str, Any]]] = None, - compression: Optional[grpc.Compression] = None, - credentials_callback_args: CredentialsCallbackDict = None, - ) -> None: - self.target = target - self.options = options - self.compression = compression - self.credentials_callback = None - self.credentials_callback_args = credentials_callback_args - - if credentials_callback_args is not None: - self.credentials_callback = CredentialsCallback(**credentials_callback_args) - - def __eq__(self, o): - return ( - self.target == o.target - and self.options == o.options - and self.compression == o.compression - and self.credentials_callback_args == o.credentials_callback_args - ) - - def channel(self) -> grpc.Channel: - """ - Create a grpc.Channel based on arguments supplied to this object. - - :return: Return grpc.insecure_channel if credentials is None. Otherwise - returns grpc.secure_channel. - """ - - if self.credentials_callback is None: - return grpc.insecure_channel( - target=self.target, - options=self.options, - compression=self.compression, - ) - return grpc.secure_channel( - target=self.target, - credentials=self.credentials_callback.call(), - options=self.options, - compression=self.compression, - ) - - def aio_channel(self) -> grpc.aio.Channel: - """ - Create a grpc.aio.Channel (asyncio) based on arguments supplied to this object. - - :return: Return grpc.aio.insecure_channel if credentials is None. Otherwise - returns grpc.aio.secure_channel. - """ - - if self.credentials_callback is None: - return grpc.aio.insecure_channel( - target=self.target, - options=self.options, - compression=self.compression, - ) - return grpc.aio.secure_channel( - target=self.target, - credentials=self.credentials_callback.call(), - options=self.options, - compression=self.compression, - ) - - def serialize(self) -> dict: - """ - Get a serialized version of this object. - - :return: A dict of keyword arguments used when calling - a grpc channel or instantiating this object. - """ - - return { - "target": self.target, - "credentials_callback_args": self.credentials_callback_args, - "options": self.options, - "compression": self.compression, - } diff --git a/third_party/airflow/armada/operators/jobservice.py b/third_party/airflow/armada/operators/jobservice.py deleted file mode 100644 index c6445286064..00000000000 --- a/third_party/airflow/armada/operators/jobservice.py +++ /dev/null @@ -1,97 +0,0 @@ -import json -from typing import Optional - -from armada.jobservice import jobservice_pb2_grpc, jobservice_pb2 - -import grpc -from google.protobuf import empty_pb2 - -default_jobservice_channel_options = [ - ( - "grpc.service_config", - json.dumps( - { - "methodConfig": [ - { - "name": [{"service": "jobservice.JobService"}], - "retryPolicy": { - "maxAttempts": 6 * 5, # A little under 5 minutes. - "initialBackoff": "0.1s", - "maxBackoff": "10s", - "backoffMultiplier": 2, - "retryableStatusCodes": ["UNAVAILABLE"], - }, - } - ] - } - ), - ) -] - - -class JobServiceClient: - """ - The JobService Client - - Implementation of gRPC stubs from JobService - - :param channel: gRPC channel used for authentication. See - https://grpc.github.io/grpc/python/grpc.html - for more information. - :return: a job service client instance - """ - - def __init__(self, channel): - self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel) - - def get_job_status( - self, queue: str, job_set_id: str, job_id: str - ) -> jobservice_pb2.JobServiceResponse: - """Get job status of a given job in a queue and job_set_id. - - Uses the GetJobStatus rpc to get a status of your job - - :param queue: The name of the queue - :param job_set_id: The name of the job set (a grouping of jobs) - :param job_id: The id of the job - :return: A Job Service Request (State, Error) - """ - job_service_request = jobservice_pb2.JobServiceRequest( - queue=queue, job_set_id=job_set_id, job_id=job_id - ) - return self.job_stub.GetJobStatus(job_service_request) - - def health(self) -> jobservice_pb2.HealthCheckResponse: - """Health Check for GRPC Request""" - return self.job_stub.Health(request=empty_pb2.Empty()) - - -def get_retryable_job_service_client( - target: str, - credentials: Optional[grpc.ChannelCredentials] = None, - compression: Optional[grpc.Compression] = None, -) -> JobServiceClient: - """ - Get a JobServiceClient that has retry configured - - :param target: grpc channel target - :param credentials: grpc channel credentials (if needed) - :param compresion: grpc channel compression - - :return: A job service client instance - """ - channel = None - if credentials is None: - channel = grpc.insecure_channel( - target=target, - options=default_jobservice_channel_options, - compression=compression, - ) - else: - channel = grpc.secure_channel( - target=target, - credentials=credentials, - options=default_jobservice_channel_options, - compression=compression, - ) - return JobServiceClient(channel) diff --git a/third_party/airflow/armada/operators/jobservice_asyncio.py b/third_party/airflow/armada/operators/jobservice_asyncio.py deleted file mode 100644 index a40b9fc14a0..00000000000 --- a/third_party/airflow/armada/operators/jobservice_asyncio.py +++ /dev/null @@ -1,80 +0,0 @@ -from armada.jobservice import ( - jobservice_pb2_grpc, - jobservice_pb2, -) -from armada.operators.jobservice import default_jobservice_channel_options - -import grpc -from typing import Optional - -from google.protobuf import empty_pb2 - - -class JobServiceAsyncIOClient: - """ - The JobService AsyncIO Client - - AsyncIO implementation of gRPC stubs from JobService - - :param channel: AsyncIO gRPC channel used for authentication. See - https://grpc.github.io/grpc/python/grpc_asyncio.html - for more information. - :return: A job service client instance - """ - - def __init__(self, channel: grpc.aio.Channel) -> None: - self.job_stub = jobservice_pb2_grpc.JobServiceStub(channel) - - async def get_job_status( - self, queue: str, job_set_id: str, job_id: str - ) -> jobservice_pb2.JobServiceResponse: - """Get job status of a given job in a queue and job_set_id. - - Uses the GetJobStatus rpc to get a status of your job - - :param queue: The name of the queue - :param job_set_id: The name of the job set (a grouping of jobs) - :param job_id: The id of the job - :return: A Job Service Request (State, Error) - """ - job_service_request = jobservice_pb2.JobServiceRequest( - queue=queue, job_set_id=job_set_id, job_id=job_id - ) - response = await self.job_stub.GetJobStatus(job_service_request) - return response - - async def health(self) -> jobservice_pb2.HealthCheckResponse: - """Health Check for GRPC Request""" - response = await self.job_stub.Health(request=empty_pb2.Empty()) - return response - - -def get_retryable_job_service_asyncio_client( - target: str, - credentials: Optional[grpc.ChannelCredentials], - compression: Optional[grpc.Compression], -) -> JobServiceAsyncIOClient: - """ - Get a JobServiceAsyncIOClient that has retry configured - - :param target: grpc channel target - :param credentials: grpc channel credentials (if needed) - :param compresion: grpc channel compression - - :return: A job service asyncio client instance - """ - channel = None - if credentials is None: - channel = grpc.aio.insecure_channel( - target=target, - options=default_jobservice_channel_options, - compression=compression, - ) - else: - channel = grpc.aio.secure_channel( - target=target, - credentials=credentials, - options=default_jobservice_channel_options, - compression=compression, - ) - return JobServiceAsyncIOClient(channel) diff --git a/third_party/airflow/armada/operators/utils.py b/third_party/airflow/armada/operators/utils.py deleted file mode 100644 index 1ab7fa35d04..00000000000 --- a/third_party/airflow/armada/operators/utils.py +++ /dev/null @@ -1,289 +0,0 @@ -import asyncio -import logging -import os -import time - -from airflow.exceptions import AirflowException -from typing import List, Optional, Tuple -from enum import Enum - -from armada.operators.jobservice import JobServiceClient -from armada.operators.jobservice_asyncio import JobServiceAsyncIOClient -from armada.jobservice import jobservice_pb2 -from armada_client.armada import submit_pb2 - - -class JobState(Enum): - SUBMITTED = 0 - DUPLICATE_FOUND = 1 - RUNNING = 2 - FAILED = 3 - SUCCEEDED = 4 - CANCELLED = 5 - JOB_ID_NOT_FOUND = 6 - CONNECTION_ERR = 7 - - -_pb_to_job_state = { - jobservice_pb2.JobServiceResponse.SUBMITTED: JobState.SUBMITTED, - jobservice_pb2.JobServiceResponse.DUPLICATE_FOUND: JobState.DUPLICATE_FOUND, - jobservice_pb2.JobServiceResponse.RUNNING: JobState.RUNNING, - jobservice_pb2.JobServiceResponse.FAILED: JobState.FAILED, - jobservice_pb2.JobServiceResponse.SUCCEEDED: JobState.SUCCEEDED, - jobservice_pb2.JobServiceResponse.CANCELLED: JobState.CANCELLED, - jobservice_pb2.JobServiceResponse.JOB_ID_NOT_FOUND: JobState.JOB_ID_NOT_FOUND, - # NOTE(Clif): For whatever reason CONNECTION_ERR is not present in the - # generated protobuf. - 7: JobState.CONNECTION_ERR, -} - - -def job_state_from_pb(state) -> JobState: - return _pb_to_job_state[state] - - -def airflow_error(job_state: JobState, name: str, job_id: str): - """Throw an error on a terminal event if job errored out - - :param job_state: A JobState enum class - :param name: The name of your armada job - :param job_id: The job id that armada assigns to it - :return: No Return or an AirflowFailException. - - AirflowFailException tells Airflow Schedule to not reschedule the task - - """ - if job_state == JobState.SUCCEEDED: - return - if ( - job_state == JobState.FAILED - or job_state == JobState.CANCELLED - or job_state == JobState.JOB_ID_NOT_FOUND - ): - job_message = job_state.name - # AirflowException allows operator-level retries. AirflowFailException - # does *not*. - raise AirflowException(f"The Armada job {name}:{job_id} {job_message}") - - -def default_job_status_callable( - armada_queue: str, - job_set_id: str, - job_id: str, - job_service_client: JobServiceClient, -) -> jobservice_pb2.JobServiceResponse: - return job_service_client.get_job_status( - queue=armada_queue, job_id=job_id, job_set_id=job_set_id - ) - - -armada_logger = logging.getLogger("airflow.task") - - -def search_for_job_complete( - armada_queue: str, - job_set_id: str, - airflow_task_name: str, - job_id: str, - poll_interval: int = 30, - job_service_client: Optional[JobServiceClient] = None, - job_status_callable=default_job_status_callable, - time_out_for_failure: int = 7200, -) -> Tuple[JobState, str]: - """ - - Poll JobService cache until you get a terminated event. - - A terminated event is SUCCEEDED, FAILED or CANCELLED - - :param armada_queue: The queue for armada - :param job_set_id: Your job_set_id - :param airflow_task_name: The name of your armada job - :param poll_interval: Polling interval for jobservice to get status. - :param job_id: The name of the job id that armada assigns to it - :param job_service_client: A JobServiceClient that is used for polling. - It is optional only for testing - :param job_status_callable: A callable object for test injection. - :param time_out_for_failure: The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - :return: A tuple of JobStateEnum, message - """ - start_time = time.time() - # Overwrite time_out_for_failure by environment variable for configuration - armada_time_out_env = os.getenv("ARMADA_AIRFLOW_TIME_OUT_JOB_ID") - if armada_time_out_env: - time_out_for_failure = int(armada_time_out_env) - while True: - # The else statement is for testing purposes. - # We want to allow a test callable to be passed - if job_service_client: - job_status_return = job_status_callable( - armada_queue=armada_queue, - job_id=job_id, - job_set_id=job_set_id, - job_service_client=job_service_client, - ) - else: - job_status_return = job_status_callable( - armada_queue=armada_queue, job_id=job_id, job_set_id=job_set_id - ) - - job_state = job_state_from_pb(job_status_return.state) - armada_logger.debug(f"Got job state '{job_state.name}' for job {job_id}") - - time.sleep(poll_interval) - if job_state == JobState.SUCCEEDED: - job_message = f"Armada {airflow_task_name}:{job_id} succeeded" - break - if job_state == JobState.FAILED: - job_message = ( - f"Armada {airflow_task_name}:{job_id} failed\n" - f"failed with reason {job_status_return.error}" - ) - break - if job_state == JobState.CANCELLED: - job_message = f"Armada {airflow_task_name}:{job_id} cancelled" - break - if job_state == JobState.CONNECTION_ERR: - log_messages = ( - f"Armada {airflow_task_name}:{job_id} connection error (will retry)" - f"failed with reason {job_status_return.error}" - ) - armada_logger.warning(log_messages) - continue - - if job_state == JobState.JOB_ID_NOT_FOUND: - end_time = time.time() - time_elasped = int(end_time) - int(start_time) - if time_elasped > time_out_for_failure: - job_state = JobState.JOB_ID_NOT_FOUND - job_message = ( - f"Armada {airflow_task_name}:{job_id} could not find a job id and\n" - f"hit a timeout" - ) - break - - return job_state, job_message - - -def annotate_job_request_items( - context, job_request_items: List[submit_pb2.JobSubmitRequestItem] -) -> List[submit_pb2.JobSubmitRequestItem]: - """ - Annotates the inbound job request items with Airflow context elements - - :param context: The airflow context. - - :param job_request_items: The job request items to be sent to armada - - :return: annotated job request items for armada - """ - task_instance = context["ti"] - task_id = task_instance.task_id - run_id = context["run_id"] - dag_id = context["dag"].dag_id - - for item in job_request_items: - item.annotations[get_annotation_key_prefix() + "taskId"] = task_id - item.annotations[get_annotation_key_prefix() + "taskRunId"] = run_id - item.annotations[get_annotation_key_prefix() + "dagId"] = dag_id - - return job_request_items - - -ANNOTATION_KEY_PREFIX = "armadaproject.io/" - - -def get_annotation_key_prefix() -> str: - """ - Provides the annotation key prefix, - which can be specified in env var ANNOTATION_KEY_PREFIX. - A default is provided if the env var is not defined - - :return: string annotation key prefix - """ - env_var_name = "ANNOTATION_KEY_PREFIX" - if env_var_name in os.environ: - return f"{os.environ.get(env_var_name)}" - else: - return ANNOTATION_KEY_PREFIX - - -async def search_for_job_complete_async( - armada_queue: str, - job_set_id: str, - airflow_task_name: str, - job_id: str, - job_service_client: JobServiceAsyncIOClient, - log, - poll_interval: int, - time_out_for_failure: int = 7200, -) -> Tuple[JobState, str]: - """ - - Poll JobService cache asyncronously until you get a terminated event. - - A terminated event is SUCCEEDED, FAILED or CANCELLED - - :param armada_queue: The queue for armada - :param job_set_id: Your job_set_id - :param airflow_task_name: The name of your armada job - :param job_id: The name of the job id that armada assigns to it - :param job_service_client: A JobServiceClient that is used for polling. - It is optional only for testing - :param poll_interval: How often to poll jobservice to get status. - :param time_out_for_failure: The amount of time a job - can be in job_id_not_found - before we decide it was a invalid job - :return: A tuple of JobStateEnum, message - """ - start_time = time.time() - # Overwrite time_out_for_failure by environment variable for configuration - armada_time_out_env = os.getenv("ARMADA_AIRFLOW_TIME_OUT_JOB_ID") - if armada_time_out_env: - time_out_for_failure = int(armada_time_out_env) - while True: - job_status_return = await job_service_client.get_job_status( - queue=armada_queue, - job_id=job_id, - job_set_id=job_set_id, - ) - - job_state = job_state_from_pb(job_status_return.state) - log.debug(f"Got job state '{job_state.name}' for job {job_id}") - - await asyncio.sleep(poll_interval) - - if job_state == JobState.SUCCEEDED: - job_message = f"Armada {airflow_task_name}:{job_id} succeeded" - break - if job_state == JobState.FAILED: - job_message = ( - f"Armada {airflow_task_name}:{job_id} failed\n" - f"failed with reason {job_status_return.error}" - ) - break - if job_state == JobState.CANCELLED: - job_message = f"Armada {airflow_task_name}:{job_id} cancelled" - break - if job_state == JobState.CONNECTION_ERR: - log_messages = ( - f"Armada {airflow_task_name}:{job_id} connection error (will retry)" - f"failed with reason {job_status_return.error}" - ) - log.warning(log_messages) - continue - - if job_state == JobState.JOB_ID_NOT_FOUND: - end_time = time.time() - time_elasped = int(end_time) - int(start_time) - if time_elasped > time_out_for_failure: - job_state = JobState.JOB_ID_NOT_FOUND - job_message = ( - f"Armada {airflow_task_name}:{job_id} could not find a job id and\n" - f"hit a timeout" - ) - break - - return job_state, job_message diff --git a/third_party/airflow/armada/provider.yaml b/third_party/airflow/armada/provider.yaml deleted file mode 100644 index ff95bce9210..00000000000 --- a/third_party/airflow/armada/provider.yaml +++ /dev/null @@ -1,38 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - ---- -package-name: apache-airflow-providers-armada -name: Armada -description: | - `ArmadaOperator `__ -versions: - -0.3.14 - -additional-dependencies: - - apache-airflow>=2.2.0 - -integrations: - - integration-name: Armada - external-doc-url: https://armadaproject.io/ - logo: /integration-logos/armada/armada.png - tags: [software] - -operators: - - integration-name: Armada - python-modules: - - airflow.providers.armada.operators.armada \ No newline at end of file diff --git a/third_party/airflow/armada/triggers/armada.py b/third_party/airflow/armada/triggers/armada.py new file mode 100644 index 00000000000..284fe305169 --- /dev/null +++ b/third_party/airflow/armada/triggers/armada.py @@ -0,0 +1,269 @@ +import asyncio +import importlib +import time +from functools import cached_property +from typing import AsyncIterator, Any, Optional, Tuple, Dict + +from airflow.triggers.base import BaseTrigger, TriggerEvent +from armada_client.armada.job_pb2 import JobRunDetails +from armada_client.typings import JobState + +from armada_client.asyncio_client import ArmadaAsyncIOClient +from armada.auth import TokenRetriever +from armada.logs.pod_log_manager import PodLogManagerAsync +from armada.model import GrpcChannelArgs +from pendulum import DateTime + + +class ArmadaTrigger(BaseTrigger): + """ + An Airflow Trigger that can asynchronously manage an Armada job. + """ + + def __init__( + self, + job_id: str, + armada_queue: str, + job_set_id: str, + poll_interval: int, + tracking_message: str, + job_acknowledgement_timeout: int, + job_request_namespace: str, + channel_args: GrpcChannelArgs = None, + channel_args_details: Dict[str, Any] = None, + container_logs: Optional[str] = None, + k8s_token_retriever: Optional[TokenRetriever] = None, + k8s_token_retriever_details: Optional[Tuple[str, Dict[str, Any]]] = None, + last_log_time: Optional[DateTime] = None, + ): + """ + Initializes an instance of ArmadaTrigger, which is an Airflow trigger for + managing Armada jobs asynchronously. + + :param job_id: The unique identifier of the job to be monitored. + :type job_id: str + :param armada_queue: The Armada queue under which the job was submitted. + Required for job cancellation. + :type armada_queue: str + :param job_set_id: The unique identifier of the job set under which the job + was submitted. Required for job cancellation. + :type job_set_id: str + :param poll_interval: The interval, in seconds, at which the job status will be + checked. + :type poll_interval: int + :param tracking_message: A message to log or display for tracking the job + status. + :type tracking_message: str + :param job_acknowledgement_timeout: The timeout, in seconds, to wait for the job + to be acknowledged by Armada. + :type job_acknowledgement_timeout: int + :param job_request_namespace: The Kubernetes namespace under which the job was + submitted. + :type job_request_namespace: str + :param channel_args: The arguments to configure the gRPC channel. If None, + default arguments will be used. + :type channel_args: GrpcChannelArgs, optional + :param channel_args_details: Additional details or configurations for the gRPC + channel as a dictionary. Only used when + the trigger is rehydrated after serialization. + :type channel_args_details: dict[str, Any], optional + :param container_logs: Name of container from which to retrieve logs + :type container_logs: str, optional + :param k8s_token_retriever: An optional instance of type TokenRetriever, used to + refresh the Kubernetes auth token + :type k8s_token_retriever: TokenRetriever, optional + :param k8s_token_retriever_details: Configuration for TokenRetriever as a + dictionary. + Only used when the trigger is + rehydrated after serialization. + :type k8s_token_retriever_details: Tuple[str, Dict[str, Any]], optional + :param last_log_time: where to resume logs from + :type last_log_time: DateTime, optional + """ + super().__init__() + self.job_id = job_id + self.armada_queue = armada_queue + self.job_set_id = job_set_id + self.poll_interval = poll_interval + self.tracking_message = tracking_message + self.job_acknowledgement_timeout = job_acknowledgement_timeout + self.container_logs = container_logs + self.last_log_time = last_log_time + self.job_request_namespace = job_request_namespace + self._pod_manager = None + self.k8s_token_retriever = k8s_token_retriever + + if channel_args: + self.channel_args = channel_args + elif channel_args_details: + self.channel_args = GrpcChannelArgs(**channel_args_details) + else: + raise f"must provide either {channel_args} or {channel_args_details}" + + if k8s_token_retriever_details: + classpath, kwargs = k8s_token_retriever_details + module_path, class_name = classpath.rsplit( + ".", 1 + ) # Split the classpath to module and class name + module = importlib.import_module( + module_path + ) # Dynamically import the module + cls = getattr(module, class_name) # Get the class from the module + self.k8s_token_retriever = cls( + **kwargs + ) # Instantiate the class with the deserialized kwargs + + def serialize(self) -> tuple: + """ + Serialises the state of this Trigger. + When the Trigger is re-hydrated, these values will be passed to init() as kwargs + :return: + """ + k8s_token_retriever_details = ( + self.k8s_token_retriever.serialize() if self.k8s_token_retriever else None + ) + return ( + "armada.triggers.armada.ArmadaTrigger", + { + "job_id": self.job_id, + "armada_queue": self.armada_queue, + "job_set_id": self.job_set_id, + "channel_args_details": self.channel_args.serialize(), + "poll_interval": self.poll_interval, + "tracking_message": self.tracking_message, + "job_acknowledgement_timeout": self.job_acknowledgement_timeout, + "container_logs": self.container_logs, + "k8s_token_retriever_details": k8s_token_retriever_details, + "last_log_time": self.last_log_time, + "job_request_namespace": self.job_request_namespace, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """ + Run the Trigger Asynchronously. This will poll Armada until the Job reaches a + terminal state + """ + try: + response = await self._poll_for_termination(self.job_id) + yield TriggerEvent(response) + except Exception as exc: + yield TriggerEvent( + { + "status": "error", + "job_id": self.job_id, + "response": f"Job {self.job_id} did not succeed. Error was {exc}", + } + ) + + """Cannot call on_kill from trigger, will asynchronously cancel jobs instead.""" + + async def _cancel_job(self) -> None: + try: + result = await self.client.cancel_jobs( + queue=self.armada_queue, + job_set_id=self.job_set_id, + job_id=self.job_id, + ) + if len(list(result.cancelled_ids)) > 0: + self.log.info(f"Cancelled job with id {result.cancelled_ids}") + else: + self.log.warning(f"Failed to cancel job with id {self.job_id}") + except Exception as e: + self.log.warning(f"Failed to cancel job with id {self.job_id}: {e}") + + async def _poll_for_termination(self, job_id: str) -> Dict[str, Any]: + state = JobState.UNKNOWN + start_time = time.time() + job_acknowledged = False + run_details = None + + # Poll for terminal state + while state.is_active(): + resp = await self.client.get_job_status([job_id]) + state = JobState(resp.job_states[job_id]) + self.log.info( + f"Job {job_id} is in state: {state.name}. {self.tracking_message}" + ) + + if state != JobState.UNKNOWN: + job_acknowledged = True + + if ( + not job_acknowledged + and int(time.time() - start_time) > self.job_acknowledgement_timeout + ): + await self._cancel_job() + return { + "status": "error", + "job_id": job_id, + "response": f"Job {job_id} not acknowledged within timeout " + f"{self.job_acknowledgement_timeout}.", + } + + if self.container_logs and not run_details: + if state == JobState.RUNNING or state.is_terminal(): + run_details = await self._get_latest_job_run_details(self.job_id) + + if run_details: + try: + log_status = await self.pod_manager( + run_details.cluster + ).fetch_container_logs( + pod_name=f"armada-{self.job_id}-0", + namespace=self.job_request_namespace, + container_name=self.container_logs, + since_time=self.last_log_time, + ) + self.last_log_time = log_status.last_log_time + except Exception as e: + self.log.exception(e) + + if state.is_active(): + self.log.debug(f"Sleeping for {self.poll_interval} seconds") + await asyncio.sleep(self.poll_interval) + + self.log.info(f"Job {job_id} terminated with state:{state.name}") + if state != JobState.SUCCEEDED: + return { + "status": "error", + "job_id": job_id, + "response": f"Job {job_id} did not succeed. Final status was " + f"{state.name}", + } + return { + "status": "success", + "job_id": job_id, + "response": f"Job {job_id} succeeded", + } + + @cached_property + def client(self) -> ArmadaAsyncIOClient: + return ArmadaAsyncIOClient(channel=self.channel_args.aio_channel()) + + def pod_manager(self, k8s_context: str) -> PodLogManagerAsync: + if self._pod_manager is None: + self._pod_manager = PodLogManagerAsync( + k8s_context=k8s_context, token_retriever=self.k8s_token_retriever + ) + + return self._pod_manager + + async def _get_latest_job_run_details(self, job_id) -> Optional[JobRunDetails]: + resp = await self.client.get_job_details([job_id]) + job_details = resp.job_details[job_id] + if job_details and job_details.latest_run_id: + for run in job_details.job_runs: + if run.run_id == job_details.latest_run_id: + return run + return None + + def __eq__(self, other): + if not isinstance(other, ArmadaTrigger): + return False + return ( + self.job_id == other.job_id + and self.channel_args.serialize() == other.channel_args.serialize() + and self.poll_interval == other.poll_interval + and self.tracking_message == other.tracking_message + ) diff --git a/third_party/airflow/examples/bad_armada.py b/third_party/airflow/examples/bad_armada.py index 8474eae0351..11bf545691e 100644 --- a/third_party/airflow/examples/bad_armada.py +++ b/third_party/airflow/examples/bad_armada.py @@ -1,5 +1,7 @@ from airflow import DAG from airflow.operators.bash import BashOperator + +from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -56,48 +58,44 @@ def submit_sleep_container(image: str): ) as dag: """ This Airflow DAG follows a similar pattern: - 1) Define arguments for armada and jobservice grpc channels. + 1) Define arguments for armada grpc channel. 2) Define your ArmadaOperator tasks that you want to run. 3) Generate a DAG definition. """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") op = BashOperator(task_id="dummy", bash_command="echo Hello World!") armada = ArmadaOperator( task_id="armada", name="armada", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_container(image="busybox"), + channel_args=armada_channel_args, + job_request=submit_sleep_container(image="busybox")[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) """ - This task is used to verify that if an Armada Job - fails we are correctly telling Airflow that it failed. - """ + This task is used to verify that if an Armada Job + fails we are correctly telling Airflow that it failed. + """ bad_armada = ArmadaOperator( task_id="armada_fail", name="armada_fail", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_container(image="nonexistant"), + channel_args=armada_channel_args, + job_request=submit_sleep_container(image="busybox")[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) good_armada = ArmadaOperator( task_id="good_armada", name="good_armada", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_container(image="busybox"), + channel_args=armada_channel_args, + job_request=submit_sleep_container(image="busybox")[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) """ - Airflow syntax to say - Run op first and then run armada and bad_armada in parallel - If all jobs are successful, run good_armada. - """ + Airflow syntax to say + Run op first and then run armada and bad_armada in parallel + If all jobs are successful, run good_armada. + """ op >> [armada, bad_armada] >> good_armada diff --git a/third_party/airflow/examples/big_armada.py b/third_party/airflow/examples/big_armada.py index dc64cdc76b2..5979e391f0b 100644 --- a/third_party/airflow/examples/big_armada.py +++ b/third_party/airflow/examples/big_armada.py @@ -1,5 +1,7 @@ from airflow import DAG from airflow.operators.bash import BashOperator + +from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -62,11 +64,9 @@ def submit_sleep_job(): default_args={"retries": 2}, ) as dag: """ - The ArmadaDeferrableOperator requires grpc.channel arguments for armada and - the jobservice. + The ArmadaDeferrableOperator requires grpc.channel arguments for armada. """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") """ This defines an Airflow task that runs Hello World and it gives the airflow @@ -75,8 +75,7 @@ def submit_sleep_job(): op = BashOperator(task_id="dummy", bash_command="echo Hello World!") """ This is creating an Armada task with the task_id of armada and name of armada. - The Airflow operator needs queue and job-set for Armada - You also specify the PythonClient and JobServiceClient for each task. + The Airflow operator needs a queue for Armada You should reuse them for all your tasks. This job will use the podspec defined above. """ @@ -84,9 +83,8 @@ def submit_sleep_job(): task_id="armada1", name="armada1", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -94,9 +92,8 @@ def submit_sleep_job(): task_id="armada2", name="armada2", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -104,9 +101,8 @@ def submit_sleep_job(): task_id="armada3", name="armada3", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -114,9 +110,8 @@ def submit_sleep_job(): task_id="armada4", name="armada4", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -124,9 +119,8 @@ def submit_sleep_job(): task_id="armada5", name="armada5", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -134,9 +128,8 @@ def submit_sleep_job(): task_id="armada6", name="armada6", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -144,9 +137,8 @@ def submit_sleep_job(): task_id="armada7", name="armada7", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -154,9 +146,8 @@ def submit_sleep_job(): task_id="armada8", name="armada8", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -164,9 +155,8 @@ def submit_sleep_job(): task_id="armada9", name="armada9", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -174,9 +164,8 @@ def submit_sleep_job(): task_id="armada10", name="armada10", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -184,9 +173,8 @@ def submit_sleep_job(): task_id="armada11", name="armada11", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -194,18 +182,16 @@ def submit_sleep_job(): task_id="armada12", name="armada12", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) armada13 = ArmadaOperator( task_id="armada13", name="armada13", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -213,9 +199,8 @@ def submit_sleep_job(): task_id="armada14", name="armada14", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -223,9 +208,8 @@ def submit_sleep_job(): task_id="armada15", name="armada15", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -233,9 +217,8 @@ def submit_sleep_job(): task_id="armada16", name="armada16", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -243,9 +226,8 @@ def submit_sleep_job(): task_id="armada17", name="armada17", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -253,18 +235,16 @@ def submit_sleep_job(): task_id="armada18", name="armada18", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) armada19 = ArmadaOperator( task_id="armada19", name="armada19", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) @@ -272,9 +252,8 @@ def submit_sleep_job(): task_id="armada20", name="armada20", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) diff --git a/third_party/airflow/examples/hello_armada.py b/third_party/airflow/examples/hello_armada.py index 53c20c78038..0f59932d96c 100644 --- a/third_party/airflow/examples/hello_armada.py +++ b/third_party/airflow/examples/hello_armada.py @@ -1,5 +1,7 @@ from airflow import DAG from airflow.operators.bash import BashOperator + +from armada.model import GrpcChannelArgs from armada.operators.armada import ArmadaOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 @@ -62,11 +64,9 @@ def submit_sleep_job(): default_args={"retries": 2}, ) as dag: """ - The ArmadaOperator requires grpc.channel arguments for armada and - the jobservice. + The ArmadaOperator requires grpc.channel arguments for armada. """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") """ This defines an Airflow task that runs Hello World and it gives the airflow @@ -75,8 +75,7 @@ def submit_sleep_job(): op = BashOperator(task_id="dummy", bash_command="echo Hello World!") """ This is creating an Armada task with the task_id of armada and name of armada. - The Airflow operator needs queue and job-set for Armada - You also specify the PythonClient and JobServiceClient for each task. + The Airflow operator needs queue for Armada You should reuse them for all your tasks. This job will use the podspec defined above. """ @@ -84,9 +83,8 @@ def submit_sleep_job(): task_id="armada", name="armada", armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submit_sleep_job(), + channel_args=armada_channel_args, + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", ) """ diff --git a/third_party/airflow/examples/hello_armada_deferrable.py b/third_party/airflow/examples/hello_armada_deferrable.py index 907242e4932..eb028d61a40 100644 --- a/third_party/airflow/examples/hello_armada_deferrable.py +++ b/third_party/airflow/examples/hello_armada_deferrable.py @@ -1,6 +1,5 @@ from airflow import DAG from airflow.operators.bash import BashOperator -from armada.operators.armada_deferrable import ArmadaDeferrableOperator from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 from armada_client.k8s.io.apimachinery.pkg.api.resource import ( @@ -13,6 +12,9 @@ import pendulum +from armada.model import GrpcChannelArgs +from armada.operators.armada import ArmadaOperator + def submit_sleep_job(): """ @@ -63,12 +65,9 @@ def submit_sleep_job(): default_args={"retries": 2}, ) as dag: """ - The ArmadaDeferrableOperatorOperator requires grpc.aio.channel arguments + The ArmadaOperator requires GrpcChannelArgs arguments """ - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = { - "target": "127.0.0.1:60003", - } + armada_channel_args = GrpcChannelArgs(target="127.0.0.1:50051") """ This defines an Airflow task that runs Hello World and it gives the airflow task name of dummy. @@ -76,19 +75,17 @@ def submit_sleep_job(): op = BashOperator(task_id="dummy", bash_command="echo Hello World!") """ This is creating an Armada task with the task_id of armada and name of armada. - The Airflow operator needs queue and job-set for Armada - You also specify the PythonClient and JobServiceClient channel arguments - for each task. + The Airflow operator needs queue for Armada. This job will use the podspec defined above. """ - armada = ArmadaDeferrableOperator( + armada = ArmadaOperator( task_id="armada_deferrable", name="armada_deferrable", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, + channel_args=armada_channel_args, armada_queue="test", - job_request_items=submit_sleep_job(), + job_request=submit_sleep_job()[0], lookout_url_template="http://127.0.0.1:8089/jobs?job_id=", + deferrable=True, ) """ Airflow dag syntax for running op and then armada. diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index 8a0ffec1c0c..bde16313944 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -1,30 +1,67 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + [project] name = "armada_airflow" -version = "0.5.6" +version = "1.0.0" description = "Armada Airflow Operator" -requires-python = ">=3.7" -# Note(JayF): This dependency value is not suitable for release. Whatever -# release automation we create will have to change this to a dep on a pypi -# package, but we can't do that now because it would make development -# extremely difficult. -dependencies = [ - "armada-client", - "apache-airflow>=2.6.3", - "grpcio==1.58.0", - "grpcio-tools==1.58.0", - "types-protobuf==4.24.0.1", - "protobuf>=3.20,<5.0" -] +readme='README.md' authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] license = { text = "Apache Software License" } -readme = "README.md" +dependencies=[ + 'armada-client==0.3.4', + 'apache-airflow>=2.6.3', + 'grpcio==1.58.0', + 'grpcio-tools==1.58.0', + 'types-protobuf==4.24.0.1', + 'kubernetes>=23.6.0', + 'kubernetes_asyncio>=24.2.3', +] +requires-python=">=3.8" +classifiers=[ + 'Programming Language :: Python :: 3', + 'Operating System :: OS Independent', +] [project.optional-dependencies] -format = ["black==23.7.0", "flake8==7.0.0", "pylint==2.17.5"] -test = ["pytest==7.3.1", "coverage==7.3.2", "pytest-asyncio==0.21.1"] +format = ["black>=24.0.0", "flake8==7.0.0", "pylint==2.17.5"] +test = ["pytest==7.3.1", "coverage==7.3.2", "pytest-asyncio==0.21.1", + "pytest-mock>=3.14.0"] # note(JayF): sphinx-jekyll-builder was broken by sphinx-markdown-builder 0.6 -- so pin to 0.5.5 docs = ["sphinx==7.1.2", "sphinx-jekyll-builder==0.3.0", "sphinx-toolbox==3.2.0b1", "sphinx-markdown-builder==0.5.5"] -[build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +[project.urls] +repository='https://github.com/armadaproject/armada' + +[tools.setuptools.packages.find] +include = ["armada_airflow*"] + +[tool.black] +line-length = 88 +target-version = ['py310'] +include = ''' +/( + armada + | test +)/ +''' +exclude = ''' +/( + \.git + | venv + | build + | dist + | new + | .tox + | docs + | armada_airflow.egg-info + | __pycache__* +)/ +''' + +[tool.flake8] +# These settings are reccomended by upstream black to make flake8 find black +# style formatting correct. +max-line-length = 88 +extend-ignore = "E203" diff --git a/third_party/airflow/armada/__init__.py b/third_party/airflow/test/__init__.py similarity index 100% rename from third_party/airflow/armada/__init__.py rename to third_party/airflow/test/__init__.py diff --git a/third_party/airflow/test/integration/test_airflow_operator_logic.py b/third_party/airflow/test/integration/test_airflow_operator_logic.py new file mode 100644 index 00000000000..c2931715f70 --- /dev/null +++ b/third_party/airflow/test/integration/test_airflow_operator_logic.py @@ -0,0 +1,232 @@ +import os +import uuid +from unittest.mock import MagicMock + +import pytest +import threading + +from airflow.exceptions import AirflowException +from armada_client.typings import JobState +from armada_client.armada import ( + submit_pb2, +) +from armada_client.client import ArmadaClient +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) +import grpc +from typing import Any + +from armada.model import GrpcChannelArgs +from armada.operators.armada import ArmadaOperator + +DEFAULT_TASK_ID = "test_task_1" +DEFAULT_DAG_ID = "test_dag_1" +DEFAULT_RUN_ID = "test_run_1" +DEFAULT_QUEUE = "queue-a" +DEFAULT_NAMESPACE = "personal-anonymous" +DEFAULT_POLLING_INTERVAL = 1 +DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 10 + + +@pytest.fixture(scope="function", name="context") +def default_context() -> Any: + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_ti.xcom_pull.return_value = None + mock_ti.xcom_push.return_value = None + mock_dag = MagicMock() + mock_dag.dag_id = DEFAULT_DAG_ID + return { + "ti": mock_ti, + "run_id": DEFAULT_RUN_ID, + "dag": mock_dag, + } + + +@pytest.fixture(scope="session", name="channel_args") +def queryapi_channel_args() -> GrpcChannelArgs: + server_name = os.environ.get("ARMADA_SERVER", "localhost") + server_port = os.environ.get("ARMADA_PORT", "50051") + + return GrpcChannelArgs(target=f"{server_name}:{server_port}") + + +@pytest.fixture(scope="session", name="client") +def no_auth_client() -> ArmadaClient: + server_name = os.environ.get("ARMADA_SERVER", "localhost") + server_port = os.environ.get("ARMADA_PORT", "50051") + + return ArmadaClient(channel=grpc.insecure_channel(f"{server_name}:{server_port}")) + + +def sleep_pod(image: str): + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="goodsleep", + image=image, + args=["sleep", "5s"], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="0.2"), + "memory": api_resource.Quantity(string="64Mi"), + }, + limits={ + "cpu": api_resource.Quantity(string="0.2"), + "memory": api_resource.Quantity(string="64Mi"), + }, + ), + ) + ], + ) + return [ + submit_pb2.JobSubmitRequestItem( + priority=1, pod_spec=pod, namespace=DEFAULT_NAMESPACE + ) + ] + + +def test_success_job( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + job_set_name = f"test-{uuid.uuid1()}" + job = client.submit_jobs( + queue=DEFAULT_QUEUE, + job_set_id=job_set_name, + job_request_items=sleep_pod(image="busybox"), + ) + job_id = job.job_response_items[0].job_id + + mocker.patch( + "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", + return_value=job_id, + ) + + operator = ArmadaOperator( + task_id=DEFAULT_TASK_ID, + name="test_job_success", + channel_args=channel_args, + armada_queue=DEFAULT_QUEUE, + job_request=sleep_pod(image="busybox")[0], + poll_interval=DEFAULT_POLLING_INTERVAL, + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + ) + + operator.execute(context) + + response = operator.client.get_job_status([job_id]) + assert JobState(response.job_states[job_id]) == JobState.SUCCEEDED + + +def test_bad_job( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + job_set_name = f"test-{uuid.uuid1()}" + job = client.submit_jobs( + queue=DEFAULT_QUEUE, + job_set_id=job_set_name, + job_request_items=sleep_pod(image="NOTACONTAINER"), + ) + job_id = job.job_response_items[0].job_id + + mocker.patch( + "armada.operators.armada.ArmadaOperator._reattach_or_submit_job", + return_value=job_id, + ) + + operator = ArmadaOperator( + task_id=DEFAULT_TASK_ID, + name="test_job_failure", + channel_args=channel_args, + armada_queue=DEFAULT_QUEUE, + job_request=sleep_pod(image="busybox")[0], + poll_interval=DEFAULT_POLLING_INTERVAL, + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + ) + + try: + operator.execute(context) + pytest.fail( + "Operator did not raise AirflowException on job failure as expected" + ) + except AirflowException: # Expected + response = operator.client.get_job_status([job_id]) + assert JobState(response.job_states[job_id]) == JobState.FAILED + except Exception as e: + pytest.fail( + "Operator did not raise AirflowException on job failure as expected, " + f"raised {e} instead" + ) + + +def success_job( + task_number: int, context: Any, channel_args: GrpcChannelArgs +) -> JobState: + operator = ArmadaOperator( + task_id=f"{DEFAULT_TASK_ID}_{task_number}", + name="test_job_success", + channel_args=channel_args, + armada_queue=DEFAULT_QUEUE, + job_request=sleep_pod(image="busybox")[0], + poll_interval=DEFAULT_POLLING_INTERVAL, + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + ) + + operator.execute(context) + + response = operator.client.get_job_status([operator.job_id]) + return JobState(response.job_states[operator.job_id]) + + +@pytest.mark.skip(reason="we should not test performance in the CI.") +def test_parallel_execution( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + threads = [] + success_job(task_number=0, context=context, channel_args=channel_args) + for task_number in range(5): + t = threading.Thread( + target=success_job, args=[task_number, context, channel_args] + ) + t.start() + threads.append(t) + + for thread in threads: + thread.join() + + +@pytest.mark.skip(reason="we should not test performance in the CI.") +def test_parallel_execution_large( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + threads = [] + success_job(task_number=0, context=context, channel_args=channel_args) + for task_number in range(80): + t = threading.Thread( + target=success_job, args=[task_number, context, channel_args] + ) + t.start() + threads.append(t) + + for thread in threads: + thread.join() + + +@pytest.mark.skip(reason="we should not test performance in the CI.") +def test_parallel_execution_huge( + client: ArmadaClient, context: Any, channel_args: GrpcChannelArgs, mocker +): + threads = [] + success_job(task_number=0, context=context, channel_args=channel_args) + for task_number in range(500): + t = threading.Thread( + target=success_job, args=[task_number, context, channel_args] + ) + t.start() + threads.append(t) + + for thread in threads: + thread.join() diff --git a/third_party/airflow/armada/jobservice/__init__.py b/third_party/airflow/test/operators/__init__.py similarity index 100% rename from third_party/airflow/armada/jobservice/__init__.py rename to third_party/airflow/test/operators/__init__.py diff --git a/third_party/airflow/test/operators/test_armada.py b/third_party/airflow/test/operators/test_armada.py new file mode 100644 index 00000000000..f1502bc206f --- /dev/null +++ b/third_party/airflow/test/operators/test_armada.py @@ -0,0 +1,310 @@ +import unittest +from math import ceil +from unittest.mock import MagicMock, patch, PropertyMock + +from airflow.exceptions import AirflowException +from armada_client.armada import submit_pb2, job_pb2 +from armada_client.armada.submit_pb2 import JobSubmitRequestItem +from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 +from armada_client.k8s.io.apimachinery.pkg.api.resource import ( + generated_pb2 as api_resource, +) + +from armada.model import GrpcChannelArgs +from armada.operators.armada import ArmadaOperator +from armada.triggers.armada import ArmadaTrigger + +DEFAULT_JOB_ID = "test_job" +DEFAULT_TASK_ID = "test_task_1" +DEFAULT_DAG_ID = "test_dag_1" +DEFAULT_RUN_ID = "test_run_1" +DEFAULT_QUEUE = "test_queue_1" +DEFAULT_POLLING_INTERVAL = 30 +DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 5 * 60 + + +class TestArmadaOperator(unittest.TestCase): + def setUp(self): + # Set up a mock context + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_dag = MagicMock() + mock_dag.dag_id = DEFAULT_DAG_ID + self.context = { + "ti": mock_ti, + "run_id": DEFAULT_RUN_ID, + "dag": mock_dag, + } + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_execute(self, mock_client_fn, _): + test_cases = [ + { + "name": "Job Succeeds", + "statuses": [submit_pb2.RUNNING, submit_pb2.SUCCEEDED], + "success": True, + }, + { + "name": "Job Failed", + "statuses": [submit_pb2.RUNNING, submit_pb2.FAILED], + "success": False, + }, + { + "name": "Job cancelled", + "statuses": [submit_pb2.RUNNING, submit_pb2.CANCELLED], + "success": False, + }, + { + "name": "Job preempted", + "statuses": [submit_pb2.RUNNING, submit_pb2.PREEMPTED], + "success": False, + }, + { + "name": "Job Succeeds but takes a lot of transitions", + "statuses": [ + submit_pb2.SUBMITTED, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.RUNNING, + submit_pb2.SUCCEEDED, + ], + "success": True, + }, + ] + + for test_case in test_cases: + with self.subTest(test_case=test_case["name"]): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[ + submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID) + ] + ) + + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in test_case["statuses"] + ] + + mock_client_fn.return_value = mock_client + self.context["ti"].xcom_pull.return_value = None + + try: + operator.execute(self.context) + self.assertTrue(test_case["success"]) + except AirflowException: + self.assertFalse(test_case["success"]) + return + + self.assertEqual(mock_client.submit_jobs.call_count, 1) + self.assertEqual( + mock_client.get_job_status.call_count, len(test_case["statuses"]) + ) + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.on_kill", new_callable=PropertyMock) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_unacknowledged_results_in_on_kill(self, mock_client_fn, mock_on_kill, _): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=False, + job_acknowledgement_timeout=-1, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] + ) + mock_client_fn.return_value = mock_client + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [submit_pb2.UNKNOWN, submit_pb2.UNKNOWN] + ] + + self.context["ti"].xcom_pull.return_value = None + operator.execute(self.context) + self.assertEqual(mock_on_kill.call_count, 1) + + """We call on_kill by triggering the job unacknowledged timeout""" + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_on_kill_cancels_job(self, mock_client_fn, _): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=False, + job_acknowledgement_timeout=-1, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] + ) + mock_client_fn.return_value = mock_client + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [ + submit_pb2.UNKNOWN + for _ in range( + 1 + + ceil( + DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL + ) + ) + ] + ] + + self.context["ti"].xcom_pull.return_value = None + operator.execute(self.context) + self.assertEqual(mock_client.cancel_jobs.call_count, 1) + + @patch("time.sleep", return_value=None) + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_job_reattaches(self, mock_client_fn, _): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=False, + job_acknowledgement_timeout=-1, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [ + submit_pb2.UNKNOWN + for _ in range( + 1 + + ceil( + DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT / DEFAULT_POLLING_INTERVAL + ) + ) + ] + ] + mock_client_fn.return_value = mock_client + self.context["ti"].xcom_pull.return_value = {"armada_job_id": DEFAULT_JOB_ID} + + operator.execute(self.context) + self.assertEqual(mock_client.submit_jobs.call_count, 0) + self.assertEqual(operator.job_id, DEFAULT_JOB_ID) + + +class TestArmadaOperatorDeferrable(unittest.IsolatedAsyncioTestCase): + def setUp(self): + # Set up a mock context + mock_ti = MagicMock() + mock_ti.task_id = DEFAULT_TASK_ID + mock_dag = MagicMock() + mock_dag.dag_id = DEFAULT_DAG_ID + self.context = { + "ti": mock_ti, + "run_id": DEFAULT_RUN_ID, + "dag": mock_dag, + } + + @patch("armada.operators.armada.ArmadaOperator.defer") + @patch("armada.operators.armada.ArmadaOperator.client", new_callable=PropertyMock) + def test_execute_deferred(self, mock_client_fn, mock_defer_fn): + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=JobSubmitRequestItem(), + task_id=DEFAULT_TASK_ID, + deferrable=True, + ) + + # Set up Mock Armada + mock_client = MagicMock() + mock_client.submit_jobs.return_value = submit_pb2.JobSubmitResponse( + job_response_items=[submit_pb2.JobSubmitResponseItem(job_id=DEFAULT_JOB_ID)] + ) + mock_client_fn.return_value = mock_client + self.context["ti"].xcom_pull.return_value = None + + operator.execute(self.context) + self.assertEqual(mock_client.submit_jobs.call_count, 1) + mock_defer_fn.assert_called_with( + timeout=operator.execution_timeout, + trigger=ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=operator.job_set_id, # Not relevant for the sake of test + channel_args=operator.channel_args, + poll_interval=operator.poll_interval, + tracking_message="", + job_acknowledgement_timeout=operator.job_acknowledgement_timeout, + job_request_namespace="default", + ), + method_name="_execute_complete", + ) + + def test_templating(self): + """Tests templating for both the job_prefix and the pod spec""" + prefix = "{{ run_id }}" + pod_arg = "{{ run_id }}" + + pod = core_v1.PodSpec( + containers=[ + core_v1.Container( + name="sleep", + image="alpine:3.16.2", + args=[pod_arg], + securityContext=core_v1.SecurityContext(runAsUser=1000), + resources=core_v1.ResourceRequirements( + requests={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + limits={ + "cpu": api_resource.Quantity(string="120m"), + "memory": api_resource.Quantity(string="510Mi"), + }, + ), + ) + ], + ) + job = JobSubmitRequestItem(priority=1, pod_spec=pod, namespace="armada") + + operator = ArmadaOperator( + name="test", + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + armada_queue=DEFAULT_QUEUE, + job_request=job, + job_set_prefix=prefix, + task_id=DEFAULT_TASK_ID, + deferrable=True, + ) + + operator.render_template_fields(self.context) + + self.assertEqual(operator.job_set_prefix, "test_run_1") + self.assertEqual( + operator.job_request.pod_spec.containers[0].args[0], "test_run_1" + ) diff --git a/third_party/airflow/armada/operators/__init__.py b/third_party/airflow/test/triggers/__init__.py similarity index 100% rename from third_party/airflow/armada/operators/__init__.py rename to third_party/airflow/test/triggers/__init__.py diff --git a/third_party/airflow/test/triggers/test_armada.py b/third_party/airflow/test/triggers/test_armada.py new file mode 100644 index 00000000000..29ba4f20990 --- /dev/null +++ b/third_party/airflow/test/triggers/test_armada.py @@ -0,0 +1,207 @@ +import unittest +from unittest.mock import AsyncMock, patch, PropertyMock + +from airflow.triggers.base import TriggerEvent +from armada_client.armada.submit_pb2 import JobState +from armada_client.armada import submit_pb2, job_pb2 + +from armada.model import GrpcChannelArgs +from armada.triggers.armada import ArmadaTrigger + +DEFAULT_JOB_ID = "test_job" +DEFAULT_QUEUE = "test_queue" +DEFAULT_JOB_SET_ID = "test_job_set_id" +DEFAULT_POLLING_INTERVAL = 30 +DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT = 5 * 60 + + +class AsyncMock(unittest.mock.MagicMock): # noqa: F811 + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) + + +class TestArmadaTrigger(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.time = 0 + + def test_serialization(self): + trigger = ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=DEFAULT_JOB_SET_ID, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + poll_interval=30, + tracking_message="test tracking message", + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + job_request_namespace="default", + ) + classpath, kwargs = trigger.serialize() + self.assertEqual("armada.triggers.armada.ArmadaTrigger", classpath) + + rehydrated = ArmadaTrigger(**kwargs) + self.assertEqual(trigger, rehydrated) + + def _time_side_effect(self): + self.time += DEFAULT_POLLING_INTERVAL + return self.time + + @patch("time.time") + @patch("asyncio.sleep", new_callable=AsyncMock) + @patch("armada.triggers.armada.ArmadaTrigger.client", new_callable=PropertyMock) + async def test_execute(self, mock_client_fn, _, time_time): + time_time.side_effect = self._time_side_effect + + test_cases = [ + { + "name": "Job Succeeds", + "statuses": [JobState.RUNNING, JobState.SUCCEEDED], + "expected_responses": [ + TriggerEvent( + { + "status": "success", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} succeeded", + } + ) + ], + }, + { + "name": "Job Failed", + "statuses": [JobState.RUNNING, JobState.FAILED], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} did not succeed. " + f"Final status was FAILED", + } + ) + ], + }, + { + "name": "Job cancelled", + "statuses": [JobState.RUNNING, JobState.CANCELLED], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} did not succeed." + f" Final status was CANCELLED", + } + ) + ], + }, + { + "name": "Job unacknowledged", + "statuses": [JobState.UNKNOWN for _ in range(6)], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} not acknowledged wit" + f"hin timeout {DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT}.", + } + ) + ], + }, + { + "name": "Job preempted", + "statuses": [JobState.RUNNING, JobState.PREEMPTED], + "success": False, + "expected_responses": [ + TriggerEvent( + { + "status": "error", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} did not succeed." + f" Final status was PREEMPTED", + } + ) + ], + }, + { + "name": "Job Succeeds but takes a lot of transitions", + "statuses": [ + JobState.SUBMITTED, + JobState.RUNNING, + JobState.RUNNING, + JobState.RUNNING, + JobState.RUNNING, + JobState.RUNNING, + JobState.SUCCEEDED, + ], + "success": True, + "expected_responses": [ + TriggerEvent( + { + "status": "success", + "job_id": DEFAULT_JOB_ID, + "response": f"Job {DEFAULT_JOB_ID} succeeded", + } + ) + ], + }, + ] + + for test_case in test_cases: + with self.subTest(test_case=test_case["name"]): + trigger = ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=DEFAULT_JOB_SET_ID, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + poll_interval=DEFAULT_POLLING_INTERVAL, + tracking_message="some tracking message", + job_acknowledgement_timeout=DEFAULT_JOB_ACKNOWLEDGEMENT_TIMEOUT, + job_request_namespace="default", + ) + + # Setup Mock Armada + mock_client = AsyncMock() + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in test_case["statuses"] + ] + mock_client.cancel_jobs.return_value = submit_pb2.CancellationResult( + cancelled_ids=[DEFAULT_JOB_ID] + ) + mock_client_fn.return_value = mock_client + responses = [gen async for gen in trigger.run()] + self.assertEqual(test_case["expected_responses"], responses) + self.assertEqual( + len(test_case["statuses"]), mock_client.get_job_status.call_count + ) + + @patch("time.sleep", return_value=None) + @patch("armada.triggers.armada.ArmadaTrigger.client", new_callable=PropertyMock) + async def test_unacknowledged_results_in_job_cancel(self, mock_client_fn, _): + trigger = ArmadaTrigger( + job_id=DEFAULT_JOB_ID, + armada_queue=DEFAULT_QUEUE, + job_set_id=DEFAULT_JOB_SET_ID, + channel_args=GrpcChannelArgs(target="api.armadaproject.io:443"), + poll_interval=DEFAULT_POLLING_INTERVAL, + tracking_message="some tracking message", + job_acknowledgement_timeout=-1, + job_request_namespace="default", + ) + + # Set up Mock Armada + mock_client = AsyncMock() + mock_client.cancel_jobs.return_value = submit_pb2.CancellationResult( + cancelled_ids=[DEFAULT_JOB_ID] + ) + mock_client_fn.return_value = mock_client + mock_client.get_job_status.side_effect = [ + job_pb2.JobStatusResponse(job_states={DEFAULT_JOB_ID: x}) + for x in [JobState.UNKNOWN, JobState.UNKNOWN] + ] + [gen async for gen in trigger.run()] + + self.assertEqual(mock_client.cancel_jobs.call_count, 1) diff --git a/third_party/airflow/tests/__init__.py b/third_party/airflow/tests/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/third_party/airflow/tests/integration/test_airflow_operator_logic.py b/third_party/airflow/tests/integration/test_airflow_operator_logic.py deleted file mode 100644 index f65ced29a67..00000000000 --- a/third_party/airflow/tests/integration/test_airflow_operator_logic.py +++ /dev/null @@ -1,166 +0,0 @@ -import os -import uuid -import pytest -import threading - -from armada_client.armada import ( - submit_pb2, -) -from armada_client.client import ArmadaClient -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) -import grpc - -from armada.operators.jobservice import JobServiceClient -from armada.operators.utils import JobState, search_for_job_complete - - -@pytest.fixture(scope="session", name="jobservice") -def job_service_client() -> ArmadaClient: - server_name = os.environ.get("JOB_SERVICE_HOST", "localhost") - server_port = os.environ.get("JOB_SERVICE_PORT", "60003") - - return JobServiceClient( - channel=grpc.insecure_channel(f"{server_name}:{server_port}") - ) - - -@pytest.fixture(scope="session", name="client") -def no_auth_client() -> ArmadaClient: - server_name = os.environ.get("ARMADA_SERVER", "localhost") - server_port = os.environ.get("ARMADA_PORT", "50051") - - return ArmadaClient(channel=grpc.insecure_channel(f"{server_name}:{server_port}")) - - -def sleep_pod(image: str): - pod = core_v1.PodSpec( - containers=[ - core_v1.Container( - name="goodsleep", - image=image, - args=["sleep", "10s"], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="0.2"), - "memory": api_resource.Quantity(string="64Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="0.2"), - "memory": api_resource.Quantity(string="64Mi"), - }, - ), - ) - ], - ) - return [ - submit_pb2.JobSubmitRequestItem( - priority=1, pod_spec=pod, namespace="personal-anonymous" - ) - ] - - -def test_success_job(client: ArmadaClient, jobservice: JobServiceClient): - job_set_name = f"test-{uuid.uuid1()}" - job = client.submit_jobs( - queue="queue-a", - job_set_id=job_set_name, - job_request_items=sleep_pod(image="busybox"), - ) - job_id = job.job_response_items[0].job_id - - job_state, job_message = search_for_job_complete( - job_service_client=jobservice, - armada_queue="queue-a", - job_set_id=job_set_name, - airflow_task_name="test", - job_id=job_id, - ) - assert job_state == JobState.SUCCEEDED - assert job_message == f"Armada test:{job_id} succeeded" - - -def test_bad_job(client: ArmadaClient, jobservice: JobServiceClient): - job_set_name = f"test-{uuid.uuid1()}" - - job = client.submit_jobs( - queue="queue-a", - job_set_id=job_set_name, - job_request_items=sleep_pod(image="NOTACONTAINER"), - ) - job_id = job.job_response_items[0].job_id - - job_state, job_message = search_for_job_complete( - job_service_client=jobservice, - armada_queue="queue-a", - job_set_id=job_set_name, - airflow_task_name="test", - job_id=job_id, - ) - assert job_state == JobState.FAILED - assert job_message.startswith(f"Armada test:{job_id} failed") - - -job_set_name = "test" - - -def success_job(client: ArmadaClient, jobservice: JobServiceClient): - job = client.submit_jobs( - queue="queue-a", - job_set_id=job_set_name, - job_request_items=sleep_pod(image="busybox"), - ) - job_id = job.job_response_items[0].job_id - - job_state, job_message = search_for_job_complete( - job_service_client=jobservice, - armada_queue="queue-a", - job_set_id=job_set_name, - airflow_task_name="test", - job_id=job_id, - ) - - assert job_state == JobState.SUCCEEDED - assert job_message == f"Armada test:{job_id} succeeded" - - -@pytest.mark.skip(reason="we should not test performance in the CI.") -def test_parallel_execution(client: ArmadaClient, jobservice: JobServiceClient): - threads = [] - success_job(client=client, jobservice=jobservice) - for _ in range(30): - t = threading.Thread(target=success_job, args=[client, jobservice]) - t.start() - threads.append(t) - - for thread in threads: - thread.join() - - -@pytest.mark.skip(reason="we should not test performance in the CI.") -def test_parallel_execution_large(client: ArmadaClient, jobservice: JobServiceClient): - threads = [] - success_job(client=client, jobservice=jobservice) - for _ in range(80): - t = threading.Thread(target=success_job, args=[client, jobservice]) - t.start() - threads.append(t) - - for thread in threads: - thread.join() - - -@pytest.mark.skip(reason="we should not test performance in the CI.") -def test_parallel_execution_huge(client: ArmadaClient, jobservice: JobServiceClient): - threads = [] - success_job(client=client, jobservice=jobservice) - for _ in range(500): - t = threading.Thread(target=success_job, args=[client, jobservice]) - t.start() - threads.append(t) - - for thread in threads: - thread.join() diff --git a/third_party/airflow/tests/unit/armada_client_mock.py b/third_party/airflow/tests/unit/armada_client_mock.py deleted file mode 100644 index fa3fd669a7f..00000000000 --- a/third_party/airflow/tests/unit/armada_client_mock.py +++ /dev/null @@ -1,35 +0,0 @@ -from google.protobuf import empty_pb2 -from armada_client.armada import submit_pb2_grpc, submit_pb2, event_pb2, event_pb2_grpc - - -class SubmitService(submit_pb2_grpc.SubmitServicer): - def CreateQueue(self, request, context): - return empty_pb2.Empty() - - def DeleteQueue(self, request, context): - return empty_pb2.Empty() - - def GetQueue(self, request, context): - return submit_pb2.Queue(name=request.name) - - def SubmitJobs(self, request, context): - submit_items = submit_pb2.JobSubmitResponseItem(job_id="mock") - - return submit_pb2.JobSubmitResponse(job_response_items=[submit_items]) - - def GetQueueInfo(self, request, context): - return submit_pb2.QueueInfo() - - def CancelJobs(self, request, context): - return submit_pb2.CancellationResult() - - def ReprioritizeJobs(self, request, context): - return submit_pb2.JobReprioritizeResponse() - - def UpdateQueue(self, request, context): - return empty_pb2.Empty() - - -class EventService(event_pb2_grpc.EventServicer): - def Watch(self, request, context): - return event_pb2.EventMessage() diff --git a/third_party/airflow/tests/unit/job_service_mock.py b/third_party/airflow/tests/unit/job_service_mock.py deleted file mode 100644 index 62d9641a9e5..00000000000 --- a/third_party/airflow/tests/unit/job_service_mock.py +++ /dev/null @@ -1,65 +0,0 @@ -import grpc - -from armada.jobservice import jobservice_pb2, jobservice_pb2_grpc - - -# TODO - Make this a bit smarter, so we can hit at least one full -# loop in search_for_job_complete. -def mock_dummy_mapper_terminal(request): - if request.job_id == "test_failed": - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.FAILED, error="Test Error" - ) - if request.job_id == "test_succeeded": - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.SUCCEEDED - ) - if request.job_id == "test_cancelled": - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.CANCELLED - ) - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.JOB_ID_NOT_FOUND - ) - - -class JobService(jobservice_pb2_grpc.JobServiceServicer): - def GetJobStatus(self, request, context): - return mock_dummy_mapper_terminal(request) - - def Health(self, request, context): - return jobservice_pb2.HealthCheckResponse( - status=jobservice_pb2.HealthCheckResponse.SERVING - ) - - -class JobServiceOccasionalError(jobservice_pb2_grpc.JobServiceServicer): - def __init__(self): - self.get_job_status_count = 0 - self.health_count = 0 - - def GetJobStatus(self, request, context): - self.get_job_status_count += 1 - if self.get_job_status_count % 3 == 0: - context.set_code(grpc.StatusCode.UNAVAILABLE) - context.set_details("Injected error") - raise Exception("Injected error") - - if self.get_job_status_count < 5: - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.RUNNING - ) - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.SUCCEEDED - ) - - def Health(self, request, context): - self.health_count += 1 - if self.health_count % 3 == 0: - context.set_code(grpc.StatusCode.UNAVAILABLE) - context.set_details("Injected error") - raise Exception("Injected error") - - return jobservice_pb2.HealthCheckResponse( - status=jobservice_pb2.HealthCheckResponse.SERVING - ) diff --git a/third_party/airflow/tests/unit/server_mock.py b/third_party/airflow/tests/unit/server_mock.py deleted file mode 100644 index bbadc20964f..00000000000 --- a/third_party/airflow/tests/unit/server_mock.py +++ /dev/null @@ -1,100 +0,0 @@ -from google.protobuf import empty_pb2 -from armada_client.armada import ( - submit_pb2_grpc, - submit_pb2, - event_pb2, - event_pb2_grpc, - health_pb2, -) - - -class SubmitService(submit_pb2_grpc.SubmitServicer): - def CreateQueue(self, request, context): - return empty_pb2.Empty() - - def DeleteQueue(self, request, context): - return empty_pb2.Empty() - - def GetQueue(self, request, context): - return submit_pb2.Queue(name=request.name) - - def SubmitJobs(self, request, context): - # read job_ids from request.job_request_items - job_ids = [f"job-{i}" for i in range(1, len(request.job_request_items) + 1)] - - job_response_items = [ - submit_pb2.JobSubmitResponseItem(job_id=job_id) for job_id in job_ids - ] - - return submit_pb2.JobSubmitResponse(job_response_items=job_response_items) - - def GetQueueInfo(self, request, context): - return submit_pb2.QueueInfo(name=request.name) - - def CancelJobs(self, request, context): - return submit_pb2.CancellationResult( - cancelled_ids=["job-1"], - ) - - def CancelJobSet(self, request, context): - return empty_pb2.Empty() - - def ReprioritizeJobs(self, request, context): - new_priority = request.new_priority - if len(request.job_ids) > 0: - job_id = request.job_ids[0] - results = { - f"{job_id}": new_priority, - } - - else: - queue = request.queue - job_set_id = request.job_set_id - - results = { - f"{queue}/{job_set_id}": new_priority, - } - - # convert the result dict into a list of tuples - # while also converting ints to strings - - results = [(k, str(v)) for k, v in results.items()] - - return submit_pb2.JobReprioritizeResponse(reprioritization_results=results) - - def UpdateQueue(self, request, context): - return empty_pb2.Empty() - - def CreateQueues(self, request, context): - return submit_pb2.BatchQueueCreateResponse( - failed_queues=[ - submit_pb2.QueueCreateResponse(queue=submit_pb2.Queue(name=queue.name)) - for queue in request.queues - ] - ) - - def UpdateQueues(self, request, context): - return submit_pb2.BatchQueueUpdateResponse( - failed_queues=[ - submit_pb2.QueueUpdateResponse(queue=submit_pb2.Queue(name=queue.name)) - for queue in request.queues - ] - ) - - def Health(self, request, context): - return health_pb2.HealthCheckResponse( - status=health_pb2.HealthCheckResponse.SERVING - ) - - -class EventService(event_pb2_grpc.EventServicer): - def GetJobSetEvents(self, request, context): - events = [event_pb2.EventStreamMessage()] - - for event in events: - yield event - - def Health(self, request, context): - return health_pb2.HealthCheckResponse( - status=health_pb2.HealthCheckResponse.SERVING - ) diff --git a/third_party/airflow/tests/unit/test_airflow_error.py b/third_party/airflow/tests/unit/test_airflow_error.py deleted file mode 100644 index 1e51c08e5ff..00000000000 --- a/third_party/airflow/tests/unit/test_airflow_error.py +++ /dev/null @@ -1,24 +0,0 @@ -from armada.operators.utils import JobState, airflow_error -from airflow.exceptions import AirflowException -import pytest - -testdata_success = [JobState.SUCCEEDED] - - -@pytest.mark.parametrize("state", testdata_success) -def test_airflow_error_successful(state): - airflow_error(state, "hello", "id") - - -testdata_error = [ - (JobState.FAILED, "The Armada job hello:id FAILED"), - (JobState.CANCELLED, "The Armada job hello:id CANCELLED"), - (JobState.JOB_ID_NOT_FOUND, "The Armada job hello:id JOB_ID_NOT_FOUND"), -] - - -@pytest.mark.parametrize("state, expected_exception_message", testdata_error) -def test_airflow_error_states(state, expected_exception_message): - with pytest.raises(AirflowException) as airflow: - airflow_error(state, "hello", "id") - assert str(airflow.value) == expected_exception_message diff --git a/third_party/airflow/tests/unit/test_airflow_operator_mock.py b/third_party/airflow/tests/unit/test_airflow_operator_mock.py deleted file mode 100644 index 1ab2d37ced1..00000000000 --- a/third_party/airflow/tests/unit/test_airflow_operator_mock.py +++ /dev/null @@ -1,217 +0,0 @@ -from airflow import DAG -from airflow.models.taskinstance import TaskInstance -from airflow.utils.context import Context -from armada_client.client import ArmadaClient -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) - -import grpc -from concurrent import futures -from armada_client.armada import submit_pb2_grpc, submit_pb2, event_pb2_grpc - -import pendulum -import pytest -from armada.operators.armada import ArmadaOperator, annotate_job_request_items -from armada.operators.jobservice import JobServiceClient -from armada.operators.utils import JobState, search_for_job_complete -from armada.jobservice import jobservice_pb2_grpc, jobservice_pb2 -from armada_client_mock import SubmitService, EventService -from job_service_mock import JobService - - -@pytest.fixture(scope="session", autouse=True) -def server_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - submit_pb2_grpc.add_SubmitServicer_to_server(SubmitService(), server) - event_pb2_grpc.add_EventServicer_to_server(EventService(), server) - server.add_insecure_port("[::]:50099") - server.start() - - yield - server.stop(False) - - -@pytest.fixture(scope="session", autouse=True) -def job_service_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - jobservice_pb2_grpc.add_JobServiceServicer_to_server(JobService(), server) - server.add_insecure_port("[::]:60081") - server.start() - - yield - server.stop(False) - - -tester_client = ArmadaClient( - grpc.insecure_channel( - target="127.0.0.1:50099", - ) -) -tester_jobservice = JobServiceClient(grpc.insecure_channel(target="127.0.0.1:60081")) - - -def generate_pod_spec(name: str = "container-1") -> core_v1.PodSpec: - ps = core_v1.PodSpec( - containers=[ - core_v1.Container( - name=name, - image="busybox", - args=["sleep", "10s"], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - ), - ) - ], - ) - return ps - - -def sleep_job(): - pod = generate_pod_spec() - return [submit_pb2.JobSubmitRequestItem(priority=0, pod_spec=pod)] - - -def pre_template_sleep_job(): - pod = generate_pod_spec(name="name-{{ run_id }}") - return [submit_pb2.JobSubmitRequestItem(priority=0, pod_spec=pod)] - - -def expected_sleep_job(): - pod = generate_pod_spec(name="name-another-run-id") - return [submit_pb2.JobSubmitRequestItem(priority=0, pod_spec=pod)] - - -def test_job_service_health(): - health = tester_jobservice.health() - assert health.status == jobservice_pb2.HealthCheckResponse.SERVING - - -def test_mock_success_job(): - tester_client.submit_jobs( - queue="test", - job_set_id="test", - job_request_items=sleep_job(), - ) - - job_state, job_message = search_for_job_complete( - job_service_client=tester_jobservice, - armada_queue="test", - job_set_id="test", - airflow_task_name="test-mock", - job_id="test_succeeded", - ) - assert job_state == JobState.SUCCEEDED - assert job_message == "Armada test-mock:test_succeeded succeeded" - - -def test_mock_failed_job(): - tester_client.submit_jobs( - queue="test", - job_set_id="test", - job_request_items=sleep_job(), - ) - - job_state, job_message = search_for_job_complete( - job_service_client=tester_jobservice, - armada_queue="test", - job_set_id="test", - airflow_task_name="test-mock", - job_id="test_failed", - ) - assert job_state == JobState.FAILED - assert job_message.startswith("Armada test-mock:test_failed failed") - - -def test_mock_cancelled_job(): - tester_client.submit_jobs( - queue="test", - job_set_id="test", - job_request_items=sleep_job(), - ) - - job_state, job_message = search_for_job_complete( - job_service_client=tester_jobservice, - armada_queue="test", - job_set_id="test", - airflow_task_name="test-mock", - job_id="test_cancelled", - ) - assert job_state == JobState.CANCELLED - assert job_message == "Armada test-mock:test_cancelled cancelled" - - -def test_annotate_job_request_items(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - job_request_items = sleep_job() - task_id = "58896abbfr9" - operator = ArmadaOperator( - task_id=task_id, - name="armada-task", - armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=job_request_items, - lookout_url_template="http://127.0.0.1:8089", - ) - - task_instance = TaskInstance(operator) - dag = DAG( - dag_id="hello_armada", - start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule="@daily", - catchup=False, - default_args={"retries": 2}, - ) - context = {"ti": task_instance, "dag": dag, "run_id": "some-run-id"} - - result = annotate_job_request_items(context, job_request_items) - assert result[0].annotations == { - "armadaproject.io/taskId": task_id, - "armadaproject.io/taskRunId": "some-run-id", - "armadaproject.io/dagId": "hello_armada", - } - - -def test_parameterize_armada_operator(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - submitted_job_request_items = pre_template_sleep_job() - expected_job_request_items = expected_sleep_job() - task_id = "123456789ab" - operator = ArmadaOperator( - task_id=task_id, - name="armada-task", - armada_queue="test", - job_service_channel_args=job_service_channel_args, - armada_channel_args=armada_channel_args, - job_request_items=submitted_job_request_items, - lookout_url_template="http://127.0.0.1:8089", - ) - task_instance = TaskInstance(operator) - dag = DAG( - dag_id="hello_armada", - start_date=pendulum.datetime(2016, 1, 1, tz="UTC"), - schedule="@daily", - catchup=False, - default_args={"retries": 2}, - ) - context = Context(ti=task_instance, dag=dag, run_id="another-run-id") - - assert operator.job_request_items != expected_job_request_items - - operator.render_template_fields(context) - - assert operator.job_request_items == expected_job_request_items diff --git a/third_party/airflow/tests/unit/test_armada_deferrable_operator.py b/third_party/airflow/tests/unit/test_armada_deferrable_operator.py deleted file mode 100644 index 0f156ed177e..00000000000 --- a/third_party/airflow/tests/unit/test_armada_deferrable_operator.py +++ /dev/null @@ -1,171 +0,0 @@ -import copy - -import pytest - -from armada_client.armada import submit_pb2 -from armada_client.k8s.io.api.core.v1 import generated_pb2 as core_v1 -from armada_client.k8s.io.apimachinery.pkg.api.resource import ( - generated_pb2 as api_resource, -) -from armada.operators.armada_deferrable import ArmadaDeferrableOperator -from armada.operators.grpc import CredentialsCallback - - -def test_serialize_armada_deferrable(): - grpc_chan_args = { - "target": "localhost:443", - "credentials_callback_args": { - "module_name": "channel_test", - "function_name": "get_credentials", - "function_kwargs": { - "example_arg": "test", - }, - }, - } - - pod = core_v1.PodSpec( - containers=[ - core_v1.Container( - name="sleep", - image="busybox", - args=["sleep", "10s"], - securityContext=core_v1.SecurityContext(runAsUser=1000), - resources=core_v1.ResourceRequirements( - requests={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - limits={ - "cpu": api_resource.Quantity(string="120m"), - "memory": api_resource.Quantity(string="510Mi"), - }, - ), - ) - ], - ) - - job_requests = [ - submit_pb2.JobSubmitRequestItem( - priority=1, - pod_spec=pod, - namespace="personal-anonymous", - annotations={"armadaproject.io/hello": "world"}, - ) - ] - - source = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test task", - armada_channel_args=grpc_chan_args, - job_service_channel_args=grpc_chan_args, - armada_queue="test-queue", - job_request_items=job_requests, - lookout_url_template="https://lookout.test.domain/", - poll_interval=5, - ) - - serialized = source.serialize() - assert serialized["name"] == source.name - - reconstituted = ArmadaDeferrableOperator(**serialized) - assert reconstituted == source - - -get_lookout_url_test_cases = [ - ( - "http://localhost:8089/jobs?job_id=", - "test_id", - "http://localhost:8089/jobs?job_id=test_id", - ), - ( - "https://lookout.armada.domain/jobs?job_id=", - "test_id", - "https://lookout.armada.domain/jobs?job_id=test_id", - ), - ("", "test_id", ""), - (None, "test_id", ""), -] - - -@pytest.mark.parametrize( - "lookout_url_template, job_id, expected_url", get_lookout_url_test_cases -) -def test_get_lookout_url(lookout_url_template, job_id, expected_url): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template=lookout_url_template, - ) - - assert operator._get_lookout_url(job_id) == expected_url - - -def test_deepcopy_operator(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def test_deepcopy_operator_with_grpc_credentials_callback(): - armada_channel_args = { - "target": "127.0.0.1:50051", - "credentials_callback_args": { - "module_name": "tests.unit.test_armada_operator", - "function_name": "__example_test_callback", - "function_kwargs": { - "test_arg": "fake_arg", - }, - }, - } - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaDeferrableOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def __example_test_callback(foo=None): - return f"fake_cred {foo}" - - -def test_credentials_callback(): - callback = CredentialsCallback( - module_name="test_armada_operator", - function_name="__example_test_callback", - function_kwargs={"foo": "bar"}, - ) - - result = callback.call() - assert result == "fake_cred bar" diff --git a/third_party/airflow/tests/unit/test_armada_operator.py b/third_party/airflow/tests/unit/test_armada_operator.py deleted file mode 100644 index 571d634dc70..00000000000 --- a/third_party/airflow/tests/unit/test_armada_operator.py +++ /dev/null @@ -1,197 +0,0 @@ -import copy -from unittest.mock import patch, Mock - -import grpc -import pytest - -from armada.jobservice import jobservice_pb2 -from armada.operators.armada import ArmadaOperator -from armada.operators.grpc import CredentialsCallback -from armada.operators.utils import JobState - -get_lookout_url_test_cases = [ - ( - "http://localhost:8089/jobs?job_id=", - "test_id", - "http://localhost:8089/jobs?job_id=test_id", - ), - ( - "https://lookout.armada.domain/jobs?job_id=", - "test_id", - "https://lookout.armada.domain/jobs?job_id=test_id", - ), - ("", "test_id", ""), - (None, "test_id", ""), -] - - -@pytest.mark.parametrize( - "lookout_url_template, job_id, expected_url", get_lookout_url_test_cases -) -def test_get_lookout_url(lookout_url_template, job_id, expected_url): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template=lookout_url_template, - ) - - assert operator._get_lookout_url(job_id) == expected_url - - -def test_deepcopy_operator(): - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -@pytest.mark.skip("demonstrates how the old way of passing in credentials fails") -def test_deepcopy_operator_with_grpc_credentials(): - armada_channel_args = { - "target": "127.0.0.1:50051", - "credentials": grpc.composite_channel_credentials( - grpc.ssl_channel_credentials(), - grpc.metadata_call_credentials(("authorization", "fake_jwt")), - ), - } - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def test_deepcopy_operator_with_grpc_credentials_callback(): - armada_channel_args = { - "target": "127.0.0.1:50051", - "credentials_callback_args": { - "module_name": "tests.unit.test_armada_operator", - "function_name": "__example_test_callback", - "function_kwargs": { - "test_arg": "fake_arg", - }, - }, - } - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="http://localhost:8089/jobs?job_id=", - ) - - try: - copy.deepcopy(operator) - except Exception as e: - assert False, f"{e}" - - -def __example_test_callback(foo=None): - return f"fake_cred {foo}" - - -def test_credentials_callback(): - callback = CredentialsCallback( - module_name="test_armada_operator", - function_name="__example_test_callback", - function_kwargs={"foo": "bar"}, - ) - - result = callback.call() - assert result == "fake_cred bar" - - -@patch("armada.operators.armada.search_for_job_complete") -@patch("armada.operators.armada.ArmadaClient", autospec=True) -@patch("armada.operators.armada.JobServiceClient", autospec=True) -def test_armada_operator_execute( - JobServiceClientMock, ArmadaClientMock, search_for_job_complete_mock -): - jsclient_mock = Mock() - jsclient_mock.health.return_value = jobservice_pb2.HealthCheckResponse( - status=jobservice_pb2.HealthCheckResponse.SERVING - ) - - JobServiceClientMock.return_value = jsclient_mock - - item = Mock() - item.job_id = "fake_id" - - job = Mock() - job.job_response_items = [ - item, - ] - - aclient_mock = Mock() - aclient_mock.submit_jobs.return_value = job - ArmadaClientMock.return_value = aclient_mock - - search_for_job_complete_mock.return_value = (JobState.SUCCEEDED, "No error") - - armada_channel_args = {"target": "127.0.0.1:50051"} - job_service_channel_args = {"target": "127.0.0.1:60003"} - - operator = ArmadaOperator( - task_id="test_task_id", - name="test_task", - armada_channel_args=armada_channel_args, - job_service_channel_args=job_service_channel_args, - armada_queue="test_queue", - job_request_items=[], - lookout_url_template="https://lookout.armada.domain/jobs?job_id=", - ) - - task_instance = Mock() - task_instance.task_id = "mock_task_id" - - dag = Mock() - dag.dag_id = "mock_dag_id" - - context = { - "run_id": "mock_run_id", - "ti": task_instance, - "dag": dag, - } - - try: - operator.execute(context) - except Exception as e: - assert False, f"{e}" - - jsclient_mock.health.assert_called() - aclient_mock.submit_jobs.assert_called() diff --git a/third_party/airflow/tests/unit/test_grpc.py b/third_party/airflow/tests/unit/test_grpc.py deleted file mode 100644 index 1e12b566067..00000000000 --- a/third_party/airflow/tests/unit/test_grpc.py +++ /dev/null @@ -1,26 +0,0 @@ -import armada.operators.grpc - - -def test_serialize_grpc_channel(): - src_chan_args = { - "target": "localhost:443", - "credentials_callback_args": { - "module_name": "channel_test", - "function_name": "get_credentials", - "function_kwargs": { - "example_arg": "test", - }, - }, - } - - source = armada.operators.grpc.GrpcChannelArguments(**src_chan_args) - - serialized = source.serialize() - assert serialized["target"] == src_chan_args["target"] - assert ( - serialized["credentials_callback_args"] - == src_chan_args["credentials_callback_args"] - ) - - reconstituted = armada.operators.grpc.GrpcChannelArguments(**serialized) - assert reconstituted == source diff --git a/third_party/airflow/tests/unit/test_search_for_job_complete.py b/third_party/airflow/tests/unit/test_search_for_job_complete.py deleted file mode 100644 index 279d5f08e17..00000000000 --- a/third_party/airflow/tests/unit/test_search_for_job_complete.py +++ /dev/null @@ -1,75 +0,0 @@ -from armada.operators.utils import JobState, search_for_job_complete -from armada.jobservice import jobservice_pb2 - - -def test_failed_event(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.FAILED, error="Testing Failure" - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - ) - assert job_complete[0] == JobState.FAILED - assert ( - job_complete[1] == "Armada test:id failed\nfailed with reason Testing Failure" - ) - - -def test_successful_event(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.SUCCEEDED - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - ) - assert job_complete[0] == JobState.SUCCEEDED - assert job_complete[1] == "Armada test:id succeeded" - - -def test_cancelled_event(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.CANCELLED - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - ) - assert job_complete[0] == JobState.CANCELLED - assert job_complete[1] == "Armada test:id cancelled" - - -def test_job_id_not_found(): - def test_callable(armada_queue: str, job_set_id: str, job_id: str): - return jobservice_pb2.JobServiceResponse( - state=jobservice_pb2.JobServiceResponse.JOB_ID_NOT_FOUND - ) - - job_complete = search_for_job_complete( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - job_status_callable=test_callable, - time_out_for_failure=5, - ) - assert job_complete[0] == JobState.JOB_ID_NOT_FOUND - assert ( - job_complete[1] == "Armada test:id could not find a job id and\nhit a timeout" - ) diff --git a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py b/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py deleted file mode 100644 index a842fa994d3..00000000000 --- a/third_party/airflow/tests/unit/test_search_for_job_complete_asyncio.py +++ /dev/null @@ -1,152 +0,0 @@ -from concurrent import futures -import logging - -import grpc -import pytest -import pytest_asyncio - -from job_service_mock import JobService, JobServiceOccasionalError - -from armada.operators.jobservice_asyncio import JobServiceAsyncIOClient -from armada.operators.jobservice import default_jobservice_channel_options -from armada.operators.utils import JobState, search_for_job_complete_async -from armada.jobservice import jobservice_pb2_grpc, jobservice_pb2 - - -@pytest.fixture -def server_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - jobservice_pb2_grpc.add_JobServiceServicer_to_server(JobService(), server) - server.add_insecure_port("[::]:50100") - server.start() - yield - server.stop(False) - - -@pytest_asyncio.fixture(scope="function") -async def js_aio_client(server_mock): - channel = grpc.aio.insecure_channel( - target="127.0.0.1:50100", - options={ - "grpc.keepalive_time_ms": 30000, - }.items(), - ) - await channel.channel_ready() - assert channel.get_state(True) == grpc.ChannelConnectivity.READY - - return JobServiceAsyncIOClient(channel) - - -@pytest.fixture -def server_occasional_error_mock(): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - jobservice_pb2_grpc.add_JobServiceServicer_to_server( - JobServiceOccasionalError(), server - ) - server.add_insecure_port("[::]:50101") - server.start() - yield - server.stop(False) - - -@pytest_asyncio.fixture(scope="function") -async def js_aio_retry_client(server_occasional_error_mock): - channel = grpc.aio.insecure_channel( - target="127.0.0.1:50101", - options=default_jobservice_channel_options, - ) - await channel.channel_ready() - assert channel.get_state(True) == grpc.ChannelConnectivity.READY - - return JobServiceAsyncIOClient(channel) - - -@pytest.mark.asyncio -async def test_failed_event(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_failed", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.FAILED - assert ( - job_complete[1] - == "Armada test:test_failed failed\nfailed with reason Test Error" - ) - - -@pytest.mark.asyncio -async def test_successful_event(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_succeeded", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.SUCCEEDED - assert job_complete[1] == "Armada test:test_succeeded succeeded" - - -@pytest.mark.asyncio -async def test_cancelled_event(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_cancelled", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.CANCELLED - assert job_complete[1] == "Armada test:test_cancelled cancelled" - - -@pytest.mark.asyncio -async def test_job_id_not_found(js_aio_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="id", - armada_queue="test", - job_set_id="test", - time_out_for_failure=5, - job_service_client=js_aio_client, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.JOB_ID_NOT_FOUND - assert ( - job_complete[1] == "Armada test:id could not find a job id and\nhit a timeout" - ) - - -@pytest.mark.asyncio -async def test_healthy(js_aio_client): - health = await js_aio_client.health() - assert health.status == jobservice_pb2.HealthCheckResponse.SERVING - - -@pytest.mark.asyncio -async def test_error_retry(js_aio_retry_client): - job_complete = await search_for_job_complete_async( - airflow_task_name="test", - job_id="test_succeeded", - armada_queue="test", - job_set_id="test", - job_service_client=js_aio_retry_client, - time_out_for_failure=5, - log=logging.getLogger(), - poll_interval=1, - ) - assert job_complete[0] == JobState.SUCCEEDED - assert job_complete[1] == "Armada test:test_succeeded succeeded" diff --git a/third_party/airflow/tox.ini b/third_party/airflow/tox.ini index abfb8db10a5..09dd8ce15ea 100644 --- a/third_party/airflow/tox.ini +++ b/third_party/airflow/tox.ini @@ -13,7 +13,7 @@ allowlist_externals = find xargs commands = - coverage run -m pytest tests/unit/ + coverage run -m unittest discover coverage xml # This executes the dag files in examples but really only checks for imports and python errors bash -c "find examples/ -maxdepth 1 -type f -name *.py | xargs python3" @@ -21,18 +21,18 @@ commands = [testenv:format] extras = format commands = - black --check armada/operators tests/ examples/ -# Disabled until mypy reaches v1.0 -# mypy --ignore-missing-imports armada/operators tests/ examples/ - flake8 armada/operators tests/ examples/ + black armada/ test/ examples/ +# Disabled until mypy reaches v1.0 +# mypy --ignore-missing-imports armada/operators test/ examples/ + flake8 armada/ test/ examples/ -[testenv:format-code] +[testenv:format-check] extras = format commands = - black armada/operators tests/ examples/ -# Disabled until mypy reaches v1.0 -# mypy --ignore-missing-imports armada/operators tests/ examples/ - flake8 armada/operators tests/ examples/ + black --check armada/ test/ examples/ +# Disabled until mypy reaches v1.0 +# mypy --ignore-missing-imports armada/operators test/ examples/ + flake8 armada/ test/ examples/ [testenv:docs] basepython = python3.10