Skip to content

Commit

Permalink
Added dbutils checker
Browse files Browse the repository at this point in the history
  • Loading branch information
nfx committed Mar 14, 2024
1 parent 12b65dd commit fc5c9d2
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 119 deletions.
19 changes: 18 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# Contributing

<!-- TOC -->
* [Contributing](#contributing)
* [First Principles](#first-principles)
* [Common fixes for `mypy` errors](#common-fixes-for-mypy-errors)
* [..., expression has type "None", variable has type "str"](#-expression-has-type-none-variable-has-type-str)
* [..., has incompatible type "Path"; expected "str"](#-has-incompatible-type-path-expected-str)
* [Argument 2 to "get" of "dict" has incompatible type "None"; expected ...](#argument-2-to-get-of-dict-has-incompatible-type-none-expected-)
* [Local Setup](#local-setup)
* [First contribution](#first-contribution)
* [Troubleshooting](#troubleshooting)
<!-- TOC -->

## First Principles

Favoring standard libraries over external dependencies, especially in specific contexts like Databricks, is a best practice in software
Expand Down Expand Up @@ -114,4 +126,9 @@ Here are the example steps to submit your first contribution:

## Troubleshooting

If you encounter any package dependency errors after `git pull`, run `make clean`
If you encounter any package dependency errors after `git pull`, run `make clean dev`

### Running in isolation

See https://pylint.pycqa.org/en/latest/development_guide/how_tos/custom_checkers.html#testing-a-checker
`pylint --load-plugins=databricks.labs.pylint --disable=all --enable=<check> test.py`
19 changes: 13 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,24 @@ python="3.10"
path = ".venv"

[tool.hatch.envs.default.scripts]
test = "pytest -n 2 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20"
test = "pytest -n 2 --cov src --cov-report=xml --timeout 30 tests --durations 20"
coverage = "pytest -n 2 --cov src tests/unit --timeout 30 --cov-report=html --durations 20"
integration = "pytest -n 10 --cov src tests/integration --durations 20"
fmt = ["isort .",
"ruff format",
"ruff . --fix",
"mypy .",
"ruff check . --fix",
#"mypy .",
"pylint --output-format=colorized -j 0 src"]
verify = ["black --check .",
"isort . --check-only",
"ruff .",
"mypy .",
"pylint --output-format=colorized -j 0 src"]

[tool.isort]
[tool.mypy]
exclude = ["tests/samples"]

[tool.lint.isort]
profile = "black"

[tool.pytest.ini_options]
Expand All @@ -85,9 +88,13 @@ cache-dir = ".venv/ruff-cache"
target-version = "py310"
line-length = 120

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-first-party = ["databricks.labs.pylint"]

[tool.ruff.lint.per-file-ignores]

"tests/samples/*" = ["F403", "F405", "E402", "E501", "E722", "E731"]

[tool.coverage.run]
branch = true
parallel = true
Expand Down Expand Up @@ -145,7 +152,7 @@ fail-under = 10.0
# ignore-list. The regex matches against paths and can be in Posix or Windows
# format. Because '\\' represents the directory delimiter on Windows systems, it
# can't be used as an escape character.
# ignore-paths =
ignore-paths = ["tests/samples"]

# Files or directories matching the regular expression patterns are skipped. The
# regex matches against base names, not paths. The default value ignores Emacs
Expand Down
7 changes: 6 additions & 1 deletion src/databricks/labs/pylint/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from databricks.labs.pylint.__about__ import __version__
def register(linter):
from databricks.labs.pylint.dbutils import DbutilsChecker
from databricks.labs.pylint.notebooks import NotebookChecker

linter.register_checker(NotebookChecker(linter))
linter.register_checker(DbutilsChecker(linter))
86 changes: 86 additions & 0 deletions src/databricks/labs/pylint/dbutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# pylint checker for databricks dbutils
import astroid
from pylint.checkers import BaseChecker


class DbutilsChecker(BaseChecker):
name = "dbutils"

msgs = {
"E9899": (
"Use Databricks SDK instead: w.dbfs.copy(%s, %s)",
"dbutils-fs-cp",
"Migrate all usage of dbutils to Databricks SDK",
),
"E9898": (
"Use Databricks SDK instead: with w.dbfs.download(%s) as f: f.read()",
"dbutils-fs-head",
"Migrate all usage of dbutils to Databricks SDK",
),
"E9897": (
"Use Databricks SDK instead: w.dbfs.list(%s)",
"dbutils-fs-ls",
"Migrate all usage of dbutils to Databricks SDK",
),
"E9896": (
"Mounts are not supported with Unity Catalog, switch to using Unity Catalog Volumes instead",
"dbutils-fs-mount",
"Migrate all usage to Unity Catalog",
),
"E9889": (
"Credentials utility is not supported with Unity Catalog",
"dbutils-credentials",
"Migrate all usage to Unity Catalog",
),
"E9879": (
"""Use Databricks SDK instead: w.jobs.submit(
tasks=[jobs.SubmitTask(existing_cluster_id=...,
notebook_task=jobs.NotebookTask(notebook_path=%s),
task_key=...)
]).result(timeout=timedelta(minutes=%s))""",
"dbutils-notebook-run",
"Migrate all usage of dbutils to Databricks SDK",
),
"E9869": (
"Use Databricks SDK instead: from databricks.sdk import WorkspaceClient(); w = WorkspaceClient()",
"pat-token-leaked",
"Do not hardcode secrets in code, use Databricks Scopes instead",
),
}

def visit_call(self, node: astroid.Call):
# add message if dbutils.fs.cp() is used
if node.func.as_string() == "dbutils.fs.cp":
self.add_message("dbutils-fs-cp", node=node, args=(node.args[0].as_string(), node.args[1].as_string()))
# add message if dbutils.fs.head() is used
if node.func.as_string() == "dbutils.fs.head":
self.add_message("dbutils-fs-head", node=node, args=(node.args[0].as_string(),))
# add message if dbutils.fs.ls("/tmp") is used
if node.func.as_string() == "dbutils.fs.ls":
self.add_message("dbutils-fs-ls", node=node, args=(node.args[0].as_string(),))
# add message if dbutils.fs.mount("s3a://%s" % aws_bucket_name, "/mnt/%s" % mount_name) is used
if node.func.as_string() in {
"dbutils.fs.mount",
"dbutils.fs.mounts",
"dbutils.fs.unmount",
"dbutils.fs.updateMount",
"dbutils.fs.refreshMounts",
}:
self.add_message("dbutils-fs-mount", node=node)
# add message if dbutils.credentials.* is used
if node.func.as_string().startswith("dbutils.credentials."):
self.add_message("dbutils-credentials", node=node)
# add message if dbutils.notebook.run("My Other Notebook", 60) is used
if node.func.as_string() == "dbutils.notebook.run":
self.add_message(
"dbutils-notebook-run", node=node, args=(node.args[0].as_string(), node.args[1].as_string())
)

def visit_const(self, node: astroid.Const):
# add a message if string matches dapi[0-9a-f]{32}, dkea[0-9a-f]{32}, or dosa[0-9a-f]{32}
if node.value.startswith("dapi") or node.value.startswith("dkea") or node.value.startswith("dosa"):
self.add_message("pat-token-leaked", node=node)


def register(linter):
linter.register_checker(DbutilsChecker(linter))
90 changes: 30 additions & 60 deletions src/databricks/labs/pylint/notebooks.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,36 @@
import astroid
from pylint.checkers import BaseChecker, BaseRawFileChecker
from pylint.checkers import BaseRawFileChecker


class NotebookChecker(BaseRawFileChecker):
__implements__ = (BaseRawFileChecker,)

name = 'databricks-notebooks'
name = "databricks-notebooks"
msgs = {
'E9999': (
'dbutils.notebook.run() is not allowed',
'notebooks-dbutils-run',
'Used when dbutils.notebook.run() is used'
"E9996": (
"Notebooks should not have more than 75 cells",
"notebooks-too-many-cells",
"Used when the number of cells in a notebook is greater than 75",
),
'E9998': (
'dbutils.fs is not allowed',
'notebooks-dbutils-fs',
'Used when dbutils.fs is used'
),
'E9997': (
'dbutils.credentials is not allowed',
'notebooks-dbutils-credentials',
'Used when dbutils.credentials is used'
),
'E9996': (
'Notebooks should not have more than 75 cells',
'notebooks-too-many-cells',
'Used when the number of cells in a notebook is greater than 75'
),
'E9995': (
'Star import is not allowed',
'notebooks-star-import',
'Used when there is a star import from pyspark.sql.functions'
),
'E9994': (
'Using %run is not allowed',
'notebooks-percent-run',
'Used when `# MAGIC %run` comment is used',
"E9994": (
"Using %run is not allowed",
"notebooks-percent-run",
"Used when `# MAGIC %run` comment is used",
),
}

options = (
(
"max-cells",
{
"default": 75,
"type": "int",
"metavar": "<int>",
"help": "Maximum number of cells in the notebook",
},
),
)

def process_module(self, node: astroid.Module):
"""Read raw module. Need to do some tricks, as `ast` doesn't provide access for comments.
Expand All @@ -51,41 +43,19 @@ def process_module(self, node: astroid.Module):
"""
cells = 1
with node.stream() as stream:
for (lineno, line) in enumerate(stream):
for lineno, line in enumerate(stream):
lineno += 1
if lineno == 1 and line != b'# Databricks notebook source\n':
if lineno == 1 and line != b"# Databricks notebook source\n":
# this is not a Databricks notebook
return
if line == b'# COMMAND ----------\n':
if line == b"# COMMAND ----------\n":
cells += 1
if cells > 75:
self.add_message('notebooks-too-many-cells', line=lineno)
if cells > self.linter.config.max_cells:
self.add_message("notebooks-too-many-cells", line=lineno)
continue
if line.startswith(b'# MAGIC %run'):
self.add_message('notebooks-percent-run', line=lineno)

def visit_module(self, node):
# add message if dbutils.notebook.run() is used
if node.name == 'dbutils.notebook.run':
self.add_message('notebooks-dbutils-run', node=node)

# add message if dbutils.fs is used
if node.name == 'dbutils.fs':
self.add_message('notebooks-dbutils-fs', node=node)

# add message if dbutils.credentials is used
if node.name == 'dbutils.credentials':
self.add_message('notebooks-dbutils-credentials', node=node)

# Notebooks should not have more than 75 cells.
if len(node.body) > 75:
self.add_message('notebooks-too-many-cells', node=node)

def visit_importfrom(self, node: astroid.ImportFrom):
# add message if there's a star import from pyspark.sql.functions import *
if node.modname == 'pyspark.sql.functions' and node.names[0][0] == '*':
self.add_message('notebooks-star-import', node=node)
if line.startswith(b"# MAGIC %run"):
self.add_message("notebooks-percent-run", line=lineno)


def register(linter):
linter.register_checker(NotebookChecker(linter))
linter.register_checker(NotebookChecker(linter))
45 changes: 45 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Generic, TypeVar

import astroid
import astroid.rebuilder
import pytest
from pylint.checkers import BaseChecker
from pylint.testutils import UnittestLinter

T = TypeVar("T", bound=BaseChecker)


class TestSupport(Generic[T]):
def __init__(self, klass: type[T]):
linter = UnittestLinter()
checker = klass(linter)
checker.open()
linter.register_checker(checker)

self._checker = checker
self._linter = linter

def __lshift__(self, code: str):
node = astroid.extract_node(code)

klass_name = node.__class__.__name__
visitor = astroid.rebuilder.REDIRECT.get(klass_name, klass_name).lower()
getattr(self._checker, f"visit_{visitor}")(node)

out = set()
for message in self._linter.release_messages():
for message_definition in self._linter.msgs_store.get_message_definitions(message.msg_id):
user_message = message_definition.msg
if message.args:
user_message %= message.args
out.add(f"[{message.msg_id}] {user_message}")

return out


@pytest.fixture
def lint_with():
def factory(klass: type[T]) -> TestSupport[T]:
return TestSupport(klass)

yield factory
9 changes: 2 additions & 7 deletions tests/samples/TestForPylint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# and here we do star import
from pyspark.sql.functions import *


# # COMMAND ----------
#
# # but no dbutils.library.restartPython()
Expand All @@ -22,9 +21,7 @@

# COMMAND ----------

df = spark \
.table('samples.nyctaxi.trips') \
.limit(10)
df = spark.table("samples.nyctaxi.trips").limit(10)
display(df)

# COMMAND ----------
Expand All @@ -33,7 +30,5 @@

# COMMAND ----------

df = (spark
.table('samples.nyctaxi.trips')
.limit(10))
df = spark.table("samples.nyctaxi.trips").limit(10)
display(df)
Loading

0 comments on commit fc5c9d2

Please sign in to comment.