Skip to content

Commit

Permalink
augment XCom tests (#22)
Browse files Browse the repository at this point in the history
add testing for helper functions `get_task_ids_to_python_callable_specs`
and `get_xcoms_from_tasks`
  • Loading branch information
topherinternational authored Jan 4, 2024
1 parent 33e8730 commit ca8dbfd
Show file tree
Hide file tree
Showing 2 changed files with 305 additions and 7 deletions.
6 changes: 5 additions & 1 deletion src/pylint_airflow/checkers/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ def get_task_ids_to_python_callable_specs(node: astroid.Module) -> Dict[str, Pyt

# Store nodes containing python_callable arg as:
# {task_id: PythonOperatorSpec(call node, python_callable func name)}

task_ids_to_python_callable_specs = {}
for call_node in call_nodes:
if call_node.keywords:
task_id = ""
python_callable_function_name = ""
for keyword in call_node.keywords:
if keyword.arg == "python_callable" and isinstance(keyword.value, astroid.Name):
python_callable_function_name = keyword.value.name
python_callable_function_name = keyword.value.name # TODO: support lambdas
elif keyword.arg == "task_id" and isinstance(keyword.value, astroid.Const):
task_id = keyword.value.value # TODO: support non-Const args

Expand All @@ -77,6 +78,7 @@ def get_xcoms_from_tasks(
continue

callable_func = node.getattr(callable_func_name)[0]
# ^ TODO: handle builtins and attribute imports that will raise on this call

if not isinstance(callable_func, astroid.FunctionDef):
continue # Callable_func is str not FunctionDef when imported
Expand All @@ -97,6 +99,8 @@ def get_xcoms_from_tasks(
for keyword in callable_func_call.keywords:
if keyword.arg == "task_ids" and isinstance(keyword.value, astroid.Const):
xcoms_pulled_taskids.add(keyword.value.value)
# TODO: add support for xcom 'key' argument
# TODO: add support for non-Const argument values

return xcoms_pushed, xcoms_pulled_taskids

Expand Down
306 changes: 300 additions & 6 deletions tests/pylint_airflow/checkers/test_xcom.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,305 @@
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring,use-implicit-booleaness-not-comparison
"""Tests for the XCom checker and its helper functions."""
from unittest.mock import Mock

import astroid
import pytest
from pylint.testutils import CheckerTestCase, MessageTest

import pylint_airflow
from pylint_airflow.checkers.xcom import PythonOperatorSpec
from pylint_airflow.checkers.xcom import (
PythonOperatorSpec,
get_task_ids_to_python_callable_specs,
get_xcoms_from_tasks,
)


class TestGetTaskIdsToPythonCallableSpecs:
"""Tests the get_task_ids_to_python_callable_specs helper function which detects the
python_callable functions passed to PythonOperator constructions."""

@pytest.mark.parametrize(
"test_code",
[
"""# empty module""",
"""
print("test this")
int("5")
""",
"""
x = 5
y = x + 2
""",
"""
x = int(x="-5") # keyword arg
y = abs(x) # positional arg
""",
],
ids=[
"no code",
"no assignments",
"no calls",
"no operators",
],
)
def test_should_return_empty_when_no_tasks(self, test_code):
ast = astroid.parse(test_code)

result = get_task_ids_to_python_callable_specs(ast)

assert result == {}

def test_should_skip_lambda_callable(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
lambda_task = PythonOperator(task_id="lambda_task", python_callable=lambda x: print(x))
"""
ast = astroid.parse(test_code)

result = get_task_ids_to_python_callable_specs(ast)

assert result == {}

def test_should_detect_builtin_callable(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
builtin_task = PythonOperator(task_id="builtin_task", python_callable=list)
"""
ast = astroid.parse(test_code)

result = get_task_ids_to_python_callable_specs(ast)

expected_result = {
"builtin_task": PythonOperatorSpec(ast.body[1].value, "list"),
}

assert result == expected_result

@pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True)
def test_should_detect_imported_callable_as_attribute(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
from datetime import date
builtin_task = PythonOperator(task_id="builtin_task", python_callable=date.today)
"""
ast = astroid.parse(test_code)

result = get_task_ids_to_python_callable_specs(ast)

expected_result = {
"builtin_task": PythonOperatorSpec(ast.body[2].value, "date.today"),
}

assert result == expected_result

def test_should_detect_imported_callable_as_name(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
from datetime.date import today
builtin_task = PythonOperator(task_id="builtin_task", python_callable=today)
"""
ast = astroid.parse(test_code)

result = get_task_ids_to_python_callable_specs(ast)

expected_result = {
"builtin_task": PythonOperatorSpec(ast.body[2].value, "today"),
}

assert result == expected_result

def test_should_detect_local_function_callables(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
def task_func():
print("bupkis")
def aux_func():
return 2 + 2
local_task = PythonOperator(task_id="local_task", python_callable=task_func)
another_task = PythonOperator(task_id="another_task", python_callable=aux_func)
"""
ast = astroid.parse(test_code)

result = get_task_ids_to_python_callable_specs(ast)

expected_result = {
"local_task": PythonOperatorSpec(ast.body[3].value, "task_func"),
"another_task": PythonOperatorSpec(ast.body[4].value, "aux_func"),
}

assert result == expected_result


class TestGetXComsFromTasks:
"""Tests the get_xcoms_from_tasks helper function which detects the xcom pushes and pulls."""

def test_should_return_empty_on_empty_input(self):
result = get_xcoms_from_tasks(Mock(), {})
# node argument isn't used when input dict is empty, so we can simply use a Mock object

assert result == ({}, set())

def test_should_skip_lambda_callable(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
lambda_task = PythonOperator(task_id="lambda_task", python_callable=lambda x: print(x))
"""
ast = astroid.parse(test_code)
test_task_ids_to_python_callable_specs = {
"lambda_task": PythonOperatorSpec(ast.body[1].value, "<lambda>")
}

result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs)

assert result == ({}, set())

@pytest.mark.xfail(
reason="Not yet implemented", raises=astroid.AttributeInferenceError, strict=True
)
def test_should_skip_builtin_callable(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
builtin_task = PythonOperator(task_id="builtin_task", python_callable=list)
"""
ast = astroid.parse(test_code)

test_task_ids_to_python_callable_specs = {
"builtin_task": PythonOperatorSpec(ast.body[1].value, "list"),
}

result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs)

assert result == ({}, set())

@pytest.mark.xfail(
reason="Not yet implemented", raises=astroid.AttributeInferenceError, strict=True
)
def test_should_skip_imported_callable_as_attribute(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
from datetime import date
builtin_task = PythonOperator(task_id="builtin_task", python_callable=date.today)
"""
ast = astroid.parse(test_code)

test_task_ids_to_python_callable_specs = {
"builtin_task": PythonOperatorSpec(ast.body[2].value, "date.today"),
}

result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs)

assert result == ({}, set())

def test_should_skip_imported_callable_as_name(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
from datetime.date import today
builtin_task = PythonOperator(task_id="builtin_task", python_callable=today)
"""
ast = astroid.parse(test_code)

test_task_ids_to_python_callable_specs = {
"builtin_task": PythonOperatorSpec(ast.body[2].value, "today"),
}

result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs)

assert result == ({}, set())

def test_should_detect_xcom_push_tasks(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
def task_func():
print("bupkis")
return "done"
def aux_func():
return 2 + 2
def another_func():
print
# TODO: detect a naked return statement and don't detect it as an xcom push
# TODO: detect function inputs as xcom pulls when appropriate
local_task = PythonOperator(task_id="local_task", python_callable=task_func)
aux_task = PythonOperator(task_id="aux_task", python_callable=aux_func)
another_task = PythonOperator(task_id="another_task", python_callable=another_func)
"""
ast = astroid.parse(test_code)

local_task_spec = PythonOperatorSpec(ast.body[4].value, "task_func")
aux_task_spec = PythonOperatorSpec(ast.body[5].value, "aux_func")
another_task_spec = PythonOperatorSpec(ast.body[6].value, "another_func")
test_task_ids_to_python_callable_specs = {
"local_task": local_task_spec,
"aux_task": aux_task_spec,
"another_task": another_task_spec,
}

result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs)

expected_result = (
{ # xcom pushes
"local_task": local_task_spec,
"aux_task": aux_task_spec,
},
set(), # no xcom_pulls
)

assert result == expected_result

def test_should_detect_xcom_pull_tasks(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
def push_func(task_instance, **_):
print("bupkis")
return "done"
def pull_func():
push_val = task_instance.xcom_pull(task_ids="push_task")
def pull_again_func():
print(task_instance.xcom_pull(task_ids="push_task"))
push_task = PythonOperator(task_id="push_task", python_callable=push_func)
pull_task = PythonOperator(task_id="pull_task", python_callable=pull_func, provide_context=True)
pull_again_task = PythonOperator(task_id="pull_again_task", python_callable=pull_again_func, provide_context=True)
"""
ast = astroid.parse(test_code)

local_task_spec = PythonOperatorSpec(ast.body[4].value, "push_func")
aux_task_spec = PythonOperatorSpec(ast.body[5].value, "pull_func")
another_task_spec = PythonOperatorSpec(ast.body[6].value, "pull_again_func")
test_task_ids_to_python_callable_specs = {
"push_task": local_task_spec,
"pull_task": aux_task_spec,
"pull_again_task": another_task_spec,
}

result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs)

expected_result = (
{
"push_task": local_task_spec,
},
{"push_task"},
)

assert result == expected_result


class TestXComChecker(CheckerTestCase):
Expand All @@ -15,7 +309,7 @@ class TestXComChecker(CheckerTestCase):

def test_used_xcom(self):
"""Test valid case: _pushtask() returns a value and _pulltask pulls and uses it."""
testcase = """
test_code = """
from airflow.operators.python_operator import PythonOperator
def _pushtask():
Expand All @@ -29,13 +323,13 @@ def _pulltask(task_instance, **_):
pulltask = PythonOperator(task_id="pulltask", python_callable=_pulltask, provide_context=True)
"""
ast = astroid.parse(testcase)
ast = astroid.parse(test_code)
with self.assertNoMessages():
self.checker.visit_module(ast)

def test_unused_xcom(self):
"""Test invalid case: _pushtask() returns a value but it's never used."""
testcase = """
test_code = """
from airflow.operators.python_operator import PythonOperator
def _pushtask():
Expand All @@ -49,7 +343,7 @@ def _pulltask():
pulltask = PythonOperator(task_id="pulltask", python_callable=_pulltask)
"""
ast = astroid.parse(testcase)
ast = astroid.parse(test_code)
expected_msg_node = ast.body[2].value
expected_args = "_pushtask"
with self.assertAddsMessages(
Expand Down

0 comments on commit ca8dbfd

Please sign in to comment.