Skip to content

Commit

Permalink
Fix python string to symbolic boolean interpretation in comparsons
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Nov 14, 2024
1 parent 17e4a88 commit cd06e2a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
18 changes: 17 additions & 1 deletion dace/symbolic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import ast
from functools import lru_cache
import sympy
Expand Down Expand Up @@ -986,6 +986,14 @@ class PythonOpToSympyConverter(ast.NodeTransformer):
"""
Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation.
"""

interpret_numeric_booleans: bool

def __init__(self, interpret_numeric_booleans: bool = True):
super().__init__()

self.interpret_numeric_booleans = interpret_numeric_booleans

_ast_to_sympy_comparators = {
ast.Eq: 'Eq',
ast.Gt: 'Gt',
Expand Down Expand Up @@ -1067,6 +1075,14 @@ def visit_Compare(self, node: ast.Compare):
raise NotImplementedError
op = node.ops[0]
arguments = [node.left, node.comparators[0]]

if self.interpret_numeric_booleans:
# Ensure constant values in boolean comparisons are interpreted als booleans.
if isinstance(node.left, ast.Compare) and isinstance(node.comparators[0], ast.Constant):
arguments[1] = ast.copy_location(ast.Constant(bool(node.comparators[0].value)), node.comparators[0])
elif isinstance(node.left, ast.Constant) and isinstance(node.comparators[0], ast.Compare):
arguments[0] = ast.copy_location(ast.Constant(bool(node.left.value)), node.left)

func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[])
return ast.copy_location(new_node, node)
Expand Down
23 changes: 22 additions & 1 deletion tests/passes/dead_code_elimination_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" Various tests for dead code elimination passes. """

import numpy as np
Expand Down Expand Up @@ -45,6 +45,26 @@ def test_dse_unconditional():
assert set(sdfg.states()) == {s, s2, e}


def test_dse_edge_condition_with_integer_as_boolean_regression():
"""
This is a regression test for issue #1129, which describes dead state elimination incorrectly eliminating interstate
edges when integers are used as boolean values in interstate edge conditions. Code taken from issue #1129.
"""
sdfg = dace.SDFG('dse_edge_condition_with_integer_as_boolean_regression')
sdfg.add_scalar('N', dtype=dace.int32, transient=True)
sdfg.add_scalar('result', dtype=dace.int32)
state_init = sdfg.add_state()
state_middle = sdfg.add_state()
state_end = sdfg.add_state()
sdfg.add_edge(state_init, state_end, dace.InterstateEdge(condition='(not ((N > 20) != 0))',
assignments={'result': 'N'}))
sdfg.add_edge(state_init, state_middle, dace.InterstateEdge(condition='((N > 20) != 0)'))
sdfg.add_edge(state_middle, state_end, dace.InterstateEdge(assignments={'result': '20'}))

res = DeadStateElimination().apply_pass(sdfg, {})
assert res is None


def test_dde_simple():

@dace.program
Expand Down Expand Up @@ -307,6 +327,7 @@ def test_dce_add_type_hint_of_variable(dtype):
if __name__ == '__main__':
test_dse_simple()
test_dse_unconditional()
test_dse_edge_condition_with_integer_as_boolean_regression()
test_dde_simple()
test_dde_libnode()
test_dde_access_node_in_scope(False)
Expand Down

0 comments on commit cd06e2a

Please sign in to comment.