Skip to content

Commit

Permalink
Added checks for UC-incompatible task Airflow `DatabricksSubmitRunOpe…
Browse files Browse the repository at this point in the history
…rator` instances
  • Loading branch information
nfx committed Mar 15, 2024
1 parent 8eb176d commit 8ab5ad8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ known-first-party = ["databricks.labs.pylint"]

[tool.ruff.lint.per-file-ignores]

"tests/samples/**/**" = ["F403", "F405", "E402", "E501", "E722", "E731"]
"tests/samples/**/**" = ["F403", "F405", "E402", "E501", "E722", "E731", "F821"]

[tool.coverage.run]
branch = true
Expand Down
55 changes: 33 additions & 22 deletions src/databricks/labs/pylint/airflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import astroid
from pylint.checkers import BaseChecker

Expand All @@ -17,30 +19,35 @@ def visit_call(self, node: astroid.Call):
operator = node.func.as_string()
if operator not in ("DatabricksCreateJobsOperator", "DatabricksSubmitRunOperator"):
return
for kwarg in node.keywords:
if kwarg.arg == "tasks":
self._check_tasks(kwarg.value)
continue
if kwarg.arg == "job_clusters":
self._check_job_clusters(kwarg.value)
continue
for arg, value in self._infer_kwargs(node.keywords).items():
if arg == "tasks":
self._check_tasks(value, node)
elif arg == "job_clusters":
self._check_job_clusters(value, node)
elif arg == "new_cluster":
self._check_new_cluster("ephemeral", value, node)

def _check_new_cluster(self, key: str, new_cluster: dict[str, Any], node: astroid.NodeNG):
if "data_security_mode" not in new_cluster:
self.add_message("missing-data-security-mode", node=node, args=(key,))

def _check_tasks(self, tasks: list[dict[str, Any]], node: astroid.NodeNG):
for task in tasks:
if "new_cluster" not in task:
return
self._check_new_cluster(task["task_key"], task["new_cluster"], node)

def _check_tasks(self, value: astroid.NodeNG):
for inferred in value.infer():
for task in self._infer_value(inferred):
if "new_cluster" not in task:
continue
if "data_security_mode" not in task["new_cluster"]:
self.add_message("missing-data-security-mode", node=value, args=(task["task_key"],))
def _check_job_clusters(self, job_clusters: list[dict[str, Any]], node: astroid.NodeNG):
for job_cluster in job_clusters:
if "new_cluster" not in job_cluster:
return
self._check_new_cluster(job_cluster["job_cluster_key"], job_cluster["new_cluster"], node)

def _check_job_clusters(self, value: astroid.NodeNG):
for inferred in value.infer():
for job_cluster in self._infer_value(inferred):
if "new_cluster" not in job_cluster:
continue
# add message that this job cluster is missing data_security_mode
if "data_security_mode" not in job_cluster["new_cluster"]:
self.add_message("missing-data-security-mode", node=value, args=(job_cluster["job_cluster_key"],))
def _infer_kwargs(self, keywords: list[astroid.Keyword]):
kwargs = {}
for keyword in keywords:
kwargs[keyword.arg] = self._infer_value(keyword.value)
return kwargs

def _infer_value(self, value: astroid.NodeNG):
if isinstance(value, astroid.Dict):
Expand All @@ -53,6 +60,10 @@ def _infer_value(self, value: astroid.NodeNG):
return tuple(self._infer_value(elem) for elem in value.elts)
if isinstance(value, astroid.DictUnpack):
return {self._infer_value(key): self._infer_value(value) for key, value in value.items}
if isinstance(value, astroid.Name):
for inferred in value.inferred():
return self._infer_value(inferred)
raise ValueError(f"Cannot resolve variable: {value.name}")
raise ValueError(f"Unsupported type {type(value)}")

def _infer_dict(self, in_dict: astroid.Dict):
Expand Down
22 changes: 20 additions & 2 deletions tests/test_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def test_missing_data_security_mode_in_job_clusters(lint_with):
messages = (
lint_with(AirflowChecker)
<< """from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator
<< """from airflow.providers.databricks.operators.databricks import DatabricksCreateJobsOperator
tasks = [
{
"task_key": "test",
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_missing_data_security_mode_in_job_clusters(lint_with):
def test_missing_data_security_mode_in_task_clusters(lint_with):
messages = (
lint_with(AirflowChecker)
<< """from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator
<< """from airflow.providers.databricks.operators.databricks import DatabricksCreateJobsOperator
tasks = [
{
"task_key": "banana",
Expand All @@ -62,3 +62,21 @@ def test_missing_data_security_mode_in_task_clusters(lint_with):
"[missing-data-security-mode] banana cluster missing 'data_security_mode' "
"required for Unity Catalog compatibility"
) in messages


def test_missing_data_security_mode_in_submit_run_clusters(lint_with):
messages = (
lint_with(AirflowChecker)
<< """from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator
new_cluster = {"spark_version": "10.1.x-scala2.12", "num_workers": 2}
notebook_task = {
"notebook_path": "/Users/airflow@example.com/PrepareData",
}
DatabricksSubmitRunOperator( #@
task_id="notebook_run", new_cluster=new_cluster, notebook_task=notebook_task
)"""
)
assert (
"[missing-data-security-mode] ephemeral cluster missing 'data_security_mode' "
"required for Unity Catalog compatibility"
) in messages

0 comments on commit 8ab5ad8

Please sign in to comment.