Skip to content

Commit

Permalink
add XCom tests for local attribute callables (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
topherinternational committed Jan 5, 2024
1 parent ca8dbfd commit 8277d04
Showing 1 changed file with 61 additions and 2 deletions.
63 changes: 61 additions & 2 deletions tests/pylint_airflow/checkers/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,34 @@ def aux_func():

assert result == expected_result

@pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True)
def test_should_detect_local_function_callables_as_attributes(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
from types import SimpleNamespace
def task_func():
print("bupkis")
def aux_func():
return 2 + 2
obj = SimpleNamespace()
obj.func = aux_func
local_task = PythonOperator(task_id="local_task", python_callable=task_func)
another_task = PythonOperator(task_id="another_task", python_callable=obj.func)
"""
ast = astroid.parse(test_code)
result = get_task_ids_to_python_callable_specs(ast)

expected_result = {
"local_task": PythonOperatorSpec(ast.body[6].value, "task_func"),
"another_task": PythonOperatorSpec(ast.body[7].value, "obj.func"),
}

assert result == expected_result


class TestGetXComsFromTasks:
"""Tests the get_xcoms_from_tasks helper function which detects the xcom pushes and pulls."""
Expand Down Expand Up @@ -217,7 +245,38 @@ def test_should_skip_imported_callable_as_name(self):

assert result == ({}, set())

def test_should_detect_xcom_push_tasks(self):
@pytest.mark.xfail(
reason="Not yet implemented", raises=astroid.AttributeInferenceError, strict=True
)
def test_should_skip_local_function_callables_as_attributes(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
from types import SimpleNamespace
def task_func():
print("bupkis")
def aux_func():
return 2 + 2
obj = SimpleNamespace()
obj.func = aux_func
local_task = PythonOperator(task_id="local_task", python_callable=task_func)
another_task = PythonOperator(task_id="another_task", python_callable=obj.func)
"""
ast = astroid.parse(test_code)

test_task_ids_to_python_callable_specs = {
"local_task": PythonOperatorSpec(ast.body[6].value, "task_func"),
"another_task": PythonOperatorSpec(ast.body[7].value, "obj.func"),
}

result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs)

assert result == ({}, set())

def test_should_detect_xcom_push_tasks_from_local_function_callables(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
Expand Down Expand Up @@ -261,7 +320,7 @@ def another_func():

assert result == expected_result

def test_should_detect_xcom_pull_tasks(self):
def test_should_detect_xcom_pull_tasks_from_local_function_callables(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
Expand Down

0 comments on commit 8277d04

Please sign in to comment.