From 46ba0423a47bda84d93e1bd691c57c76c9fb8301 Mon Sep 17 00:00:00 2001 From: Topher Anderson <48180628+topherinternational@users.noreply.github.com> Date: Tue, 19 Dec 2023 09:52:14 +0100 Subject: [PATCH] refactor dag call node functions (#16) The former `dag_call_node_from_` functions are refactored to return strings (or None) rather than the entire DagCallNode. Some new test cases are also added to aid in future development of name-chain functionality. --- src/pylint_airflow/checkers/dag.py | 34 +++++++++++------------ tests/pylint_airflow/checkers/test_dag.py | 32 ++++++++++++++++++--- 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/pylint_airflow/checkers/dag.py b/src/pylint_airflow/checkers/dag.py index 196fc16..4d8b707 100644 --- a/src/pylint_airflow/checkers/dag.py +++ b/src/pylint_airflow/checkers/dag.py @@ -35,27 +35,21 @@ def is_inferred_value_subtype_of_dag(function_node: Optional[astroid.ClassDef]) ) -def dag_call_node_from_const( - const_node: astroid.Const, call_node: astroid.Call -) -> Optional[DagCallNode]: +def value_from_const_node(const_node: astroid.Const) -> Optional[str]: """Returns a DagCallNode instance with dag_id extracted from the const_node argument""" - return DagCallNode(str(const_node.value), call_node) + return str(const_node.value) -def dag_call_node_from_name( - name_node: astroid.Name, call_node: astroid.Call -) -> Optional[DagCallNode]: +def value_from_name_node(name_node: astroid.Name) -> Optional[str]: """Returns a DagCallNode instance with dag_id extracted from the name_node argument, or None if the node value can't be extracted.""" name_val = safe_infer(name_node) # TODO: follow name chains if isinstance(name_val, astroid.Const): - return dag_call_node_from_const(name_val, call_node) + return value_from_const_node(name_val) return None -def dag_call_node_from_joined_string( - joined_str_node: astroid.JoinedStr, call_node: astroid.Call -) -> Optional[DagCallNode]: +def value_from_joined_str_node(joined_str_node: astroid.JoinedStr) -> Optional[str]: """Returns a DagCallNode instance with dag_id composed from the elements of the joined_str_node argument, or None if the node value can't be extracted.""" dag_id_elements: List[str] = [] @@ -68,7 +62,7 @@ def dag_call_node_from_joined_string( return None elif isinstance(js_value, astroid.Const): dag_id_elements.append(str(js_value.value)) - return DagCallNode("".join(dag_id_elements), call_node) + return "".join(dag_id_elements) # TODO: follow name chains @@ -76,13 +70,17 @@ def dag_call_node_from_argument_value( argument_value: astroid.NodeNG, call_node: astroid.Call ) -> Optional[DagCallNode]: """Detects argument string from Const, Name or JoinedStr (f-string), or None if no match""" + val = None if isinstance(argument_value, astroid.Const): - return dag_call_node_from_const(argument_value, call_node) - if isinstance(argument_value, astroid.Name): - return dag_call_node_from_name(argument_value, call_node) - if isinstance(argument_value, astroid.JoinedStr): - return dag_call_node_from_joined_string(argument_value, call_node) - return None + val = value_from_const_node(argument_value) + elif isinstance(argument_value, astroid.Name): + val = value_from_name_node(argument_value) + elif isinstance(argument_value, astroid.JoinedStr): + val = value_from_joined_str_node(argument_value) + + if not val: # if we didn't get a real value from the chain above + return None + return DagCallNode(val, call_node) def find_dag_in_call_node(call_node: astroid.Call) -> Optional[DagCallNode]: diff --git a/tests/pylint_airflow/checkers/test_dag.py b/tests/pylint_airflow/checkers/test_dag.py index 33b1106..ab1f4e6 100644 --- a/tests/pylint_airflow/checkers/test_dag.py +++ b/tests/pylint_airflow/checkers/test_dag.py @@ -295,12 +295,28 @@ def test_invalid_nodes_should_return_none(self, test_statement): 'test_id = "my_dag"\n models.DAG(dag_id=f"{test_id}_0")', 'test_id = "my_dag"\n DAG(f"{test_id}_0")', 'test_id = "my_dag"\n models.DAG(f"{test_id}_0")', + 'test_id = "my_dag_0"\n my_id = test_id\n DAG(dag_id=my_id)', + 'test_id = "my_dag_0"\n my_id = test_id\n models.DAG(dag_id=my_id)', + 'test_id = "my_dag_0"\n my_id = test_id\n DAG(my_id)', + 'test_id = "my_dag_0"\n my_id = test_id\n models.DAG(my_id)', + 'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long ], ids=[ "DAG Name call with f-string dag_id keyword argument", "DAG Attribute call with f-string dag_id keyword argument", "DAG Name call with f-string dag_id positional argument", "DAG Attribute call with f-string dag_id positional argument", + "DAG Name call with double-variable dag_id keyword argument", + "DAG Attribute call with double-variable dag_id keyword argument", + "DAG Name call with double-variable dag_id positional argument", + "DAG Attribute call with double-variable dag_id positional argument", + "DAG Name call with triple-variable dag_id keyword argument", + "DAG Attribute call with triple-variable dag_id keyword argument", + "DAG Name call with triple-variable dag_id positional argument", + "DAG Attribute call with triple-variable dag_id positional argument", ], ) def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_statement): @@ -324,12 +340,20 @@ def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_ 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(dag_id=my_id)', 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(my_id)', 'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(my_id)', + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long + 'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long ], ids=[ - "DAG Name call with double-variable dag_id keyword argument", - "DAG Attribute call with double-variable dag_id keyword argument", - "DAG Name call with double-variable dag_id positional argument", - "DAG Attribute call with double-variable dag_id positional argument", + "DAG Name call with double-variable f-string dag_id keyword argument", + "DAG Attribute call with double-variable f-string dag_id keyword argument", + "DAG Name call with double-variable f-string dag_id positional argument", + "DAG Attribute call with double-variable f-string dag_id positional argument", + "DAG Name call with triple-variable f-string dag_id keyword argument", + "DAG Attribute call with triple-variable f-string dag_id keyword argument", + "DAG Name call with triple-variable f-string dag_id positional argument", + "DAG Attribute call with triple-variable f-string dag_id positional argument", ], ) @pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True)