From bf9b0231a67551c6794514d84522e9177c2ad68e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 24 Sep 2023 22:58:15 -0700 Subject: [PATCH] Better detection of free symbols in C++ tasklets --- dace/sdfg/nodes.py | 10 +++++- dace/sdfg/replace.py | 7 +++++ dace/sdfg/state.py | 35 +++++++++++++-------- dace/symbolic.py | 24 +++++++++++++- dace/transformation/passes/prune_symbols.py | 25 ++++++--------- 5 files changed, 71 insertions(+), 30 deletions(-) diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 3c8f38162f..f60460c50e 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -342,6 +342,8 @@ class Tasklet(CodeNode): 'additional side effects on the system state (e.g., callback). ' 'Defaults to None, which lets the framework make assumptions based on ' 'the tasklet contents') + ignored_symbols = SetProperty(element_type=str, desc='A set of symbols to ignore when computing ' + 'the symbols used by this tasklet') def __init__(self, label, @@ -355,6 +357,7 @@ def __init__(self, code_exit="", location=None, side_effects=None, + ignored_symbols=None, debuginfo=None): super(Tasklet, self).__init__(label, location, inputs, outputs) @@ -365,6 +368,7 @@ def __init__(self, self.code_init = CodeBlock(code_init, dtypes.Language.CPP) self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP) self.side_effects = side_effects + self.ignored_symbols = ignored_symbols or set() self.debuginfo = debuginfo @property @@ -393,7 +397,11 @@ def validate(self, sdfg, state): @property def free_symbols(self) -> Set[str]: - return self.code.get_free_symbols(self.in_connectors.keys() | self.out_connectors.keys()) + symbols_to_ignore = self.in_connectors.keys() | self.out_connectors.keys() + symbols_to_ignore |= self.ignored_symbols + + return self.code.get_free_symbols(symbols_to_ignore) + def has_side_effects(self, sdfg) -> bool: """ diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 5e42830a75..4b36fad4fe 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -124,6 +124,7 @@ def replace_properties_dict(node: Any, if lang is dtypes.Language.CPP: # Replace in C++ code prefix = '' tokenized = tokenize_cpp.findall(code) + active_replacements = set() for name, new_name in reduced_repl.items(): if name not in tokenized: continue @@ -131,8 +132,14 @@ def replace_properties_dict(node: Any, # Use local variables and shadowing to replace replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' prefix = replacement + prefix + active_replacements.add(name) if prefix: propval.code = prefix + code + + # Ignore replaced symbols since they no longer exist as reads + if isinstance(node, dace.nodes.Tasklet): + node._ignored_symbols.update(active_replacements) + else: warnings.warn('Replacement of %s with %s was not made ' 'for string tasklet code of language %s' % (name, new_name, lang)) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 538f2114b9..8ad0c67bb8 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: import dace.sdfg.scope + def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. @@ -417,7 +418,6 @@ def is_leaf_memlet(self, e): if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): return False return True - def used_symbols(self, all_symbols: bool) -> Set[str]: """ @@ -438,13 +438,23 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: elif isinstance(n, nd.AccessNode): # Add data descriptor symbols freesyms |= set(map(str, n.desc(sdfg).used_symbols(all_symbols))) - elif (isinstance(n, nd.Tasklet) and n.language == dtypes.Language.Python): - # Consider callbacks defined as symbols as free - for stmt in n.code.code: - for astnode in ast.walk(stmt): - if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) - and astnode.func.id in sdfg.symbols): - freesyms.add(astnode.func.id) + elif isinstance(n, nd.Tasklet): + if n.language == dtypes.Language.Python: + # Consider callbacks defined as symbols as free + for stmt in n.code.code: + for astnode in ast.walk(stmt): + if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) + and astnode.func.id in sdfg.symbols): + freesyms.add(astnode.func.id) + else: + # Find all string tokens and filter them to sdfg.symbols, while ignoring connectors + codesyms = symbolic.symbols_in_code( + n.code.as_string, + potential_symbols=sdfg.symbols.keys(), + symbols_to_ignore=(n.in_connectors.keys() | n.out_connectors.keys() | n.ignored_symbols), + ) + freesyms |= codesyms + continue if hasattr(n, 'used_symbols'): freesyms |= n.used_symbols(all_symbols) @@ -462,7 +472,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # Do not consider SDFG constants as symbols new_symbols.update(set(sdfg.constants.keys())) return freesyms - new_symbols - + @property def free_symbols(self) -> Set[str]: """ @@ -474,7 +484,6 @@ def free_symbols(self) -> Set[str]: """ return self.used_symbols(all_symbols=True) - def defined_symbols(self) -> Dict[str, dt.Data]: """ Returns a dictionary that maps currently-defined symbols in this SDFG @@ -535,8 +544,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in list(in_edges): - if (in_edge.data.data == out_edge.data.data and - in_edge.data.dst_subset.covers(out_edge.data.src_subset)): + if (in_edge.data.data == out_edge.data.data + and in_edge.data.dst_subset.covers(out_edge.data.src_subset)): out_edges.remove(out_edge) break @@ -803,7 +812,7 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None - + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) diff --git a/dace/symbolic.py b/dace/symbolic.py index 0ab6e3f6ff..87fcc0036c 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -14,6 +14,7 @@ from dace import dtypes DEFAULT_SYMBOL_TYPE = dtypes.int32 +_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') # NOTE: Up to (including) version 1.8, sympy.abc._clash is a dictionary of the # form {'N': sympy.abc.N, 'I': sympy.abc.I, 'pi': sympy.abc.pi} @@ -1377,6 +1378,27 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo if is_length: for arg in args: facts += [sympy.Q.integer(arg), sympy.Q.positive(arg)] - + with sympy.assuming(*facts): return sympy.ask(sympy.Q.is_true(sympy.Eq(*args))) + + +def symbols_in_code(code: str, potential_symbols: Set[str] = None, + symbols_to_ignore: Set[str] = None) -> Set[str]: + """ + Tokenizes a code string for symbols and returns a set thereof. + + :param code: The code to tokenize. + :param potential_symbols: If not None, filters symbols to this given set. + :param symbols_to_ignore: If not None, filters out symbols from this set. + """ + if not code: + return set() + if potential_symbols is not None and len(potential_symbols) == 0: + # Don't bother tokenizing for an empty set of potential symbols + return set() + + tokens = set(re.findall(_NAME_TOKENS, code)) + if potential_symbols is not None: + tokens &= potential_symbols + return tokens - symbols_to_ignore diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 94fcbdbc58..cf55f7a9b2 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -1,16 +1,13 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import itertools -import re from dataclasses import dataclass from typing import Optional, Set, Tuple -from dace import SDFG, dtypes, properties +from dace import SDFG, dtypes, properties, symbolic from dace.sdfg import nodes from dace.transformation import pass_pipeline as ppl -_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') - @dataclass(unsafe_hash=True) @properties.make_properties @@ -81,7 +78,7 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: # Add symbols in global/init/exit code for code in itertools.chain(sdfg.global_code.values(), sdfg.init_code.values(), sdfg.exit_code.values()): - result |= _symbols_in_code(code.as_string) + result |= symbolic.symbols_in_code(code.as_string) for desc in sdfg.arrays.values(): result |= set(map(str, desc.free_symbols)) @@ -94,21 +91,19 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: for node in state.nodes(): if isinstance(node, nodes.Tasklet): if node.code.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code.as_string) + result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_global.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_global.as_string) + result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_init.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_init.as_string) + result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_exit.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_exit.as_string) - + result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), + node.ignored_symbols) for e in sdfg.edges(): result |= e.data.free_symbols return result - -def _symbols_in_code(code: str) -> Set[str]: - if not code: - return set() - return set(re.findall(_NAME_TOKENS, code))