Skip to content

Commit

Permalink
find dag_ids from variable chains (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
topherinternational authored Dec 19, 2023
1 parent 9c8fccb commit 9d8bc18
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 37 deletions.
27 changes: 23 additions & 4 deletions src/pylint_airflow/checkers/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,28 @@ def value_from_name_node(name_node: astroid.Name) -> Optional[str]:
"""Returns a DagCallNode instance with dag_id extracted from the name_node argument,
or None if the node value can't be extracted."""
name_val = safe_infer(name_node) # TODO: follow name chains
if isinstance(name_val, astroid.Const):
return value_from_const_node(name_val)
return None
if name_val:
if isinstance(name_val, astroid.Const):
return value_from_const_node(name_val)
return None

# If astroid can't infer the name node value, we will have to walk the tree of assignments
return get_name_node_value_from_assignments(name_node)


def get_name_node_value_from_assignments(node: astroid.Name) -> Optional[str]:
"""If a given Name node's value can't be inferred, we find out where the given name node was
assigned, and try to infer _that_ value. This function can/will get called recursively."""
assign_frame_and_nodes = node.lookup(node.name)
for assign_name_node in assign_frame_and_nodes[1]:
if isinstance(assign_name_node, astroid.AssignName):
assign_node = assign_name_node.parent
if isinstance(assign_node, astroid.Assign):
assign_value = assign_node.value
return dag_id_from_argument_value(assign_value)

# If we drop out of any of 'if' blocks, we give up
return None


def value_from_joined_str_node(joined_str_node: astroid.JoinedStr) -> Optional[str]:
Expand Down Expand Up @@ -108,7 +127,7 @@ def find_dag_in_call_node(call_node: astroid.Call) -> Optional[DagCallNode]:
dag_id = dag_id_from_argument_value(call_node.args[0])
return DagCallNode(dag_id, call_node) if dag_id else None

# if we found neither a keyword arg or a positional arg
# If we found neither a keyword arg or a positional arg
return None


Expand Down
42 changes: 9 additions & 33 deletions tests/pylint_airflow/checkers/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,14 @@ def test_invalid_nodes_should_return_none(self, test_statement):
'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long
'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(dag_id=my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(dag_id=my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(my_id)',
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long
],
ids=[
"DAG Name call with f-string dag_id keyword argument",
Expand All @@ -317,35 +325,6 @@ def test_invalid_nodes_should_return_none(self, test_statement):
"DAG Attribute call with triple-variable dag_id keyword argument",
"DAG Name call with triple-variable dag_id positional argument",
"DAG Attribute call with triple-variable dag_id positional argument",
],
)
def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_statement):
test_code = f"""
from airflow import models
from airflow.models import DAG
{test_statement} #@
"""

test_call = astroid.extract_node(test_code)

result = find_dag_in_call_node(test_call)

assert result == DagCallNode("my_dag_0", test_call)

@pytest.mark.parametrize(
"test_statement",
[
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(dag_id=my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(dag_id=my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(my_id)',
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long
],
ids=[
"DAG Name call with double-variable f-string dag_id keyword argument",
"DAG Attribute call with double-variable f-string dag_id keyword argument",
"DAG Name call with double-variable f-string dag_id positional argument",
Expand All @@ -356,10 +335,7 @@ def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_
"DAG Attribute call with triple-variable f-string dag_id positional argument",
],
)
@pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True)
def test_future_work_valid_dag_call_with_variables_should_return_dag_id_and_node(
self, test_statement
):
def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_statement):
test_code = f"""
from airflow import models
from airflow.models import DAG
Expand Down

0 comments on commit 9d8bc18

Please sign in to comment.