Skip to content

Commit

Permalink
Support attributes in symbolic expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Sep 14, 2023
1 parent 680a956 commit 36010fa
Showing 1 changed file with 61 additions and 48 deletions.
109 changes: 61 additions & 48 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ def eval(cls, x, y):
def _eval_is_boolean(self):
return True


class IfExpr(sympy.Function):

@classmethod
Expand Down Expand Up @@ -723,6 +724,19 @@ class IsNot(sympy.Function):
pass


class Attr(sympy.Function):
"""
Represents a get-attribute call on a function, equivalent to ``a.b`` in Python.
"""

@property
def free_symbols(self):
return {sympy.Symbol(str(self))}

def __str__(self):
return f'{self.args[0]}.{self.args[1]}'


def sympy_intdiv_fix(expr):
""" Fix for SymPy printing out reciprocal values when they should be
integral in "ceiling/floor" sympy functions.
Expand Down Expand Up @@ -926,10 +940,9 @@ def _process_is(elem: Union[Is, IsNot]):
return expr


class SympyBooleanConverter(ast.NodeTransformer):
class PythonOpToSympyConverter(ast.NodeTransformer):
"""
Replaces boolean operations with the appropriate SymPy functions to avoid
non-symbolic evaluation.
Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation.
"""
_ast_to_sympy_comparators = {
ast.Eq: 'Eq',
Expand All @@ -945,12 +958,37 @@ class SympyBooleanConverter(ast.NodeTransformer):
ast.NotIn: 'NotIn',
}

_ast_to_sympy_functions = {
ast.BitAnd: 'BitwiseAnd',
ast.BitOr: 'BitwiseOr',
ast.BitXor: 'BitwiseXor',
ast.Invert: 'BitwiseNot',
ast.LShift: 'LeftShift',
ast.RShift: 'RightShift',
ast.FloorDiv: 'int_floor',
}

def visit_UnaryOp(self, node):
if isinstance(node.op, ast.Not):
func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return node
elif isinstance(node.op, ast.Invert):
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()),
node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BinOp(self, node):
if type(node.op) in self._ast_to_sympy_functions:
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()),
node)
new_node = ast.Call(func=func_node,
args=[self.visit(value) for value in (node.left, node.right)],
keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BoolOp(self, node):
func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node)
Expand All @@ -970,8 +1008,7 @@ def visit_Compare(self, node: ast.Compare):
raise NotImplementedError
op = node.ops[0]
arguments = [node.left, node.comparators[0]]
func_node = ast.copy_location(
ast.Name(id=SympyBooleanConverter._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[])
return ast.copy_location(new_node, node)

Expand All @@ -984,40 +1021,18 @@ def visit_NameConstant(self, node):
return self.visit_Constant(node)

def visit_IfExp(self, node):
new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load), args=[node.test, node.body, node.orelse], keywords=[])
new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load),
args=[self.visit(node.test),
self.visit(node.body),
self.visit(node.orelse)],
keywords=[])
return ast.copy_location(new_node, node)

class BitwiseOpConverter(ast.NodeTransformer):
"""
Replaces C/C++ bitwise operations with functions to avoid sympification to boolean operations.
"""
_ast_to_sympy_functions = {
ast.BitAnd: 'BitwiseAnd',
ast.BitOr: 'BitwiseOr',
ast.BitXor: 'BitwiseXor',
ast.Invert: 'BitwiseNot',
ast.LShift: 'LeftShift',
ast.RShift: 'RightShift',
ast.FloorDiv: 'int_floor',
}

def visit_UnaryOp(self, node):
if isinstance(node.op, ast.Invert):
func_node = ast.copy_location(
ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)

def visit_BinOp(self, node):
if type(node.op) in BitwiseOpConverter._ast_to_sympy_functions:
func_node = ast.copy_location(
ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node)
new_node = ast.Call(func=func_node,
args=[self.visit(value) for value in (node.left, node.right)],
keywords=[])
return ast.copy_location(new_node, node)
return self.generic_visit(node)
def visit_Attribute(self, node):
new_node = ast.Call(func=ast.Name(id='Attr', ctx=ast.Load),
args=[self.visit(node.value), ast.Name(id=node.attr, ctx=ast.Load)],
keywords=[])
return ast.copy_location(new_node, node)


@lru_cache(maxsize=16384)
Expand Down Expand Up @@ -1070,21 +1085,17 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic:
'int_ceil': int_ceil,
'IfExpr': IfExpr,
'Mod': sympy.Mod,
'Attr': Attr,
}
# _clash1 enables all one-letter variables like N as symbols
# _clash also allows pi, beta, zeta and other common greek letters
locals.update(_sympy_clash)

if isinstance(expr, str):
# Sympy processes "not/and/or" as direct evaluation. Replace with
# And/Or(x, y), Not(x)
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b', expr):
expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0]))

# NOTE: If the expression contains bitwise operations, replace them with user-functions.
# NOTE: Sympy does not support bitwise operations and converts them to boolean operations.
if re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]', expr):
expr = unparse(BitwiseOpConverter().visit(ast.parse(expr).body[0]))
# Sympy processes "not/and/or" as direct evaluation. Replace with And/Or(x, y), Not(x)
# Also replaces bitwise operations with user-functions since SymPy does not support bitwise operations.
if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b|[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]|[\.]', expr):
expr = unparse(PythonOpToSympyConverter().visit(ast.parse(expr).body[0]))

# TODO: support SymExpr over-approximated expressions
try:
Expand Down Expand Up @@ -1125,6 +1136,8 @@ def _print_Function(self, expr):
return f'(({self._print(expr.args[0])}) and ({self._print(expr.args[1])}))'
if str(expr.func) == 'OR':
return f'(({self._print(expr.args[0])}) or ({self._print(expr.args[1])}))'
if str(expr.func) == 'Attr':
return f'{self._print(expr.args[0])}.{self._print(expr.args[1])}'
return super()._print_Function(expr)

def _print_Mod(self, expr):
Expand Down Expand Up @@ -1377,6 +1390,6 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo
if is_length:
for arg in args:
facts += [sympy.Q.integer(arg), sympy.Q.positive(arg)]

with sympy.assuming(*facts):
return sympy.ask(sympy.Q.is_true(sympy.Eq(*args)))

0 comments on commit 36010fa

Please sign in to comment.