From 8ab5ad85dfa07e310d660aff6afdd9b71332d52a Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Fri, 15 Mar 2024 15:17:30 +0100 Subject: [PATCH] Added checks for UC-incompatible task Airflow `DatabricksSubmitRunOperator` instances --- pyproject.toml | 2 +- src/databricks/labs/pylint/airflow.py | 55 ++++++++++++++++----------- tests/test_airflow.py | 22 ++++++++++- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8db66f7..2f6ca87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ known-first-party = ["databricks.labs.pylint"] [tool.ruff.lint.per-file-ignores] -"tests/samples/**/**" = ["F403", "F405", "E402", "E501", "E722", "E731"] +"tests/samples/**/**" = ["F403", "F405", "E402", "E501", "E722", "E731", "F821"] [tool.coverage.run] branch = true diff --git a/src/databricks/labs/pylint/airflow.py b/src/databricks/labs/pylint/airflow.py index fd543e4..42c3919 100644 --- a/src/databricks/labs/pylint/airflow.py +++ b/src/databricks/labs/pylint/airflow.py @@ -1,3 +1,5 @@ +from typing import Any + import astroid from pylint.checkers import BaseChecker @@ -17,30 +19,35 @@ def visit_call(self, node: astroid.Call): operator = node.func.as_string() if operator not in ("DatabricksCreateJobsOperator", "DatabricksSubmitRunOperator"): return - for kwarg in node.keywords: - if kwarg.arg == "tasks": - self._check_tasks(kwarg.value) - continue - if kwarg.arg == "job_clusters": - self._check_job_clusters(kwarg.value) - continue + for arg, value in self._infer_kwargs(node.keywords).items(): + if arg == "tasks": + self._check_tasks(value, node) + elif arg == "job_clusters": + self._check_job_clusters(value, node) + elif arg == "new_cluster": + self._check_new_cluster("ephemeral", value, node) + + def _check_new_cluster(self, key: str, new_cluster: dict[str, Any], node: astroid.NodeNG): + if "data_security_mode" not in new_cluster: + self.add_message("missing-data-security-mode", node=node, args=(key,)) + + def _check_tasks(self, tasks: list[dict[str, Any]], node: astroid.NodeNG): + for task in tasks: + if "new_cluster" not in task: + return + self._check_new_cluster(task["task_key"], task["new_cluster"], node) - def _check_tasks(self, value: astroid.NodeNG): - for inferred in value.infer(): - for task in self._infer_value(inferred): - if "new_cluster" not in task: - continue - if "data_security_mode" not in task["new_cluster"]: - self.add_message("missing-data-security-mode", node=value, args=(task["task_key"],)) + def _check_job_clusters(self, job_clusters: list[dict[str, Any]], node: astroid.NodeNG): + for job_cluster in job_clusters: + if "new_cluster" not in job_cluster: + return + self._check_new_cluster(job_cluster["job_cluster_key"], job_cluster["new_cluster"], node) - def _check_job_clusters(self, value: astroid.NodeNG): - for inferred in value.infer(): - for job_cluster in self._infer_value(inferred): - if "new_cluster" not in job_cluster: - continue - # add message that this job cluster is missing data_security_mode - if "data_security_mode" not in job_cluster["new_cluster"]: - self.add_message("missing-data-security-mode", node=value, args=(job_cluster["job_cluster_key"],)) + def _infer_kwargs(self, keywords: list[astroid.Keyword]): + kwargs = {} + for keyword in keywords: + kwargs[keyword.arg] = self._infer_value(keyword.value) + return kwargs def _infer_value(self, value: astroid.NodeNG): if isinstance(value, astroid.Dict): @@ -53,6 +60,10 @@ def _infer_value(self, value: astroid.NodeNG): return tuple(self._infer_value(elem) for elem in value.elts) if isinstance(value, astroid.DictUnpack): return {self._infer_value(key): self._infer_value(value) for key, value in value.items} + if isinstance(value, astroid.Name): + for inferred in value.inferred(): + return self._infer_value(inferred) + raise ValueError(f"Cannot resolve variable: {value.name}") raise ValueError(f"Unsupported type {type(value)}") def _infer_dict(self, in_dict: astroid.Dict): diff --git a/tests/test_airflow.py b/tests/test_airflow.py index 414852c..4ef50c5 100644 --- a/tests/test_airflow.py +++ b/tests/test_airflow.py @@ -4,7 +4,7 @@ def test_missing_data_security_mode_in_job_clusters(lint_with): messages = ( lint_with(AirflowChecker) - << """from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator + << """from airflow.providers.databricks.operators.databricks import DatabricksCreateJobsOperator tasks = [ { "task_key": "test", @@ -39,7 +39,7 @@ def test_missing_data_security_mode_in_job_clusters(lint_with): def test_missing_data_security_mode_in_task_clusters(lint_with): messages = ( lint_with(AirflowChecker) - << """from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator + << """from airflow.providers.databricks.operators.databricks import DatabricksCreateJobsOperator tasks = [ { "task_key": "banana", @@ -62,3 +62,21 @@ def test_missing_data_security_mode_in_task_clusters(lint_with): "[missing-data-security-mode] banana cluster missing 'data_security_mode' " "required for Unity Catalog compatibility" ) in messages + + +def test_missing_data_security_mode_in_submit_run_clusters(lint_with): + messages = ( + lint_with(AirflowChecker) + << """from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator +new_cluster = {"spark_version": "10.1.x-scala2.12", "num_workers": 2} +notebook_task = { + "notebook_path": "/Users/airflow@example.com/PrepareData", +} +DatabricksSubmitRunOperator( #@ + task_id="notebook_run", new_cluster=new_cluster, notebook_task=notebook_task +)""" + ) + assert ( + "[missing-data-security-mode] ephemeral cluster missing 'data_security_mode' " + "required for Unity Catalog compatibility" + ) in messages