Skip to content

Commit

Permalink
Better detection of free symbols in C++ tasklets
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Sep 25, 2023
1 parent dcac66e commit bf9b023
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 30 deletions.
10 changes: 9 additions & 1 deletion dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ class Tasklet(CodeNode):
'additional side effects on the system state (e.g., callback). '
'Defaults to None, which lets the framework make assumptions based on '
'the tasklet contents')
ignored_symbols = SetProperty(element_type=str, desc='A set of symbols to ignore when computing '
'the symbols used by this tasklet')

def __init__(self,
label,
Expand All @@ -355,6 +357,7 @@ def __init__(self,
code_exit="",
location=None,
side_effects=None,
ignored_symbols=None,
debuginfo=None):
super(Tasklet, self).__init__(label, location, inputs, outputs)

Expand All @@ -365,6 +368,7 @@ def __init__(self,
self.code_init = CodeBlock(code_init, dtypes.Language.CPP)
self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP)
self.side_effects = side_effects
self.ignored_symbols = ignored_symbols or set()
self.debuginfo = debuginfo

@property
Expand Down Expand Up @@ -393,7 +397,11 @@ def validate(self, sdfg, state):

@property
def free_symbols(self) -> Set[str]:
return self.code.get_free_symbols(self.in_connectors.keys() | self.out_connectors.keys())
symbols_to_ignore = self.in_connectors.keys() | self.out_connectors.keys()
symbols_to_ignore |= self.ignored_symbols

return self.code.get_free_symbols(symbols_to_ignore)


def has_side_effects(self, sdfg) -> bool:
"""
Expand Down
7 changes: 7 additions & 0 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,22 @@ def replace_properties_dict(node: Any,
if lang is dtypes.Language.CPP: # Replace in C++ code
prefix = ''
tokenized = tokenize_cpp.findall(code)
active_replacements = set()
for name, new_name in reduced_repl.items():
if name not in tokenized:
continue

# Use local variables and shadowing to replace
replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n'
prefix = replacement + prefix
active_replacements.add(name)
if prefix:
propval.code = prefix + code

# Ignore replaced symbols since they no longer exist as reads
if isinstance(node, dace.nodes.Tasklet):
node._ignored_symbols.update(active_replacements)

else:
warnings.warn('Replacement of %s with %s was not made '
'for string tasklet code of language %s' % (name, new_name, lang))
Expand Down
35 changes: 22 additions & 13 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
if TYPE_CHECKING:
import dace.sdfg.scope


def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo:
""" Returns a DebugInfo object for the position that called this function.
Expand Down Expand Up @@ -417,7 +418,6 @@ def is_leaf_memlet(self, e):
if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'):
return False
return True


def used_symbols(self, all_symbols: bool) -> Set[str]:
"""
Expand All @@ -438,13 +438,23 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
elif isinstance(n, nd.AccessNode):
# Add data descriptor symbols
freesyms |= set(map(str, n.desc(sdfg).used_symbols(all_symbols)))
elif (isinstance(n, nd.Tasklet) and n.language == dtypes.Language.Python):
# Consider callbacks defined as symbols as free
for stmt in n.code.code:
for astnode in ast.walk(stmt):
if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name)
and astnode.func.id in sdfg.symbols):
freesyms.add(astnode.func.id)
elif isinstance(n, nd.Tasklet):
if n.language == dtypes.Language.Python:
# Consider callbacks defined as symbols as free
for stmt in n.code.code:
for astnode in ast.walk(stmt):
if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name)
and astnode.func.id in sdfg.symbols):
freesyms.add(astnode.func.id)
else:
# Find all string tokens and filter them to sdfg.symbols, while ignoring connectors
codesyms = symbolic.symbols_in_code(
n.code.as_string,
potential_symbols=sdfg.symbols.keys(),
symbols_to_ignore=(n.in_connectors.keys() | n.out_connectors.keys() | n.ignored_symbols),
)
freesyms |= codesyms
continue

if hasattr(n, 'used_symbols'):
freesyms |= n.used_symbols(all_symbols)
Expand All @@ -462,7 +472,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
# Do not consider SDFG constants as symbols
new_symbols.update(set(sdfg.constants.keys()))
return freesyms - new_symbols

@property
def free_symbols(self) -> Set[str]:
"""
Expand All @@ -474,7 +484,6 @@ def free_symbols(self) -> Set[str]:
"""
return self.used_symbols(all_symbols=True)


def defined_symbols(self) -> Dict[str, dt.Data]:
"""
Returns a dictionary that maps currently-defined symbols in this SDFG
Expand Down Expand Up @@ -535,8 +544,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr,
# Filter out memlets which go out but the same data is written to the AccessNode by another memlet
for out_edge in list(out_edges):
for in_edge in list(in_edges):
if (in_edge.data.data == out_edge.data.data and
in_edge.data.dst_subset.covers(out_edge.data.src_subset)):
if (in_edge.data.data == out_edge.data.data
and in_edge.data.dst_subset.covers(out_edge.data.src_subset)):
out_edges.remove(out_edge)
break

Expand Down Expand Up @@ -803,7 +812,7 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None):
self.nosync = False
self.location = location if location is not None else {}
self._default_lineinfo = None

