diff --git a/dace/symbolic.py b/dace/symbolic.py index 9737080c52..beb8ccb288 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from functools import lru_cache import sympy @@ -986,6 +986,14 @@ class PythonOpToSympyConverter(ast.NodeTransformer): """ Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation. """ + + interpret_numeric_booleans: bool + + def __init__(self, interpret_numeric_booleans: bool = True): + super().__init__() + + self.interpret_numeric_booleans = interpret_numeric_booleans + _ast_to_sympy_comparators = { ast.Eq: 'Eq', ast.Gt: 'Gt', @@ -1067,6 +1075,14 @@ def visit_Compare(self, node: ast.Compare): raise NotImplementedError op = node.ops[0] arguments = [node.left, node.comparators[0]] + + if self.interpret_numeric_booleans: + # Ensure constant values in boolean comparisons are interpreted als booleans. + if isinstance(node.left, ast.Compare) and isinstance(node.comparators[0], ast.Constant): + arguments[1] = ast.copy_location(ast.Constant(bool(node.comparators[0].value)), node.comparators[0]) + elif isinstance(node.left, ast.Constant) and isinstance(node.comparators[0], ast.Compare): + arguments[0] = ast.copy_location(ast.Constant(bool(node.left.value)), node.left) + func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node) new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[]) return ast.copy_location(new_node, node) diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index a41a11c4d6..1832ad8321 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Various tests for dead code elimination passes. """ import numpy as np @@ -45,6 +45,26 @@ def test_dse_unconditional(): assert set(sdfg.states()) == {s, s2, e} +def test_dse_edge_condition_with_integer_as_boolean_regression(): + """ + This is a regression test for issue #1129, which describes dead state elimination incorrectly eliminating interstate + edges when integers are used as boolean values in interstate edge conditions. Code taken from issue #1129. + """ + sdfg = dace.SDFG('dse_edge_condition_with_integer_as_boolean_regression') + sdfg.add_scalar('N', dtype=dace.int32, transient=True) + sdfg.add_scalar('result', dtype=dace.int32) + state_init = sdfg.add_state() + state_middle = sdfg.add_state() + state_end = sdfg.add_state() + sdfg.add_edge(state_init, state_end, dace.InterstateEdge(condition='(not ((N > 20) != 0))', + assignments={'result': 'N'})) + sdfg.add_edge(state_init, state_middle, dace.InterstateEdge(condition='((N > 20) != 0)')) + sdfg.add_edge(state_middle, state_end, dace.InterstateEdge(assignments={'result': '20'})) + + res = DeadStateElimination().apply_pass(sdfg, {}) + assert res is None + + def test_dde_simple(): @dace.program @@ -307,6 +327,7 @@ def test_dce_add_type_hint_of_variable(dtype): if __name__ == '__main__': test_dse_simple() test_dse_unconditional() + test_dse_edge_condition_with_integer_as_boolean_regression() test_dde_simple() test_dde_libnode() test_dde_access_node_in_scope(False)