Skip to content

Commit

Permalink
add detection of attribute callables in XComChecker (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
topherinternational authored Jan 5, 2024
1 parent 8277d04 commit 57f3aa7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
31 changes: 21 additions & 10 deletions src/pylint_airflow/checkers/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Set, Dict, Tuple

import astroid
from astroid import AttributeInferenceError
from pylint import checkers
from pylint.checkers import utils

Expand Down Expand Up @@ -39,6 +40,7 @@ class PythonOperatorSpec:


def get_task_ids_to_python_callable_specs(node: astroid.Module) -> Dict[str, PythonOperatorSpec]:
# pylint: disable=too-many-nested-blocks
"""Fill this in"""
assign_nodes = node.nodes_of_class(astroid.Assign)
call_nodes = [assign.value for assign in assign_nodes if isinstance(assign.value, astroid.Call)]
Expand All @@ -50,16 +52,21 @@ def get_task_ids_to_python_callable_specs(node: astroid.Module) -> Dict[str, Pyt
for call_node in call_nodes:
if call_node.keywords:
task_id = ""
python_callable_function_name = ""
python_callable_name = ""
for keyword in call_node.keywords:
if keyword.arg == "python_callable" and isinstance(keyword.value, astroid.Name):
python_callable_function_name = keyword.value.name # TODO: support lambdas
elif keyword.arg == "task_id" and isinstance(keyword.value, astroid.Const):
task_id = keyword.value.value # TODO: support non-Const args

if python_callable_function_name:
kw_value = keyword.value
if keyword.arg == "python_callable":
if isinstance(kw_value, astroid.Name): # TODO: support lambdas
python_callable_name = kw_value.name
elif isinstance(kw_value, astroid.Attribute):
if isinstance(kw_value.expr, astroid.Name):
python_callable_name = f"{kw_value.expr.name}.{kw_value.attrname}"
elif keyword.arg == "task_id" and isinstance(kw_value, astroid.Const):
task_id = kw_value.value # TODO: support non-Const args

if python_callable_name:
task_ids_to_python_callable_specs[task_id] = PythonOperatorSpec(
call_node, python_callable_function_name
call_node, python_callable_name
)

return task_ids_to_python_callable_specs
Expand All @@ -77,8 +84,12 @@ def get_xcoms_from_tasks(
if callable_func_name == "<lambda>": # TODO support lambdas
continue

callable_func = node.getattr(callable_func_name)[0]
# ^ TODO: handle builtins and attribute imports that will raise on this call
try:
module_attribute = node.getattr(callable_func_name)
except AttributeInferenceError:
continue
else:
callable_func = module_attribute[0]

if not isinstance(callable_func, astroid.FunctionDef):
continue # Callable_func is str not FunctionDef when imported
Expand Down
11 changes: 0 additions & 11 deletions tests/pylint_airflow/checkers/test_xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def test_should_detect_builtin_callable(self):

assert result == expected_result

@pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True)
def test_should_detect_imported_callable_as_attribute(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
Expand Down Expand Up @@ -136,7 +135,6 @@ def aux_func():

assert result == expected_result

@pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True)
def test_should_detect_local_function_callables_as_attributes(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
Expand Down Expand Up @@ -189,9 +187,6 @@ def test_should_skip_lambda_callable(self):

assert result == ({}, set())

@pytest.mark.xfail(
reason="Not yet implemented", raises=astroid.AttributeInferenceError, strict=True
)
def test_should_skip_builtin_callable(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
Expand All @@ -208,9 +203,6 @@ def test_should_skip_builtin_callable(self):

assert result == ({}, set())

@pytest.mark.xfail(
reason="Not yet implemented", raises=astroid.AttributeInferenceError, strict=True
)
def test_should_skip_imported_callable_as_attribute(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
Expand Down Expand Up @@ -245,9 +237,6 @@ def test_should_skip_imported_callable_as_name(self):

assert result == ({}, set())

@pytest.mark.xfail(
reason="Not yet implemented", raises=astroid.AttributeInferenceError, strict=True
)
def test_should_skip_local_function_callables_as_attributes(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
Expand Down

0 comments on commit 57f3aa7

Please sign in to comment.