From 9d8bc18beb6f278fe7ea54534a47335c385a53bd Mon Sep 17 00:00:00 2001 From: Topher Anderson <48180628+topherinternational@users.noreply.github.com> Date: Tue, 19 Dec 2023 13:40:25 +0100 Subject: [PATCH] find dag_ids from variable chains (#18) --- src/pylint_airflow/checkers/dag.py | 27 ++++++++++++--- tests/pylint_airflow/checkers/test_dag.py | 42 +++++------------------ 2 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/pylint_airflow/checkers/dag.py b/src/pylint_airflow/checkers/dag.py index 3bd75c0..a21274f 100644 --- a/src/pylint_airflow/checkers/dag.py +++ b/src/pylint_airflow/checkers/dag.py @@ -44,9 +44,28 @@ def value_from_name_node(name_node: astroid.Name) -> Optional[str]: """Returns a DagCallNode instance with dag_id extracted from the name_node argument, or None if the node value can't be extracted.""" name_val = safe_infer(name_node) # TODO: follow name chains - if isinstance(name_val, astroid.Const): - return value_from_const_node(name_val) - return None + if name_val: + if isinstance(name_val, astroid.Const): + return value_from_const_node(name_val) + return None + + # If astroid can't infer the name node value, we will have to walk the tree of assignments + return get_name_node_value_from_assignments(name_node) + + +def get_name_node_value_from_assignments(node: astroid.Name) -> Optional[str]: + """If a given Name node's value can't be inferred, we find out where the given name node was + assigned, and try to infer _that_ value. This function can/will get called recursively.""" + assign_frame_and_nodes = node.lookup(node.name) + for assign_name_node in assign_frame_and_nodes[1]: + if isinstance(assign_name_node, astroid.AssignName): + assign_node = assign_name_node.parent + if isinstance(assign_node, astroid.Assign): + assign_value = assign_node.value + return dag_id_from_argument_value(assign_value) + + # If we drop out of any of 'if' blocks, we give up + return None def value_from_joined_str_node(joined_str_node: astroid.JoinedStr) -> Optional[str]: @@ -108,7 +127,7 @@ def find_dag_in_call_node(call_node: astroid.Call) -> Optional[DagCallNode]: dag_id = dag_id_from_argument_value(call_node.args[0]) return DagCallNode(dag_id, call_node) if dag_id else None - # if we found neither a keyword arg or a positional arg + # If we found neither a keyword arg or a positional arg return None diff --git a/tests/pylint_airflow/checkers/test_dag.py b/tests/pylint_airflow/checkers/test_dag.py index ab1f4e6..ca7b683 100644 --- a/tests/pylint_airflow/checkers/test_dag.py +++ b/tests/pylint_airflow/checkers/test_dag.py @@ -303,6 +303,14 @@ def test_invalid_nodes_should_return_none(self, test_statement): 'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long 'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long 'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(dag_id=my_id)', + 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(dag_id=my_id)', + 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(my_id)', + 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(my_id)', + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long ], ids=[ "DAG Name call with f-string dag_id keyword argument", @@ -317,35 +325,6 @@ def test_invalid_nodes_should_return_none(self, test_statement): "DAG Attribute call with triple-variable dag_id keyword argument", "DAG Name call with triple-variable dag_id positional argument", "DAG Attribute call with triple-variable dag_id positional argument", - ], - ) - def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_statement): - test_code = f""" - from airflow import models - from airflow.models import DAG - - {test_statement} #@ - """ - - test_call = astroid.extract_node(test_code) - - result = find_dag_in_call_node(test_call) - - assert result == DagCallNode("my_dag_0", test_call) - - @pytest.mark.parametrize( - "test_statement", - [ - 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(dag_id=my_id)', - 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(dag_id=my_id)', - 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(my_id)', - 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(my_id)', - 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long - 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long - 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long - 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long - ], - ids=[ "DAG Name call with double-variable f-string dag_id keyword argument", "DAG Attribute call with double-variable f-string dag_id keyword argument", "DAG Name call with double-variable f-string dag_id positional argument", @@ -356,10 +335,7 @@ def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_ "DAG Attribute call with triple-variable f-string dag_id positional argument", ], ) - @pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True) - def test_future_work_valid_dag_call_with_variables_should_return_dag_id_and_node( - self, test_statement - ): + def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_statement): test_code = f""" from airflow import models from airflow.models import DAG