Skip to content

Commit

Permalink
refactor operator checker (#19)
Browse files Browse the repository at this point in the history
A bunch of refactor items:
- Break the OperatorChecker into more methods/functions
- Each pylint rule we implement has a corresponding `check_<rule>()`
method, isolating the logic of the check from how we procure the nodes
needed to compute the check result
- Test the `check` methods individually 
- Use a custom dataclass for extracting the task/operator parameters
from an assignment node
- Move the pylint MSGS to the top of each module for easier reading
  • Loading branch information
topherinternational authored Dec 20, 2023
1 parent 9d8bc18 commit 0915c73
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 127 deletions.
87 changes: 45 additions & 42 deletions src/pylint_airflow/checkers/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,47 @@

from pylint_airflow.__pkginfo__ import BASE_ID

DAG_CHECKER_MSGS = {
f"W{BASE_ID}00": (
"TODO Don't place BaseHook calls at the top level of DAG script",
"basehook-top-level",
"TODO Airflow executes DAG scripts periodically and anything at the top level "
"of a script is executed. Therefore, move BaseHook calls into "
"functions/hooks/operators.",
),
f"E{BASE_ID}00": (
"DAG name %s already used",
"duplicate-dag-name",
"DAG name should be unique.",
),
f"E{BASE_ID}01": (
"TODO Task name {} already used",
"duplicate-task-name",
"TODO Task name within a DAG should be unique.",
),
f"E{BASE_ID}02": (
"TODO Task dependency {}->{} already set",
"duplicate-dependency",
"TODO Task dependencies can be defined only once.",
),
f"E{BASE_ID}03": (
"TODO DAG {} contains cycles",
"dag-with-cycles",
"TODO A DAG is acyclic and cannot contain cycles.",
),
f"E{BASE_ID}04": (
"TODO Task {} is not bound to any DAG instance",
"task-no-dag",
"TODO A task must know a DAG instance to run.",
),
f"C{BASE_ID}06": (
"For consistency match the DAG filename with the dag_id",
"match-dagid-filename",
"For consistency match the DAG filename with the dag_id.",
),
# TODO: add check to force kwargs for DAG definitions
}


@dataclass
class DagCallNode:
Expand Down Expand Up @@ -43,7 +84,7 @@ def value_from_const_node(const_node: astroid.Const) -> Optional[str]:
def value_from_name_node(name_node: astroid.Name) -> Optional[str]:
"""Returns a DagCallNode instance with dag_id extracted from the name_node argument,
or None if the node value can't be extracted."""
name_val = safe_infer(name_node) # TODO: follow name chains
name_val = safe_infer(name_node)
if name_val:
if isinstance(name_val, astroid.Const):
return value_from_const_node(name_val)
Expand All @@ -64,12 +105,12 @@ def get_name_node_value_from_assignments(node: astroid.Name) -> Optional[str]:
assign_value = assign_node.value
return dag_id_from_argument_value(assign_value)

# If we drop out of any of 'if' blocks, we give up
# If we drop out of any 'if' blocks, we give up
return None


def value_from_joined_str_node(joined_str_node: astroid.JoinedStr) -> Optional[str]:
"""Returns a DagCallNode instance with dag_id composed from the elements of the
"""Returns a DagCallNode instance with dag_id composed by joining the elements of the
joined_str_node argument, or None if the node value can't be extracted."""
dag_id_elements: List[str] = []
for js_value in joined_str_node.values:
Expand Down Expand Up @@ -134,45 +175,7 @@ def find_dag_in_call_node(call_node: astroid.Call) -> Optional[DagCallNode]:
class DagChecker(checkers.BaseChecker):
"""Checks conditions in the context of (a) complete DAG(s)."""

msgs = {
f"W{BASE_ID}00": (
"Don't place BaseHook calls at the top level of DAG script",
"basehook-top-level",
"Airflow executes DAG scripts periodically and anything at the top level "
"of a script is executed. Therefore, move BaseHook calls into "
"functions/hooks/operators.",
),
f"E{BASE_ID}00": (
"DAG name %s already used",
"duplicate-dag-name",
"DAG name should be unique.",
),
f"E{BASE_ID}01": (
"Task name {} already used",
"duplicate-task-name",
"Task name within a DAG should be unique.",
),
f"E{BASE_ID}02": (
"Task dependency {}->{} already set",
"duplicate-dependency",
"Task dependencies can be defined only once.",
),
f"E{BASE_ID}03": (
"DAG {} contains cycles",
"dag-with-cycles",
"A DAG is acyclic and cannot contain cycles.",
),
f"E{BASE_ID}04": (
"Task {} is not bound to any DAG instance",
"task-no-dag",
"A task must know a DAG instance to run.",
),
f"C{BASE_ID}06": (
"For consistency match the DAG filename with the dag_id",
"match-dagid-filename",
"For consistency match the DAG filename with the dag_id.",
),
}
msgs = DAG_CHECKER_MSGS

@staticmethod
def _dagids_to_deduplicated_nodes(
Expand Down
247 changes: 165 additions & 82 deletions src/pylint_airflow/checkers/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Checks on Airflow operators."""
import logging
from dataclasses import dataclass
from typing import Set, Optional

import astroid
from pylint import checkers
Expand All @@ -7,47 +10,137 @@

from pylint_airflow.__pkginfo__ import BASE_ID

logging.basicConfig(level=logging.WARNING)

OPERATOR_CHECKER_MSGS = {
f"C{BASE_ID}00": (
"Operator variable name and task_id argument should match",
"different-operator-varname-taskid",
"For consistency assign the same variable name and task_id to operators.",
),
f"C{BASE_ID}01": (
"Name the python_callable function '_[task_id]'",
"match-callable-taskid",
"For consistency name the callable function '_[task_id]', e.g. "
"PythonOperator(task_id='mytask', python_callable=_mytask).",
),
f"C{BASE_ID}02": (
"Avoid mixing task dependency directions",
"mixed-dependency-directions",
"For consistency don't mix directions in a single statement, instead split "
"over multiple statements.",
),
f"C{BASE_ID}03": (
"TODO Task {} has no dependencies. Verify or disable message.",
"task-no-dependencies",
"TODO Sometimes a task without any dependency is desired, however often it is "
"the result of a forgotten dependency.",
),
f"C{BASE_ID}04": (
"TODO Rename **kwargs variable to **context to show intent for Airflow task context",
"task-context-argname",
"TODO Indicate you expect Airflow task context variables in the **kwargs "
"argument by renaming to **context.",
),
f"C{BASE_ID}05": (
"TODO Extract variables from keyword arguments for explicitness",
"task-context-separate-arg",
"TODO To avoid unpacking kwargs from the Airflow task context in a function, you "
"can set the needed variables as arguments in the function.",
),
# TODO: add check to force kwargs for task definitions
# TODO: modify check to allow task_id matching python_callable name (without underscore)
}


@dataclass
class TaskParameters:
"""
Data class to hold parameters extracted from an assignment involving the instantiation of
an Operator/task.
var_name is always present
task_id might be missing if a constructor call is malformed (tasks have to have IDs)
python_callable_name will only be present for a PythonOperator
"""

var_name: str
task_id: Optional[str] = None
python_callable_name: Optional[str] = None


def collect_operators_from_binops(working_node: astroid.BinOp) -> Set[str]:
"""
Function for collecting binary operations (>> and/or <<); called with recursion.
"""
binops_found = set()
if isinstance(working_node.left, astroid.BinOp):
binops_found.update(collect_operators_from_binops(working_node.left))
if isinstance(working_node.right, astroid.BinOp):
binops_found.update(collect_operators_from_binops(working_node.right))
if working_node.op in (">>", "<<"):
binops_found.add(working_node.op)

return binops_found


def is_assign_call_subtype_of_base_operator(node: astroid.Assign) -> bool:
"""Tests an Assign node and returns True if all of the following are true:
* The Assign value is a Call object
* The Call's func member can be inferred to a node
* The inferred value is not a BoundMethod (a method on a class instance)
* The inferred value is a ClassDef object (has "is_subtype_of" attribute)
* The inferred value is a subtype of "airflow.models.BaseOperator" or
"airflow.models.baseoperator.BaseOperator"
"""
if not isinstance(node.value, astroid.Call):
return False

function_node = safe_infer(node.value.func)
return (
function_node
and not isinstance(function_node, astroid.bases.BoundMethod)
and hasattr(function_node, "is_subtype_of")
and (
function_node.is_subtype_of("airflow.models.BaseOperator")
or function_node.is_subtype_of("airflow.models.baseoperator.BaseOperator")
# ^ TODO: are both of these subtypes relevant?
)
)


def get_task_parameters_from_assign(node: astroid.Assign) -> TaskParameters:
"""Extracts the callable name, task_id and var_name from an assignment whose right side is an
Operator construction (a task). callable_name and task_id can be None (showing an
underspecified task whose linting should be skipped)."""

assign_target = node.targets[0]
if not isinstance(assign_target, astroid.AssignName):
raise ValueError(
f"Target of Assign node {node} is not an AssignName ({assign_target});"
f" task cannot be linted."
)

var_name = assign_target.name
task_id = None
python_callable_name = None

if isinstance(node.value, astroid.Call): # we know this, but a check gives us type inference
for keyword in node.value.keywords:
if keyword.arg == "task_id" and isinstance(keyword.value, astroid.Const):
# TODO support other values than constants
task_id = keyword.value.value
continue
if keyword.arg == "python_callable" and isinstance(keyword.value, astroid.Name):
python_callable_name = keyword.value.name

return TaskParameters(var_name, task_id, python_callable_name)


class OperatorChecker(checkers.BaseChecker):
"""Checks on Airflow operators."""

msgs = {
f"C{BASE_ID}00": (
"Operator variable name and task_id argument should match",
"different-operator-varname-taskid",
"For consistency assign the same variable name and task_id to operators.",
),
f"C{BASE_ID}01": (
"Name the python_callable function '_[task_id]'",
"match-callable-taskid",
"For consistency name the callable function '_[task_id]', e.g. "
"PythonOperator(task_id='mytask', python_callable=_mytask).",
),
f"C{BASE_ID}02": (
"Avoid mixing task dependency directions",
"mixed-dependency-directions",
"For consistency don't mix directions in a single statement, instead split "
"over multiple statements.",
),
f"C{BASE_ID}03": (
"Task {} has no dependencies. Verify or disable message.",
"task-no-dependencies",
"Sometimes a task without any dependency is desired, however often it is "
"the result of a forgotten dependency.",
),
f"C{BASE_ID}04": (
"Rename **kwargs variable to **context to show intent for Airflow task context",
"task-context-argname",
"Indicate you expect Airflow task context variables in the **kwargs "
"argument by renaming to **context.",
),
f"C{BASE_ID}05": (
"Extract variables from keyword arguments for explicitness",
"task-context-separate-arg",
"To avoid unpacking kwargs from the Airflow task context in a function, you "
"can set the needed variables as arguments in the function.",
),
}
msgs = OPERATOR_CHECKER_MSGS

@utils.only_required_for_messages("different-operator-varname-taskid", "match-callable-taskid")
def visit_assign(self, node):
Expand All @@ -63,54 +156,44 @@ def _mytask(): print("dosomething")
def invalidname(): print("dosomething")
mytask = PythonOperator(task_id="mytask", python_callable=invalidname)
"""
if isinstance(node.value, astroid.Call):
function_node = safe_infer(node.value.func)
if (
function_node is not None
and not isinstance(function_node, astroid.bases.BoundMethod)
and hasattr(function_node, "is_subtype_of")
and (
function_node.is_subtype_of("airflow.models.BaseOperator")
or function_node.is_subtype_of("airflow.models.baseoperator.BaseOperator")
)
):
var_name = node.targets[0].name
task_id = None
python_callable_name = None

for keyword in node.value.keywords:
if keyword.arg == "task_id" and isinstance(keyword.value, astroid.Const):
# TODO support other values than constants
task_id = keyword.value.value
continue
if keyword.arg == "python_callable":
python_callable_name = keyword.value.name

if var_name != task_id:
self.add_message("different-operator-varname-taskid", node=node)

if python_callable_name and f"_{task_id}" != python_callable_name:
self.add_message("match-callable-taskid", node=node)

if is_assign_call_subtype_of_base_operator(node):
try:
task_parameters = get_task_parameters_from_assign(node)
except ValueError as val_err:
logging.warning("Task assignment expression could not be analyzed\n%s", val_err)
else:
self.check_operator_varname_versus_task_id(node, task_parameters)
self.check_callable_name_versus_task_id(node, task_parameters)

def check_operator_varname_versus_task_id(
self, node: astroid.Assign, task_parameters: TaskParameters
) -> None:
"""Adds a message if the assigned variable name and the task ID do not match.
A message is not added if either string argument is empty ("") or None."""
var_name = task_parameters.var_name
task_id = task_parameters.task_id
if var_name and task_id and var_name != task_id:
self.add_message("different-operator-varname-taskid", node=node)

def check_callable_name_versus_task_id(
self, node: astroid.Assign, task_parameters: TaskParameters
) -> None:
"""Adds a message if the callable name and the task ID prefixed with an underscore
do not match. A message is not added if either string argument is empty ("") or None."""
task_id = task_parameters.task_id
python_callable_name = task_parameters.python_callable_name
if python_callable_name and task_id and f"_{task_id}" != python_callable_name:
self.add_message("match-callable-taskid", node=node)

@utils.only_required_for_messages("mixed-dependency-directions")
def visit_binop(self, node):
"""Check for mixed dependency directions."""

def fetch_binops(node_):
"""
Method fetching binary operations (>> and/or <<).
Resides in separate function for recursion.
"""
binops_found = set()
if isinstance(node_.left, astroid.BinOp):
binops_found.update(fetch_binops(node_.left))
if isinstance(node_.right, astroid.BinOp):
binops_found.update(fetch_binops(node_.right))
if node_.op in (">>", "<<"):
binops_found.add(node_.op)

return binops_found

binops = fetch_binops(node)
if ">>" in binops and "<<" in binops:
self.check_mixed_dependency_directions(node)

def check_mixed_dependency_directions(self, node: astroid.BinOp) -> None:
"""Check for mixed dependency directions (a BinOp chain contains both >> and <<)."""
collected_operators = collect_operators_from_binops(node)
if ">>" in collected_operators and "<<" in collected_operators:
self.add_message("mixed-dependency-directions", node=node)
Loading

0 comments on commit 0915c73

Please sign in to comment.