diff --git a/.gitignore b/.gitignore index 63e57c4..65d6e84 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ airflow_provider_ray.egg-info/* .idea .idea/ *.iml + +# VsCode +.vscode \ No newline at end of file diff --git a/ray_provider/example_dags/anyscale_cluster.py b/ray_provider/example_dags/anyscale_cluster.py new file mode 100644 index 0000000..ab59a50 --- /dev/null +++ b/ray_provider/example_dags/anyscale_cluster.py @@ -0,0 +1,65 @@ +from datetime import datetime + +from airflow.decorators import dag +from airflow.utils.trigger_rule import TriggerRule + +from ray_provider.operators.anyscale_cluster import ( + AnyscaleCreateClusterOperator, + AnyscaleStartClusterOperator, + AnyscaleTerminateClusterOperator, +) +from ray_provider.operators.anyscale_cluster import AnyscaleCreateSessionCommandOperator + + +AUTH_TOKEN = "" +DEFAULT_ARGS = { + "owner": "airflow", + "retries": 1, + "retry_delay": 0, +} + + +@dag( + schedule_interval=None, + start_date=datetime(2022, 9, 30), + tags=["demo"], + default_args=DEFAULT_ARGS, +) +def anyscale_cluster(): + + cluster = AnyscaleCreateClusterOperator( + task_id="create_cluster", + name="", + project_id="", + compute_config_id="", + cluster_environment_build_id="", + auth_token=AUTH_TOKEN, + ) + + start = AnyscaleStartClusterOperator( + task_id="start_cluster", + cluster_id=cluster.output["id"], + auth_token=AUTH_TOKEN, + wait_for_completion=True, + ) + + job = AnyscaleCreateSessionCommandOperator( + task_id="submit_job", + session_id=cluster.output["id"], + shell_command="python3 -c 'import ray'", + auth_token=AUTH_TOKEN, + wait_for_completion=True, + ) + + terminate = AnyscaleTerminateClusterOperator( + task_id="terminate_cluster", + auth_token=AUTH_TOKEN, + cluster_id=cluster.output["id"], + wait_for_completion=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + + cluster >> start >> job >> terminate + + +dag = anyscale_cluster() diff --git a/ray_provider/operators/__init__.py b/ray_provider/operators/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/ray_provider/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/ray_provider/operators/anyscale_cluster.py b/ray_provider/operators/anyscale_cluster.py new file mode 100644 index 0000000..2463ae1 --- /dev/null +++ b/ray_provider/operators/anyscale_cluster.py @@ -0,0 +1,242 @@ +import time + +from typing import List, Optional, Sequence + +from ray_provider.utils import push_to_xcom +from ray_provider.operators.base import AnyscaleBaseOperator +from ray_provider.sensors.anyscale_cluster import AnyscaleClusterSensor + +from airflow.utils.context import Context +from airflow.exceptions import AirflowException + +from anyscale.shared_anyscale_utils.utils.byod import BYODInfo +from anyscale.sdk.anyscale_client.models.cluster import Cluster + + +class AnyscaleCreateClusterOperator(AnyscaleBaseOperator): + """ + An operator that creates a Cluster. + + :param name: Name of the Cluster. (templated) + :param cluster_environment_build_id: Cluster Environment Build ID that the Cluster is using. (templated) + :param docker: Docker image for BYOD. (templated) + :param project_id: Project that this Cluster belongs to. If none, the Cluster will use the default Project. (templated) + :param ray_version: Ray version (only used for BYOD). (templated) (default: "1.13.0") + :param python_version: Python version (only used for BYOD). (templated) (default: "py38") + :param compute_config_id: Cluster Compute that the Cluster is using. (templated) + """ + + template_fields: Sequence[str] = [ + "name", + "auth_token", + "cluster_environment_build_id", + "docker", + "project_id", + "ray_version", + "python_version", + "compute_config_id", + ] + + def __init__( + self, + *, + name: str, + cluster_environment_build_id: str = None, + docker: str = None, + project_id: str = None, + ray_version: Optional[str] = "1.13.0", + python_version: Optional[str] = "py38", + compute_config_id: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.name = name + self.project_id = project_id + self.docker = docker + self.cluster_environment_build_id = cluster_environment_build_id + + self.ray_version = ray_version + self.python_version = python_version + self.compute_config_id = compute_config_id + + self._ignore_keys = [ + "services_urls", + "ssh_authorized_keys", + "ssh_private_key", + "user_service_token", + "access_token", + ] + + def _search_clusters(self) -> List[Cluster]: + clusters_query = { + "name": { + "equals": self.name, + }, + "project_id": self.project_id, + } + + clusters: List[Cluster] = self.sdk.search_clusters( + clusters_query=clusters_query).results + return clusters + + def _get_cluster_environment_build_id(self) -> str: + + cluster_environment_build_id = None + + if self.docker: + + cluster_environment_build_id = BYODInfo( + docker_image_name=self.docker, + python_version=self.python_version, + ray_version=self.ray_version, + ).encode() + + if self.cluster_environment_build_id: + if self.docker: + self.log.info( + "docker is ignored when cluster_environment_build_id is provided.") + + cluster_environment_build_id = self.cluster_environment_build_id + + if cluster_environment_build_id is None: + raise AirflowException( + "at least cluster_environment_build_id or docker must be provided.") + + return cluster_environment_build_id + + def execute(self, context: Context) -> None: + + clusters = self._search_clusters() + + if clusters: + self.log.info( + "cluster with name %s in %s already exists", self.name, self.project_id) + cluster = clusters[0].to_dict() + push_to_xcom(cluster, context, self._ignore_keys) + return + + cluster_environment_build_id = self._get_cluster_environment_build_id() + + create_cluster = { + "name": self.name, + "project_id": self.project_id, + "cluster_compute_id": self.compute_config_id, + "cluster_environment_build_id": cluster_environment_build_id, + } + + cluster: Cluster = self.sdk.create_cluster(create_cluster).result + + self.log.info("cluster created with id: %s", cluster.id) + push_to_xcom(cluster.to_dict(), context, self._ignore_keys) + + +class AnyscaleStartClusterOperator(AnyscaleBaseOperator): + """ + An operator that starts a cluster. + + :param cluster_id: ID of the Cluster to start. (templated) + :param start_cluster_options: Options to set when starting a cluster. (templated) + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) + :param poke_interval: Poke interval that the operator will use to check if the cluster is started. (default: 60) + """ + + template_fields: Sequence[str] = [ + "auth_token", + "cluster_id", + "start_cluster_options" + ] + + def __init__( + self, + *, + cluster_id: str, + start_cluster_options: Optional[dict] = None, + wait_for_completion: Optional[bool] = True, + **kwargs, + ): + super().__init__(**kwargs) + self.cluster_id = cluster_id + + self.start_cluster_options = start_cluster_options + + if self.start_cluster_options is None: + self.start_cluster_options = {} + + self.wait_for_completion = wait_for_completion + + self._ignore_keys = [] + + def execute(self, context: Context) -> None: + + self.log.info("starting cluster %s", self.cluster_id) + + cluster_operation = self.sdk.start_cluster( + cluster_id=self.cluster_id, + start_cluster_options=self.start_cluster_options + ).result + + if self.wait_for_completion: + while not AnyscaleClusterSensor( + task_id="wait_cluster", + cluster_id=self.cluster_id, + auth_token=self.auth_token, + ).poke(context): + + time.sleep(self.poke_interval) + + push_to_xcom(cluster_operation.to_dict(), context, self._ignore_keys) + + +class AnyscaleTerminateClusterOperator(AnyscaleBaseOperator): + """ + An operator that initializes workflow to transition the Cluster into the Terminated state. + :param cluster_id: ID of the Cluster to terminate. (templated) + :param terminate_cluster_options: Options to set when terminating a Cluster. (templated) + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) + :param poke_interval: Poke interval that the operator will use to check if the cluster is terminated. (default: 60) + """ + + template_fields: Sequence[str] = [ + "cluster_id", + "auth_token", + "terminate_cluster_options", + ] + + def __init__( + self, + *, + cluster_id: str, + terminate_cluster_options: Optional[dict] = None, + wait_for_completion: Optional[bool] = True, + **kwargs, + ): + super().__init__(**kwargs) + self.cluster_id = cluster_id + + self.terminate_cluster_options = terminate_cluster_options + + if self.terminate_cluster_options is None: + self.terminate_cluster_options = {} + + self.wait_for_completion = wait_for_completion + self._ignore_keys = [] + + def execute(self, context: Context) -> None: + + cluster_operation = self.sdk.terminate_cluster( + cluster_id=self.cluster_id, + terminate_cluster_options=self.terminate_cluster_options).result + + self.log.info("terminating cluster %s", self.cluster_id) + + if self.wait_for_completion: + while not AnyscaleClusterSensor( + task_id="wait_cluster", + cluster_id=self.cluster_id, + auth_token=self.auth_token, + ).poke(context): + + time.sleep(self.poke_interval) + + push_to_xcom(cluster_operation.to_dict(), context, self._ignore_keys) diff --git a/ray_provider/operators/anyscale_job.py b/ray_provider/operators/anyscale_job.py new file mode 100644 index 0000000..03c8343 --- /dev/null +++ b/ray_provider/operators/anyscale_job.py @@ -0,0 +1,142 @@ +import time +from typing import Optional, Sequence +from ray_provider.utils import push_to_xcom + +from airflow.utils.context import Context +from airflow.exceptions import AirflowException + +from ray_provider.operators.base import AnyscaleBaseOperator +from ray_provider.sensors.anyscale_job import AnyscaleProductionJobSensor + +from anyscale.shared_anyscale_utils.utils.byod import BYODInfo +from anyscale.sdk.anyscale_client.models.create_production_job import CreateProductionJob + + +class AnyscaleCreateProductionJobOperator(AnyscaleBaseOperator): + """ + An operator that creates an Production Job. + :param name: Name of the job. (templated) + :param project_id: Id of the project this job will start clusters in. (templated) + :param entrypoint: A script that will be run to start your job. + This command will be run in the root directory of + the specified runtime env. Eg. 'python script.py' (templated) + :param build_id: The id of the cluster env build. + This id will determine the docker image your job is run on. (templated) + :param docker: Docker image for BYOD. (templated) + :param max_retries: The number of retries this job will attempt on failure. + Set to None to set infinite retries. (templated) + :param description: Description of the job. (templated) + :param runtime_env: A ray runtime env json. (templated) + Your entrypoint will be run in the environment specified by this runtime env. (templated) + :param compute_config_id: The id of the compute configuration that you want to use. + This id will specify the resources required for your job. (templated) + :param ray_version: Ray version (only used for BYOD). (templated) (default: "1.13.0") + :param python_version: Python version (only used for BYOD). (templated) (default: "py38") + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) + :param poke_interval: Poke interval that the operator will use to check if the cluster is started. (default: 60) + """ + + template_fields: Sequence[str] = [ + "name", + "auth_token", + "project_id", + "entrypoint", + "build_id", + "docker", + "description", + "runtime_env", + "compute_config_id", + "ray_version", + "python_version", + ] + + def __init__( + self, + name: str, + project_id: str, + entrypoint: str, + build_id: str = None, + docker: str = None, + max_retries: int = 1, + description: str = None, + runtime_env: dict = None, + compute_config_id: str = None, + ray_version: Optional[str] = "1.13.0", + python_version: Optional[str] = "py38", + wait_for_completion: Optional[bool] = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.name = name + self.docker = docker + self.entrypoint = entrypoint + self.project_id = project_id + self.description = description + self.max_retries = max_retries + self.runtime_env = runtime_env + self.compute_config_id = compute_config_id + self.ray_version = ray_version + self.python_version = python_version + self.build_id = build_id + + self.wait_for_completion = wait_for_completion + self._ignore_keys = [] + + def _get_build_id(self) -> str: + + build_id = None + + if self.docker: + + build_id = BYODInfo( + docker_image_name=self.docker, + python_version=self.python_version, + ray_version=self.ray_version, + ).encode() + + if self.build_id: + if self.docker: + self.log.info( + "docker is ignored when cluster_environment_build_id is provided.") + + build_id = self.build_id + + if build_id is None: + raise AirflowException( + "at least build_id or docker must be provided.") + + return build_id + + def execute(self, context: Context) -> None: + build_id = self._get_build_id() + + create_production_job = CreateProductionJob( + name=self.name, + description=self.description, + project_id=self.project_id, + config={ + "entrypoint": self.entrypoint, + "build_id": build_id, + "runtime_env": self.runtime_env, + "compute_config_id": self.compute_config_id, + "max_retries": self.max_retries, + }, + ) + + production_job = self.sdk.create_job( + create_production_job).result + + self.log.info(f"production job {production_job.id} created") + + if self.wait_for_completion: + while not AnyscaleProductionJobSensor( + task_id="wait_job", + production_job_id=production_job.id, + auth_token=self.auth_token, + ).poke(context): + + time.sleep(self.poke_interval) + + push_to_xcom(production_job.to_dict(), context, + ignore_keys=self._ignore_keys) diff --git a/ray_provider/operators/anyscale_service.py b/ray_provider/operators/anyscale_service.py new file mode 100644 index 0000000..7c2c552 --- /dev/null +++ b/ray_provider/operators/anyscale_service.py @@ -0,0 +1,157 @@ +import time +from typing import Optional, Sequence + +from airflow.utils.context import Context +from airflow.exceptions import AirflowException +from airflow.utils.log.secrets_masker import mask_secret + +from ray_provider.utils import push_to_xcom +from ray_provider.operators.base import AnyscaleBaseOperator +from ray_provider.sensors.anyscale_service import AnyscaleServiceSensor + +from anyscale.shared_anyscale_utils.utils.byod import BYODInfo +from anyscale.sdk.anyscale_client.models.create_production_service import CreateProductionService + + +class AnyscaleApplyServiceOperator(AnyscaleBaseOperator): + """ + An Operator that puts a service. This operator will create a service + with the given name if it doesn't exist, and will otherwise update the service. + :param name: Name of the service. (templated) + :param project_id: Id of the project this job will start clusters in. (templated) + :param entrypoint: A script that will be run to start your service. + This command will be run in the root directory of the specified runtime env. Eg. 'python script.py'. (templated) + :param healthcheck_url: Healthcheck url. (templated) + :param build_id: The id of the cluster env build. + This id will determine the docker image your service is run on. (templated) + :param docker: Docker image for BYOD. (templated) + :param max_retries: The number of retries this job will attempt on failure. + Set to None to set infinite retries. (templated) + :param access: Whether service can be accessed by public internet traffic. + Possible values: ['private', 'public'] (templated) + :param description: Description of the Service. (templated) + :param runtime_env: A ray runtime env json. (templated) + Your entrypoint will be run in the environment specified by this runtime env. (templated) + :param compute_config_id: The id of the compute configuration that you want to use. + This id will specify the resources required for your job. (templated) + :param ray_version: Ray version (only used for BYOD). (templated) (default: "1.13.0") + :param python_version: Python version (only used for BYOD). (templated) (default: "py38") + :param wait_for_completion: If True, waits for creation of the service to complete. (default: True) + :param poke_interval: Poke interval that the operator will use to check if the service is ready. (default: 60) + """ + + template_fields: Sequence[str] = [ + "name", + "auth_token", + "project_id", + "entrypoint", + "build_id", + "docker", + "description", + "runtime_env", + "compute_config_id", + "ray_version", + "python_version", + "access", + ] + + def __init__( + self, + name: str, + project_id: str, + entrypoint: str, + healthcheck_url: str, + build_id: str = None, + docker: str = None, + max_retries: int = 0, + access: str = "private", + description: str = None, + runtime_env: dict = None, + compute_config_id: str = None, + ray_version: Optional[str] = "1.13.0", + python_version: Optional[str] = "py38", + wait_for_completion: Optional[bool] = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.name = name + self.docker = docker + self.access = access + self.entrypoint = entrypoint + self.project_id = project_id + self.description = description + self.max_retries = max_retries + self.runtime_env = runtime_env + self.healthcheck_url = healthcheck_url + self.compute_config_id = compute_config_id + self.ray_version = ray_version + self.python_version = python_version + self.build_id = build_id + + self.wait_for_completion = wait_for_completion + self._ignore_keys = [] + + def _get_build_id(self) -> str: + + build_id = None + + if self.docker: + + build_id = BYODInfo( + docker_image_name=self.docker, + python_version=self.python_version, + ray_version=self.ray_version, + ).encode() + + if self.build_id: + if self.docker: + self.log.info( + "docker is ignored when cluster_environment_build_id is provided.") + + build_id = self.build_id + + if build_id is None: + raise AirflowException( + "at least cluster_environment_build_id or docker must be provided.") + + return build_id + + def execute(self, context: Context) -> None: + + build_id = self._get_build_id() + + create_production_service = CreateProductionService( + name=self.name, + access=self.access, + description=self.description, + project_id=self.project_id, + healthcheck_url=self.healthcheck_url, + config={ + "entrypoint": self.entrypoint, + "build_id": build_id, + "runtime_env": self.runtime_env, + "compute_config_id": self.compute_config_id, + "max_retries": self.max_retries, + }, + ) + + production_service = self.sdk.apply_service( + create_production_service).result + + self.log.info("production service %s created", production_service.id) + + if self.wait_for_completion: + while not AnyscaleServiceSensor( + task_id="wait_service", + service_id=production_service.id, + auth_token=self.auth_token, + ).poke(context): + + time.sleep(self.poke_interval) + + self.log.info("service available at %s", production_service.url) + + xcom_payload = production_service.to_dict() + xcom_payload["token"] = mask_secret(xcom_payload["token"]) + push_to_xcom(xcom_payload, context, self._ignore_keys) diff --git a/ray_provider/operators/anyscale_session_command.py b/ray_provider/operators/anyscale_session_command.py new file mode 100644 index 0000000..2dfdf92 --- /dev/null +++ b/ray_provider/operators/anyscale_session_command.py @@ -0,0 +1,67 @@ +import time +from typing import Optional, Sequence + +from airflow.utils.context import Context +from ray_provider.utils import push_to_xcom +from ray_provider.operators.base import AnyscaleBaseOperator + +from ray_provider.sensors.anyscale_session_command import AnyscaleSessionCommandSensor + +_POKE_INTERVAL = 60 + + +class AnyscaleCreateSessionCommandOperator(AnyscaleBaseOperator): + """ + An Operator that creates and executes a shell command on a session. + Makes no assumption about the details of the shell command. + :param session_id: ID of the Session to execute this command on. (templated) + :param shell_command: Shell command string that will be executed. (templated) + :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True) + :param poke_interval: Poke interval that the operator will use to check if the session command has finished. (default: 60) + """ + + template_fields: Sequence[str] = [ + "session_id", + "auth_token", + "shell_command", + ] + + def __init__( + self, + *, + session_id: str, + shell_command: str, + wait_for_completion: Optional[bool] = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.session_id = session_id + self.shell_command = shell_command + self.wait_for_completion = wait_for_completion + self._ignore_keys = [] + + def execute(self, context: Context): + + create_session_command = { + "session_id": self.session_id, + "shell_command": self.shell_command, + } + + session_command_response = self.sdk.create_session_command( + create_session_command).result + + self.log.info("session command with id %s created", + session_command_response.id) + + if self.wait_for_completion: + while not AnyscaleSessionCommandSensor( + task_id="wait_session_command", + session_command_id=session_command_response.id, + auth_token=self.auth_token, + ).poke(context): + + time.sleep(self.poke_interval) + + push_to_xcom(session_command_response.to_dict(), + context, self._ignore_keys) diff --git a/ray_provider/operators/base.py b/ray_provider/operators/base.py new file mode 100644 index 0000000..9af5a76 --- /dev/null +++ b/ray_provider/operators/base.py @@ -0,0 +1,32 @@ +from anyscale import AnyscaleSDK +from airflow.utils.context import Context + +from typing import Optional +from airflow.models.baseoperator import BaseOperator +from airflow.compat.functools import cached_property + + +class AnyscaleBaseOperator(BaseOperator): + """ + Anyscale Base Operator. + :param auth_token: Anyscale CLI token. + :param poke_interval: Time in seconds that the operator should wait for completion. (default: 60) + """ + + def __init__( + self, + *, + auth_token: str, + poke_interval: Optional[int] = 60, + **kwargs + ): + self.auth_token = auth_token + self.poke_interval = poke_interval + super().__init__(**kwargs) + + @cached_property + def sdk(self) -> AnyscaleSDK: + return AnyscaleSDK(auth_token=self.auth_token) + + def execute(self, context: Context): + raise NotImplementedError('Please implement execute() in subclass') diff --git a/ray_provider/sensors/__init__.py b/ray_provider/sensors/__init__.py new file mode 100644 index 0000000..13a8339 --- /dev/null +++ b/ray_provider/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/ray_provider/sensors/anyscale_cluster.py b/ray_provider/sensors/anyscale_cluster.py new file mode 100644 index 0000000..324334b --- /dev/null +++ b/ray_provider/sensors/anyscale_cluster.py @@ -0,0 +1,61 @@ +from typing import Sequence + +from airflow.utils.context import Context +from ray_provider.sensors.base import AnyscaleBaseSensor + + +class AnyscaleClusterSensor(AnyscaleBaseSensor): + """ + A Sensor that pokes the state of a cluster and returns when it reaches goal state. + + :param cluster_id: ID of the Cluster to retreive. (templated) + """ + + template_fields: Sequence[str] = [ + "auth_token", + "cluster_id", + ] + + def __init__( + self, + *, + cluster_id: str, + **kwargs, + ): + + super().__init__(**kwargs) + self.cluster_id = cluster_id + + def _log_services(self, response): + services = response.result.services_urls + + if services: + self.log.info("service urls:") + for name, service in services.to_dict().items(): + self.log.info("%s: %s", name, service) + + def _log_head_node(self, response): + head_node_info = response.result.head_node_info + + if head_node_info: + self.log.info("head node info:") + for name, info in head_node_info.to_dict().items(): + self.log.info("%s: %s", name, info) + + def poke(self, context: Context) -> bool: + + response = self.sdk.get_cluster(self.cluster_id) + + state = response.result.state + goal_state = response.result.goal_state + + self.log.info("current state: %s, goal state: %s", state, goal_state) + + if goal_state is not None and goal_state != state: + return False + + self.log.info("cluster reached goal state: %s", state) + self._log_head_node(response) + self._log_services(response) + + return True diff --git a/ray_provider/sensors/anyscale_job.py b/ray_provider/sensors/anyscale_job.py new file mode 100644 index 0000000..93c8393 --- /dev/null +++ b/ray_provider/sensors/anyscale_job.py @@ -0,0 +1,72 @@ +from typing import Sequence + +from airflow.utils.context import Context +from airflow.exceptions import AirflowException +from ray_provider.sensors.base import AnyscaleBaseSensor + + +class AnyscaleProductionJobSensor(AnyscaleBaseSensor): + """ + A Sensor that pokes the state of a production job and returns when it reaches goal state. + :param production_job_id: ID of the production job. (templated) + """ + + template_fields: Sequence[str] = [ + "production_job_id", + "auth_token", + ] + + def __init__( + self, + *, + production_job_id: str, + **kwargs, + ): + super().__init__(**kwargs) + self.production_job_id = production_job_id + + def _fetch_logs(self): + + try: + logs = self.sdk.get_production_job_logs( + self.production_job_id).results.logs + self.log.info("logs: \n %s", logs) + + except Exception: + self.log.warning("logs not found for %s", self.production_job_id) + + def poke(self, context: Context) -> bool: + + production_job = self.sdk.get_production_job( + production_job_id=self.production_job_id).result + + state = production_job.state + + self.log.info("current state: %s, goal state %s", + state.current_state, state.goal_state) + + operation_message = state.operation_message + if operation_message: + self.log.info(operation_message) + + if state.current_state in ("OUT_OF_RETRIES", "TERMINATED", "ERRORED"): + self._fetch_logs() + raise AirflowException( + "job ended with status {}, error: {}".format( + state.current_state, + state.error, + ) + ) + + if state.current_state != state.goal_state: + return False + + self.log.info( + "job %s reached goal state %s", self.production_job_id, state.goal_state) + + took = state.state_transitioned_at - production_job.created_at + + self.log.info("duration: %s", took.total_seconds()) + self._fetch_logs() + + return True diff --git a/ray_provider/sensors/anyscale_service.py b/ray_provider/sensors/anyscale_service.py new file mode 100644 index 0000000..2942bb7 --- /dev/null +++ b/ray_provider/sensors/anyscale_service.py @@ -0,0 +1,64 @@ +from typing import Sequence + +from airflow.utils.context import Context +from airflow.exceptions import AirflowException +from ray_provider.sensors.base import AnyscaleBaseSensor + + +class AnyscaleServiceSensor(AnyscaleBaseSensor): + """ + A Sensor that pokes the state of a service and returns when it reaches goal state. (EXPERIMENTAL) + :param service_id: ID of the service. (templated) + """ + + template_fields: Sequence[str] = [ + "service_id", + "auth_token", + ] + + def __init__( + self, + service_id: str, + goal_state: str = "RUNNING", + **kwargs, + ): + super().__init__(**kwargs) + + self.service_id = service_id + self.goal_state = goal_state + self._ignore_keys = [] + + def poke(self, context: Context) -> bool: + + response = self.sdk.get_service(service_id=self.service_id) + + state = response.result.state + + msg = f"current state: {state.current_state}, " f"goal state: {self.goal_state}" + + self.log.info(msg) + + operation_message = state.operation_message + + if operation_message: + self.log.info(operation_message) + + if state.current_state in ("OUT_OF_RETRIES", "TERMINATED", "ERRORED", "BROKEN"): + if self.goal_state == state.current_state: + return True + + msg = ( + f"job ended with status {state.current_state}, " f"error: {state.error}" + ) + raise AirflowException(msg) + + if state.current_state != self.goal_state: + return False + + self.log.info(f"service {self.service_id} reached goal state {self.goal_state}") + + took = response.result.state.state_transitioned_at - response.result.created_at + self.log.info(f"duration: {took.total_seconds()}") + self.log.info(f"service available at: {response.result.url}") + + return True diff --git a/ray_provider/sensors/anyscale_session_command.py b/ray_provider/sensors/anyscale_session_command.py new file mode 100644 index 0000000..6148b01 --- /dev/null +++ b/ray_provider/sensors/anyscale_session_command.py @@ -0,0 +1,51 @@ +from typing import Sequence + +from airflow.utils.context import Context + +from airflow.exceptions import AirflowException +from ray_provider.sensors.base import AnyscaleBaseSensor + + +class AnyscaleSessionCommandSensor(AnyscaleBaseSensor): + """ + A Sensor that pokes the state of a session command and returns when it reaches goal state. + :param session_command_id: ID of the Session Command to retrieve. + """ + + template_fields: Sequence[str] = [ + "session_command_id", + "auth_token", + ] + + def __init__( + self, + session_command_id: str, + **kwargs, + ): + super().__init__(**kwargs) + + self.session_command_id = session_command_id + + def poke(self, context: Context) -> bool: + + session_command = self.sdk.get_session_command( + self.session_command_id).result + + status_code = session_command.status_code + + if status_code is None: + return False + + took = session_command.finished_at - session_command.created_at + + self.log.info("duration: %s", took.total_seconds()) + self.log.info( + "session command %s ended with status code %s", + self.session_command_id, + session_command.status_code + ) + + if status_code != 0: + raise AirflowException("session command ended with errors") + + return True diff --git a/ray_provider/sensors/base.py b/ray_provider/sensors/base.py new file mode 100644 index 0000000..fd509a1 --- /dev/null +++ b/ray_provider/sensors/base.py @@ -0,0 +1,29 @@ +from anyscale import AnyscaleSDK +from airflow.utils.context import Context + +from airflow.sensors.base import BaseSensorOperator +from airflow.compat.functools import cached_property + + +class AnyscaleBaseSensor(BaseSensorOperator): + """ + Anyscale Base Sensor. + :param auth_token: Anyscale CLI token. + """ + + def __init__( + self, + *, + auth_token: str, + **kwargs + ): + + self.auth_token = auth_token + super().__init__(**kwargs) + + @cached_property + def sdk(self) -> AnyscaleSDK: + return AnyscaleSDK(auth_token=self.auth_token) + + def poke(self, context: Context) -> bool: + raise NotImplementedError("Please implement poke() in subclass") diff --git a/ray_provider/utils/__init__.py b/ray_provider/utils/__init__.py new file mode 100644 index 0000000..68ae74d --- /dev/null +++ b/ray_provider/utils/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from ray_provider.utils.utils import push_to_xcom diff --git a/ray_provider/utils/utils.py b/ray_provider/utils/utils.py new file mode 100644 index 0000000..03c25ea --- /dev/null +++ b/ray_provider/utils/utils.py @@ -0,0 +1,24 @@ +import json + +from copy import deepcopy +from airflow.utils.context import Context + + +def push_to_xcom(result: dict, context: Context, ignore_keys: list = None): + + if ignore_keys is None: + ignore_keys = [] + + ti = context["ti"] + result_copy = deepcopy(result) + + for key, value in result_copy.items(): + + if key in ignore_keys: + continue + + if type(value) is dict: + value = json.dumps(value, default=str) + + value = str(value) + ti.xcom_push(key=key, value=value) diff --git a/requirements.txt b/requirements.txt index 04fc278..4f30db5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy pandas==1.2.4 modin xgboost_ray +anyscale diff --git a/setup.py b/setup.py index 1b49b5a..fc1723f 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,8 @@ "pandas>=1.0.0", "filelock>=3.0.0", "ray>=1.3.0", - "apache-airflow-providers-http"], + "apache-airflow-providers-http", + "anyscale"], setup_requires=["setuptools", "wheel"], extras_require={}, author="Rob Deeb, Richard Liaw, Daniel Imberman, Pete DeJoy",