def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
Expand Down
24 changes: 23 additions & 1 deletion dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dace import dtypes

DEFAULT_SYMBOL_TYPE = dtypes.int32
_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*')

# NOTE: Up to (including) version 1.8, sympy.abc._clash is a dictionary of the
# form {'N': sympy.abc.N, 'I': sympy.abc.I, 'pi': sympy.abc.pi}
Expand Down Expand Up @@ -1377,6 +1378,27 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo
if is_length:
for arg in args:
facts += [sympy.Q.integer(arg), sympy.Q.positive(arg)]

with sympy.assuming(*facts):
return sympy.ask(sympy.Q.is_true(sympy.Eq(*args)))


def symbols_in_code(code: str, potential_symbols: Set[str] = None,
symbols_to_ignore: Set[str] = None) -> Set[str]:
"""
Tokenizes a code string for symbols and returns a set thereof.
:param code: The code to tokenize.
:param potential_symbols: If not None, filters symbols to this given set.
:param symbols_to_ignore: If not None, filters out symbols from this set.
"""
if not code:
return set()
if potential_symbols is not None and len(potential_symbols) == 0:
# Don't bother tokenizing for an empty set of potential symbols
return set()

tokens = set(re.findall(_NAME_TOKENS, code))
if potential_symbols is not None:
tokens &= potential_symbols
return tokens - symbols_to_ignore
25 changes: 10 additions & 15 deletions dace/transformation/passes/prune_symbols.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved.

import itertools
import re
from dataclasses import dataclass
from typing import Optional, Set, Tuple

from dace import SDFG, dtypes, properties
from dace import SDFG, dtypes, properties, symbolic
from dace.sdfg import nodes
from dace.transformation import pass_pipeline as ppl

_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*')


@dataclass(unsafe_hash=True)
@properties.make_properties
Expand Down Expand Up @@ -81,7 +78,7 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]:

# Add symbols in global/init/exit code
for code in itertools.chain(sdfg.global_code.values(), sdfg.init_code.values(), sdfg.exit_code.values()):
result |= _symbols_in_code(code.as_string)
result |= symbolic.symbols_in_code(code.as_string)

for desc in sdfg.arrays.values():
result |= set(map(str, desc.free_symbols))
Expand All @@ -94,21 +91,19 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]:
for node in state.nodes():
if isinstance(node, nodes.Tasklet):
if node.code.language != dtypes.Language.Python:
result |= _symbols_in_code(node.code.as_string)
result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(),
node.ignored_symbols)
if node.code_global.language != dtypes.Language.Python:
result |= _symbols_in_code(node.code_global.as_string)
result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(),
node.ignored_symbols)
if node.code_init.language != dtypes.Language.Python:
result |= _symbols_in_code(node.code_init.as_string)
result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(),
node.ignored_symbols)
if node.code_exit.language != dtypes.Language.Python:
result |= _symbols_in_code(node.code_exit.as_string)

result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(),
node.ignored_symbols)

for e in sdfg.edges():
result |= e.data.free_symbols

return result

def _symbols_in_code(code: str) -> Set[str]:
if not code:
return set()
return set(re.findall(_NAME_TOKENS, code))

0 comments on commit bf9b023

Please sign in to comment.