diff --git a/tests/pylint_airflow/checkers/test_xcom.py b/tests/pylint_airflow/checkers/test_xcom.py index 046127b..b33716f 100644 --- a/tests/pylint_airflow/checkers/test_xcom.py +++ b/tests/pylint_airflow/checkers/test_xcom.py @@ -136,6 +136,34 @@ def aux_func(): assert result == expected_result + @pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True) + def test_should_detect_local_function_callables_as_attributes(self): + test_code = """ + from airflow.operators.python_operator import PythonOperator + from types import SimpleNamespace + + def task_func(): + print("bupkis") + + def aux_func(): + return 2 + 2 + + obj = SimpleNamespace() + obj.func = aux_func + + local_task = PythonOperator(task_id="local_task", python_callable=task_func) + another_task = PythonOperator(task_id="another_task", python_callable=obj.func) + """ + ast = astroid.parse(test_code) + result = get_task_ids_to_python_callable_specs(ast) + + expected_result = { + "local_task": PythonOperatorSpec(ast.body[6].value, "task_func"), + "another_task": PythonOperatorSpec(ast.body[7].value, "obj.func"), + } + + assert result == expected_result + class TestGetXComsFromTasks: """Tests the get_xcoms_from_tasks helper function which detects the xcom pushes and pulls.""" @@ -217,7 +245,38 @@ def test_should_skip_imported_callable_as_name(self): assert result == ({}, set()) - def test_should_detect_xcom_push_tasks(self): + @pytest.mark.xfail( + reason="Not yet implemented", raises=astroid.AttributeInferenceError, strict=True + ) + def test_should_skip_local_function_callables_as_attributes(self): + test_code = """ + from airflow.operators.python_operator import PythonOperator + from types import SimpleNamespace + + def task_func(): + print("bupkis") + + def aux_func(): + return 2 + 2 + + obj = SimpleNamespace() + obj.func = aux_func + + local_task = PythonOperator(task_id="local_task", python_callable=task_func) + another_task = PythonOperator(task_id="another_task", python_callable=obj.func) + """ + ast = astroid.parse(test_code) + + test_task_ids_to_python_callable_specs = { + "local_task": PythonOperatorSpec(ast.body[6].value, "task_func"), + "another_task": PythonOperatorSpec(ast.body[7].value, "obj.func"), + } + + result = get_xcoms_from_tasks(ast, test_task_ids_to_python_callable_specs) + + assert result == ({}, set()) + + def test_should_detect_xcom_push_tasks_from_local_function_callables(self): test_code = """ from airflow.operators.python_operator import PythonOperator @@ -261,7 +320,7 @@ def another_func(): assert result == expected_result - def test_should_detect_xcom_pull_tasks(self): + def test_should_detect_xcom_pull_tasks_from_local_function_callables(self): test_code = """ from airflow.operators.python_operator import PythonOperator