Skip to content

Commit

Permalink
refactor dag call node functions (#16)
Browse files Browse the repository at this point in the history
The former `dag_call_node_from_<node-type>` functions are refactored to
return strings (or None) rather than the entire DagCallNode. Some new
test cases are also added to aid in future development of name-chain
functionality.
  • Loading branch information
topherinternational authored Dec 19, 2023
1 parent 496ca09 commit 46ba042
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 22 deletions.
34 changes: 16 additions & 18 deletions src/pylint_airflow/checkers/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,21 @@ def is_inferred_value_subtype_of_dag(function_node: Optional[astroid.ClassDef])
)


def dag_call_node_from_const(
const_node: astroid.Const, call_node: astroid.Call
) -> Optional[DagCallNode]:
def value_from_const_node(const_node: astroid.Const) -> Optional[str]:
"""Returns a DagCallNode instance with dag_id extracted from the const_node argument"""
return DagCallNode(str(const_node.value), call_node)
return str(const_node.value)


def dag_call_node_from_name(
name_node: astroid.Name, call_node: astroid.Call
) -> Optional[DagCallNode]:
def value_from_name_node(name_node: astroid.Name) -> Optional[str]:
"""Returns a DagCallNode instance with dag_id extracted from the name_node argument,
or None if the node value can't be extracted."""
name_val = safe_infer(name_node) # TODO: follow name chains
if isinstance(name_val, astroid.Const):
return dag_call_node_from_const(name_val, call_node)
return value_from_const_node(name_val)
return None


def dag_call_node_from_joined_string(
joined_str_node: astroid.JoinedStr, call_node: astroid.Call
) -> Optional[DagCallNode]:
def value_from_joined_str_node(joined_str_node: astroid.JoinedStr) -> Optional[str]:
"""Returns a DagCallNode instance with dag_id composed from the elements of the
joined_str_node argument, or None if the node value can't be extracted."""
dag_id_elements: List[str] = []
Expand All @@ -68,21 +62,25 @@ def dag_call_node_from_joined_string(
return None
elif isinstance(js_value, astroid.Const):
dag_id_elements.append(str(js_value.value))
return DagCallNode("".join(dag_id_elements), call_node)
return "".join(dag_id_elements)
# TODO: follow name chains


def dag_call_node_from_argument_value(
argument_value: astroid.NodeNG, call_node: astroid.Call
) -> Optional[DagCallNode]:
"""Detects argument string from Const, Name or JoinedStr (f-string), or None if no match"""
val = None
if isinstance(argument_value, astroid.Const):
return dag_call_node_from_const(argument_value, call_node)
if isinstance(argument_value, astroid.Name):
return dag_call_node_from_name(argument_value, call_node)
if isinstance(argument_value, astroid.JoinedStr):
return dag_call_node_from_joined_string(argument_value, call_node)
return None
val = value_from_const_node(argument_value)
elif isinstance(argument_value, astroid.Name):
val = value_from_name_node(argument_value)
elif isinstance(argument_value, astroid.JoinedStr):
val = value_from_joined_str_node(argument_value)

if not val: # if we didn't get a real value from the chain above
return None
return DagCallNode(val, call_node)


def find_dag_in_call_node(call_node: astroid.Call) -> Optional[DagCallNode]:
Expand Down
32 changes: 28 additions & 4 deletions tests/pylint_airflow/checkers/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,28 @@ def test_invalid_nodes_should_return_none(self, test_statement):
'test_id = "my_dag"\n models.DAG(dag_id=f"{test_id}_0")',
'test_id = "my_dag"\n DAG(f"{test_id}_0")',
'test_id = "my_dag"\n models.DAG(f"{test_id}_0")',
'test_id = "my_dag_0"\n my_id = test_id\n DAG(dag_id=my_id)',
'test_id = "my_dag_0"\n my_id = test_id\n models.DAG(dag_id=my_id)',
'test_id = "my_dag_0"\n my_id = test_id\n DAG(my_id)',
'test_id = "my_dag_0"\n my_id = test_id\n models.DAG(my_id)',
'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long
'test_id = "my_dag_0"\n the_id = test_id\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long
],
ids=[
"DAG Name call with f-string dag_id keyword argument",
"DAG Attribute call with f-string dag_id keyword argument",
"DAG Name call with f-string dag_id positional argument",
"DAG Attribute call with f-string dag_id positional argument",
"DAG Name call with double-variable dag_id keyword argument",
"DAG Attribute call with double-variable dag_id keyword argument",
"DAG Name call with double-variable dag_id positional argument",
"DAG Attribute call with double-variable dag_id positional argument",
"DAG Name call with triple-variable dag_id keyword argument",
"DAG Attribute call with triple-variable dag_id keyword argument",
"DAG Name call with triple-variable dag_id positional argument",
"DAG Attribute call with triple-variable dag_id positional argument",
],
)
def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_statement):
Expand All @@ -324,12 +340,20 @@ def test_valid_dag_call_with_variables_should_return_dag_id_and_node(self, test_
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(dag_id=my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n DAG(my_id)',
'test_id = "my_dag"\n my_id = f"{test_id}_0"\n models.DAG(my_id)',
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(dag_id=my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n DAG(my_id)', # pylint: disable=line-too-long
'test_id = "my_dag"\n the_id = f"{test_id}_0"\n my_id = the_id\n models.DAG(my_id)', # pylint: disable=line-too-long
],
ids=[
"DAG Name call with double-variable dag_id keyword argument",
"DAG Attribute call with double-variable dag_id keyword argument",
"DAG Name call with double-variable dag_id positional argument",
"DAG Attribute call with double-variable dag_id positional argument",
"DAG Name call with double-variable f-string dag_id keyword argument",
"DAG Attribute call with double-variable f-string dag_id keyword argument",
"DAG Name call with double-variable f-string dag_id positional argument",
"DAG Attribute call with double-variable f-string dag_id positional argument",
"DAG Name call with triple-variable f-string dag_id keyword argument",
"DAG Attribute call with triple-variable f-string dag_id keyword argument",
"DAG Name call with triple-variable f-string dag_id positional argument",
"DAG Attribute call with triple-variable f-string dag_id positional argument",
],
)
@pytest.mark.xfail(reason="Not yet implemented", raises=AssertionError, strict=True)
Expand Down

0 comments on commit 46ba042

Please sign in to comment.