diff --git a/.github/workflows/general-ci.yml b/.github/workflows/general-ci.yml index 138726ef1d..063c1f3e7d 100644 --- a/.github/workflows/general-ci.yml +++ b/.github/workflows/general-ci.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7,'3.11'] + python-version: [3.7,'3.12'] simplify: [0,1,autoopt] steps: diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index 9ee0772eeb..8a132f3df3 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -287,6 +287,7 @@ def get_workspace_sizes(self) -> Dict[dtypes.StorageType, int]: result: Dict[dtypes.StorageType, int] = {} for storage in self.external_memory_types: func = self._lib.get_symbol(f'__dace_get_external_memory_size_{storage.name}') + func.restype = ctypes.c_size_t result[storage] = func(self._libhandle, *self._lastargs[1]) return result @@ -449,8 +450,8 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: raise TypeError('Passing an object (type %s) to an array in argument "%s"' % (type(arg).__name__, a)) elif dtypes.is_array(arg) and not isinstance(atype, dt.Array): - # GPU scalars are pointers, so this is fine - if atype.storage != dtypes.StorageType.GPU_Global: + # GPU scalars and return values are pointers, so this is fine + if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'): raise TypeError('Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a)) elif (not isinstance(atype, (dt.Array, dt.Structure)) and not isinstance(atype.dtype, dtypes.callback) and diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 182604c892..a198ed371b 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -30,7 +30,7 @@ x < 5 /------>[s2]--------\\ - [s1] \ ->[s5] + [s1] \\ ->[s5] ------>[s3]->[s4]--/ x >= 5 @@ -82,6 +82,9 @@ class ControlFlow: # a string with its generated code. dispatch_state: Callable[[SDFGState], str] + # The parent control flow block of this one, used to avoid generating extraneous ``goto``s + parent: Optional['ControlFlow'] + @property def first_state(self) -> SDFGState: """ @@ -222,11 +225,18 @@ def as_cpp(self, codegen, symbols) -> str: out_edges = sdfg.out_edges(elem.state) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: - # If this is the last generated edge and it leads - # to the next state, skip emitting goto + # Skip gotos to immediate successors successor = None - if (j == (len(out_edges) - 1) and (i + 1) < len(self.elements)): - successor = self.elements[i + 1].first_state + # If this is the last generated edge + if j == (len(out_edges) - 1): + if (i + 1) < len(self.elements): + # If last edge leads to next state in block + successor = self.elements[i + 1].first_state + elif i == len(self.elements) - 1: + # If last edge leads to first state in next block + next_block = _find_next_block(self) + if next_block is not None: + successor = next_block.first_state expr += elem.generate_transition(sdfg, e, successor) else: @@ -350,6 +360,9 @@ class ForScope(ControlFlow): init_edges: List[InterstateEdge] #: All initialization edges def as_cpp(self, codegen, symbols) -> str: + + sdfg = self.guard.parent + # Initialize to either "int i = 0" or "i = 0" depending on whether # the type has been defined defined_vars = codegen.dispatcher.defined_vars @@ -359,9 +372,8 @@ def as_cpp(self, codegen, symbols) -> str: init = self.itervar else: init = f'{symbols[self.itervar]} {self.itervar}' - init += ' = ' + self.init - - sdfg = self.guard.parent + init += ' = ' + unparse_interstate_edge(self.init_edges[0].data.assignments[self.itervar], + sdfg, codegen=codegen) preinit = '' if self.init_edges: @@ -478,13 +490,14 @@ def children(self) -> List[ControlFlow]: def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge], leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]], - dispatch_state: Callable[[SDFGState], str]) -> Union[ForScope, WhileScope]: + dispatch_state: Callable[[SDFGState], + str], parent_block: GeneralBlock) -> Union[ForScope, WhileScope]: """ Helper method that constructs the correct structured loop construct from a set of states. Can construct for or while loops. """ - body = GeneralBlock(dispatch_state, [], [], [], [], [], True) + body = GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True) guard_inedges = sdfg.in_edges(guard) increment_edges = [e for e in guard_inedges if e in back_edges] @@ -535,10 +548,10 @@ def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[Intersta # Also ignore assignments in increment edge (handled in for stmt) body.assignments_to_ignore.append(increment_edge) - return ForScope(dispatch_state, itvar, guard, init, condition, update, body, init_edges) + return ForScope(dispatch_state, parent_block, itvar, guard, init, condition, update, body, init_edges) # Otherwise, it is a while loop - return WhileScope(dispatch_state, guard, condition, body) + return WhileScope(dispatch_state, parent_block, guard, condition, body) def _cases_from_branches( @@ -617,6 +630,31 @@ def _child_of(node: SDFGState, parent: SDFGState, ptree: Dict[SDFGState, SDFGSta return False +def _find_next_block(block: ControlFlow) -> Optional[ControlFlow]: + """ + Returns the immediate successor control flow block. + """ + # Find block in parent + parent = block.parent + if parent is None: + return None + ind = next(i for i, b in enumerate(parent.children) if b is block) + if ind == len(parent.children) - 1 or isinstance(parent, (IfScope, IfElseChain, SwitchCaseScope)): + # If last block, or other children are not reachable from current node (branches), + # recursively continue upwards + return _find_next_block(parent) + return parent.children[ind + 1] + + +def _reset_block_parents(block: ControlFlow): + """ + Fixes block parents after processing. + """ + for child in block.children: + child.parent = block + _reset_block_parents(child) + + def _structured_control_flow_traversal(sdfg: SDFG, start: SDFGState, ptree: Dict[SDFGState, SDFGState], @@ -645,7 +683,7 @@ def _structured_control_flow_traversal(sdfg: SDFG, """ def make_empty_block(): - return GeneralBlock(dispatch_state, [], [], [], [], [], True) + return GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True) # Traverse states in custom order visited = set() if visited is None else visited @@ -657,7 +695,7 @@ def make_empty_block(): if node in visited or node is stop: continue visited.add(node) - stateblock = SingleState(dispatch_state, node) + stateblock = SingleState(dispatch_state, parent_block, node) oe = sdfg.out_edges(node) if len(oe) == 0: # End state @@ -708,12 +746,14 @@ def make_empty_block(): if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())): # If without else if oe[0].dst is mergestate: - branch_block = IfScope(dispatch_state, sdfg, node, oe[1].data.condition, cblocks[oe[1]]) + branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[1].data.condition, + cblocks[oe[1]]) elif oe[1].dst is mergestate: - branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]]) + branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition, + cblocks[oe[0]]) else: - branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]], - cblocks[oe[1]]) + branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition, + cblocks[oe[0]], cblocks[oe[1]]) else: # If there are 2 or more edges (one is not the negation of the # other): @@ -721,10 +761,10 @@ def make_empty_block(): if switch: # If all edges are of form "x == y" for a single x and # integer y, it is a switch/case - branch_block = SwitchCaseScope(dispatch_state, sdfg, node, switch[0], switch[1]) + branch_block = SwitchCaseScope(dispatch_state, parent_block, sdfg, node, switch[0], switch[1]) else: # Otherwise, create if/else if/.../else goto exit chain - branch_block = IfElseChain(dispatch_state, sdfg, node, + branch_block = IfElseChain(dispatch_state, parent_block, sdfg, node, [(e.data.condition, cblocks[e] if e in cblocks else make_empty_block()) for e in oe]) # End of branch classification @@ -739,11 +779,11 @@ def make_empty_block(): loop_exit = None scope = None if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node: - scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state) + scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state, parent_block) body_start = oe[0].dst loop_exit = oe[1].dst elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node: - scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state) + scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state, parent_block) body_start = oe[1].dst loop_exit = oe[0].dst @@ -836,7 +876,8 @@ def structured_control_flow_tree(sdfg: SDFG, dispatch_state: Callable[[SDFGState if len(common_frontier) == 1: branch_merges[state] = next(iter(common_frontier)) - root_block = GeneralBlock(dispatch_state, [], [], [], [], [], True) + root_block = GeneralBlock(dispatch_state, None, [], [], [], [], [], True) _structured_control_flow_traversal(sdfg, sdfg.start_state, ptree, branch_merges, back_edges, dispatch_state, root_block) + _reset_block_parents(root_block) return root_block diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py index eae0ed229e..e4456e3e18 100644 --- a/dace/codegen/cppunparse.py +++ b/dace/codegen/cppunparse.py @@ -78,6 +78,7 @@ import numpy as np import os import tokenize +import warnings import sympy import dace @@ -86,6 +87,21 @@ from dace import dtypes from dace.codegen.tools import type_inference + +if sys.version_info < (3, 8): + BytesConstant = ast.Bytes + EllipsisConstant = ast.Ellipsis + NameConstant = ast.NameConstant + NumConstant = ast.Num + StrConstant = ast.Str +else: + BytesConstant = ast.Constant + EllipsisConstant = ast.Constant + NameConstant = ast.Constant + NumConstant = ast.Constant + StrConstant = ast.Constant + + # Large float and imaginary literals get turned into infinities in the AST. # We unparse those infinities to INFSTR. INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) @@ -573,7 +589,7 @@ def _generic_FunctionDef(self, t, is_async=False): self.write('/* async */ ') if getattr(t, "returns", False): - if isinstance(t.returns, ast.NameConstant): + if isinstance(t.returns, NameConstant): if t.returns.value is None: self.write('void') else: @@ -728,11 +744,27 @@ def _Repr(self, t): raise NotImplementedError('Invalid C++') def _Num(self, t): - repr_n = repr(t.n) + t_n = t.value if sys.version_info >= (3, 8) else t.n + repr_n = repr(t_n) # For complex values, use DTYPE_TO_TYPECLASS dictionary - if isinstance(t.n, complex): + if isinstance(t_n, complex): dtype = dtypes.DTYPE_TO_TYPECLASS[complex] + # Handle large integer values + if isinstance(t_n, int): + bits = t_n.bit_length() + if bits == 32: # Integer, potentially unsigned + if t_n >= 0: # unsigned + repr_n += 'U' + else: # signed, 64-bit + repr_n += 'LL' + elif 32 < bits <= 63: + repr_n += 'LL' + elif bits == 64 and t_n >= 0: + repr_n += 'ULL' + elif bits >= 64: + warnings.warn(f'Value wider than 64 bits encountered in expression ({t_n}), emitting as-is') + if repr_n.endswith("j"): self.write("%s(0, %s)" % (dtype, repr_n.replace("inf", INFSTR)[:-1])) else: @@ -831,8 +863,23 @@ def _Tuple( self.write(")") unop = {"Invert": "~", "Not": "!", "UAdd": "+", "USub": "-"} + unop_lambda = {'Invert': (lambda x: ~x), 'Not': (lambda x: not x), 'UAdd': (lambda x: +x), 'USub': (lambda x: -x)} def _UnaryOp(self, t): + # Dispatch constants after applying the operation + if sys.version_info[:2] < (3, 8): + if isinstance(t.operand, ast.Num): + newval = self.unop_lambda[t.op.__class__.__name__](t.operand.n) + newnode = ast.Num(n=newval) + self.dispatch(newnode) + return + else: + if isinstance(t.operand, ast.Constant): + newval = self.unop_lambda[t.op.__class__.__name__](t.operand.value) + newnode = ast.Constant(value=newval) + self.dispatch(newnode) + return + self.write("(") self.write(self.unop[t.op.__class__.__name__]) self.write(" ") @@ -867,13 +914,13 @@ def _BinOp(self, t): self.write(")") # Special cases for powers elif t.op.__class__.__name__ == 'Pow': - if isinstance(t.right, (ast.Num, ast.Constant, ast.UnaryOp)): + if isinstance(t.right, (NumConstant, ast.Constant, ast.UnaryOp)): power = None - if isinstance(t.right, (ast.Num, ast.Constant)): - power = t.right.n + if isinstance(t.right, (NumConstant, ast.Constant)): + power = t.right.value if sys.version_info >= (3, 8) else t.right.n elif isinstance(t.right, ast.UnaryOp) and isinstance(t.right.op, ast.USub): - if isinstance(t.right.operand, (ast.Num, ast.Constant)): - power = -t.right.operand.n + if isinstance(t.right.operand, (NumConstant, ast.Constant)): + power = - (t.right.operand.value if sys.version_info >= (3, 8) else t.right.operand.n) if power is not None and int(power) == power: negative = power < 0 @@ -953,7 +1000,9 @@ def _Attribute(self, t): # Special case: 3.__abs__() is a syntax error, so if t.value # is an integer literal then we need to either parenthesize # it or add an extra space to get 3 .__abs__(). - if (isinstance(t.value, (ast.Num, ast.Constant)) and isinstance(t.value.n, int)): + if isinstance(t.value, ast.Constant) and isinstance(t.value.value, int): + self.write(" ") + elif sys.version_info < (3, 8) and isinstance(t.value, ast.Num) and isinstance(t.value.n, int): self.write(" ") if (isinstance(t.value, ast.Name) and t.value.id in ('dace', 'dace::math', 'dace::cmath')): self.write("::") diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index bc7163ea9b..c0d3b657a1 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -448,7 +448,7 @@ class PAPIUtils(object): def available_counters() -> Dict[str, int]: """ Returns the available PAPI counters on this machine. Only works on - \*nix based systems with ``grep`` and ``papi-tools`` installed. + *nix based systems with ``grep`` and ``papi-tools`` installed. :return: A set of available PAPI counters in the form of a dictionary mapping from counter name to the number of native hardware diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index d3d4f50ccd..3d26f76214 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Helper functions for C++ code generation. NOTE: The C++ code generator is currently located in cpu.py. @@ -9,6 +9,7 @@ import itertools import math import numbers +import sys import warnings import sympy as sp @@ -218,6 +219,11 @@ def ptr(name: str, desc: data.Data, sdfg: SDFG = None, framecode=None) -> str: from dace.codegen.targets.framecode import DaCeCodeGenerator # Avoid import loop framecode: DaCeCodeGenerator = framecode + if '.' in name: + root = name.split('.')[0] + if root in sdfg.arrays and isinstance(sdfg.arrays[root], data.Structure): + name = name.replace('.', '->') + # Special case: If memory is persistent and defined in this SDFG, add state # struct to name if (desc.transient and desc.lifetime in (dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External)): @@ -992,8 +998,7 @@ def _Name(self, t: ast.Name): if t.id not in self.sdfg.arrays: return super()._Name(t) - # Replace values with their code-generated names (for example, - # persistent arrays) + # Replace values with their code-generated names (for example, persistent arrays) desc = self.sdfg.arrays[t.id] self.write(ptr(t.id, desc, self.sdfg, self.codegen)) @@ -1271,7 +1276,8 @@ def visit_BinOp(self, node: ast.BinOp): evaluated_constant = symbolic.evaluate(unparsed, self.constants) evaluated = symbolic.symstr(evaluated_constant, cpp_mode=True) value = ast.parse(evaluated).body[0].value - if isinstance(evaluated_node, numbers.Number) and evaluated_node != value.n: + if isinstance(evaluated_node, numbers.Number) and evaluated_node != ( + value.value if sys.version_info >= (3, 8) else value.n): raise TypeError node.right = ast.parse(evaluated).body[0].value except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError): diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 0464672390..88dda0058f 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -59,11 +59,11 @@ def __init__(self, frame_codegen, sdfg): def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): for k, v in struct.members.items(): if isinstance(v, data.Structure): - _visit_structure(v, args, f'{prefix}.{k}') + _visit_structure(v, args, f'{prefix}->{k}') elif isinstance(v, data.StructArray): - _visit_structure(v.stype, args, f'{prefix}.{k}') + _visit_structure(v.stype, args, f'{prefix}->{k}') elif isinstance(v, data.Data): - args[f'{prefix}.{k}'] = v + args[f'{prefix}->{k}'] = v # Keeps track of generated connectors, so we know how to access them in nested scopes arglist = dict(self._frame.arglist) @@ -221,8 +221,8 @@ def allocate_view(self, sdfg: SDFG, dfg: SDFGState, state_id: int, node: nodes.A if isinstance(v, data.Data): ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer - self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef) - self._dispatcher.defined_vars.add(f"{name}.{k}", defined_type, ctypedef) + self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef) + self._dispatcher.defined_vars.add(f"{name}->{k}", defined_type, ctypedef) # TODO: Find a better way to do this (the issue is with pointers of pointers) if atype.endswith('*'): atype = atype[:-1] @@ -299,9 +299,6 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d name = node.data alloc_name = cpp.ptr(name, nodedesc, sdfg, self._frame) name = alloc_name - # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and - # NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows. - alloc_name = alloc_name.replace('.', '->') if nodedesc.transient is False: return @@ -331,7 +328,7 @@ def allocate_array(self, sdfg, dfg, state_id, node, nodedesc, function_stream, d if isinstance(v, data.Data): ctypedef = dtypes.pointer(v.dtype).ctype if isinstance(v, data.Array) else v.dtype.ctype defined_type = DefinedType.Scalar if isinstance(v, data.Scalar) else DefinedType.Pointer - self._dispatcher.declared_arrays.add(f"{name}.{k}", defined_type, ctypedef) + self._dispatcher.declared_arrays.add(f"{name}->{k}", defined_type, ctypedef) self.allocate_array(sdfg, dfg, state_id, nodes.AccessNode(f"{name}.{k}"), v, function_stream, declaration_stream, allocation_stream) return @@ -1184,9 +1181,6 @@ def memlet_definition(self, if not types: types = self._dispatcher.defined_vars.get(ptr, is_global=True) var_type, ctypedef = types - # NOTE: `expr` may only be a name or a sequence of names and dots. The latter indicates nested data and - # NOTE: structures. Since structures are implemented as pointers, we replace dots with arrows. - ptr = ptr.replace('.', '->') if fpga.is_fpga_array(desc): decouple_array_interfaces = Config.get_bool("compiler", "xilinx", "decouple_array_interfaces") @@ -1517,9 +1511,10 @@ def make_restrict(expr: str) -> str: arguments += [ f'{atype} {restrict} {aname}' for (atype, aname, _), restrict in zip(memlet_references, restrict_args) ] + fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True) arguments += [ f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys()) - if aname not in sdfg.constants + if aname in fsyms and aname not in sdfg.constants ] arguments = ', '.join(arguments) return f'void {sdfg_label}({arguments}) {{' @@ -1528,9 +1523,10 @@ def generate_nsdfg_call(self, sdfg, state, node, memlet_references, sdfg_label, prepend = [] if state_struct: prepend = ['__state'] + fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True) args = ', '.join(prepend + [argval for _, _, argval in memlet_references] + [ - cpp.sym2cpp(symval) - for symname, symval in sorted(node.symbol_mapping.items()) if symname not in sdfg.constants + cpp.sym2cpp(symval) for symname, symval in sorted(node.symbol_mapping.items()) + if symname in fsyms and symname not in sdfg.constants ]) return f'{sdfg_label}({args});' @@ -1814,11 +1810,11 @@ def _generate_MapEntry( # Find if bounds are used within the scope scope = state_dfg.scope_subgraph(node, False, False) - fsyms = scope.free_symbols + fsyms = self._frame.free_symbols(scope) # Include external edges for n in scope.nodes(): for e in state_dfg.all_edges(n): - fsyms |= e.data.free_symbols + fsyms |= e.data.used_symbols(False, e) fsyms = set(map(str, fsyms)) ntid_is_used = '__omp_num_threads' in fsyms diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index ee49f04d03..a465d2bbc0 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1939,6 +1939,13 @@ def generate_kernel_scope(self, sdfg: SDFG, dfg_scope: ScopeSubgraphView, state_ kernel_params: list, function_stream: CodeIOStream, kernel_stream: CodeIOStream): node = dfg_scope.source_nodes()[0] + # Get the thread/block index type + ttype = Config.get('compiler', 'cuda', 'thread_id_type') + tidtype = getattr(dtypes, ttype, False) + if not isinstance(tidtype, dtypes.typeclass): + raise ValueError(f'Configured type "{ttype}" for ``thread_id_type`` does not match any DaCe data type. ' + 'See ``dace.dtypes`` for available types (for example ``int32``).') + # allocating shared memory for dynamic threadblock maps if has_dtbmap: kernel_stream.write( @@ -1990,8 +1997,8 @@ def generate_kernel_scope(self, sdfg: SDFG, dfg_scope: ScopeSubgraphView, state_ expr = _topy(bidx[i]).replace('__DAPB%d' % i, block_expr) - kernel_stream.write('int %s = %s;' % (varname, expr), sdfg, state_id, node) - self._dispatcher.defined_vars.add(varname, DefinedType.Scalar, 'int') + kernel_stream.write(f'{tidtype.ctype} {varname} = {expr};', sdfg, state_id, node) + self._dispatcher.defined_vars.add(varname, DefinedType.Scalar, tidtype.ctype) # Delinearize beyond the third dimension if len(krange) > 3: @@ -2010,8 +2017,8 @@ def generate_kernel_scope(self, sdfg: SDFG, dfg_scope: ScopeSubgraphView, state_ ) expr = _topy(bidx[i]).replace('__DAPB%d' % i, block_expr) - kernel_stream.write('int %s = %s;' % (varname, expr), sdfg, state_id, node) - self._dispatcher.defined_vars.add(varname, DefinedType.Scalar, 'int') + kernel_stream.write(f'{tidtype.ctype} {varname} = {expr};', sdfg, state_id, node) + self._dispatcher.defined_vars.add(varname, DefinedType.Scalar, tidtype.ctype) # Dispatch internal code assert CUDACodeGen._in_device_code is False diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 9ee5c2ef17..b1eb42fe60 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -471,7 +471,7 @@ def dispatch_state(state: SDFGState) -> str: # If disabled, generate entire graph as general control flow block states_topological = list(sdfg.topological_sort(sdfg.start_state)) last = states_topological[-1] - cft = cflow.GeneralBlock(dispatch_state, + cft = cflow.GeneralBlock(dispatch_state, None, [cflow.SingleState(dispatch_state, s, s is last) for s in states_topological], [], [], [], [], False) @@ -886,8 +886,8 @@ def generate_code(self, # NOTE: NestedSDFGs frequently contain tautologies in their symbol mapping, e.g., `'i': i`. Do not # redefine the symbols in such cases. - if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping.keys() - and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName] == isvarName)): + if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping + and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName]) == str(isvarName)): continue isvar = data.Scalar(isvarType) callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg) diff --git a/dace/codegen/targets/intel_fpga.py b/dace/codegen/targets/intel_fpga.py index 095a5ce9df..d3c46b0069 100644 --- a/dace/codegen/targets/intel_fpga.py +++ b/dace/codegen/targets/intel_fpga.py @@ -729,9 +729,10 @@ def generate_module(self, sdfg, state, kernel_name, module_name, subgraph, param def generate_nsdfg_header(self, sdfg, state, state_id, node, memlet_references, sdfg_label): # Intel FPGA needs to deal with streams arguments = [f'{atype} {aname}' for atype, aname, _ in memlet_references] + fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True) arguments += [ f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys()) - if aname not in sdfg.constants + if aname in fsyms and aname not in sdfg.constants ] arguments = ', '.join(arguments) function_header = f'void {sdfg_label}({arguments}) {{' diff --git a/dace/codegen/targets/xilinx.py b/dace/codegen/targets/xilinx.py index e802907652..5d82cfeafc 100644 --- a/dace/codegen/targets/xilinx.py +++ b/dace/codegen/targets/xilinx.py @@ -368,9 +368,10 @@ def generate_flatten_loop_post(kernel_stream, sdfg, state_id, node): def generate_nsdfg_header(self, sdfg, state, state_id, node, memlet_references, sdfg_label): # TODO: Use a single method for GPU kernels, FPGA modules, and NSDFGs arguments = [f'{atype} {aname}' for atype, aname, _ in memlet_references] + fsyms = node.sdfg.used_symbols(all_symbols=False, keep_defined_in_mapping=True) arguments += [ f'{node.sdfg.symbols[aname].as_arg(aname)}' for aname in sorted(node.symbol_mapping.keys()) - if aname not in sdfg.constants + if aname in fsyms and aname not in sdfg.constants ] arguments = ', '.join(arguments) return f'void {sdfg_label}({arguments}) {{\n#pragma HLS INLINE' diff --git a/dace/codegen/tools/type_inference.py b/dace/codegen/tools/type_inference.py index 3d91e5f964..f159088461 100644 --- a/dace/codegen/tools/type_inference.py +++ b/dace/codegen/tools/type_inference.py @@ -338,7 +338,15 @@ def _BinOp(t, symbols, inferred_symbols): return dtypes.result_type_of(type_left, type_right) # Special case for integer power elif t.op.__class__.__name__ == 'Pow': - if (isinstance(t.right, (ast.Num, ast.Constant)) and int(t.right.n) == t.right.n and t.right.n >= 0): + if (sys.version_info >= (3, 8) and isinstance(t.right, ast.Constant) and + int(t.right.value) == t.right.value and t.right.value >= 0): + if t.right.value != 0: + type_left = _dispatch(t.left, symbols, inferred_symbols) + for i in range(int(t.right.n) - 1): + _dispatch(t.left, symbols, inferred_symbols) + return dtypes.result_type_of(type_left, dtypes.typeclass(np.uint32)) + elif (sys.version_info < (3, 8) and isinstance(t.right, ast.Num) and + int(t.right.n) == t.right.n and t.right.n >= 0): if t.right.n != 0: type_left = _dispatch(t.left, symbols, inferred_symbols) for i in range(int(t.right.n) - 1): @@ -405,6 +413,9 @@ def _infer_dtype(t: Union[ast.Name, ast.Attribute]): def _Attribute(t, symbols, inferred_symbols): inferred_type = _dispatch(t.value, symbols, inferred_symbols) + if (isinstance(inferred_type, dtypes.pointer) and isinstance(inferred_type.base_type, dtypes.struct) and + t.attr in inferred_type.base_type.fields): + return inferred_type.base_type.fields[t.attr] return inferred_type diff --git a/dace/config_schema.yml b/dace/config_schema.yml index e378b6c1f2..08a427aa52 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -413,6 +413,17 @@ required: a specified larger block size in the third dimension. Default value is derived from hardware limits on common GPUs. + thread_id_type: + type: str + title: Thread/block index data type + default: int32 + description: > + Defines the data type for a thread and block index in the generated code. + The type is based on the type-classes in ``dace.dtypes``. For example, + ``uint64`` is equivalent to ``dace.uint64``. Change this setting when large + index types are needed to address memory offsets that are beyond the 32-bit + range, or to reduce memory usage. + ############################################# # General FPGA flags diff --git a/dace/data.py b/dace/data.py index 3b571e6537..0a9858458b 100644 --- a/dace/data.py +++ b/dace/data.py @@ -243,6 +243,10 @@ def __hash__(self): def as_arg(self, with_types=True, for_call=False, name=None): """Returns a string for a C++ function signature (e.g., `int *A`). """ raise NotImplementedError + + def as_python_arg(self, with_types=True, for_call=False, name=None): + """Returns a string for a Data-Centric Python function signature (e.g., `A: dace.int32[M]`). """ + raise NotImplementedError def used_symbols(self, all_symbols: bool) -> Set[symbolic.SymbolicType]: """ @@ -583,6 +587,13 @@ def as_arg(self, with_types=True, for_call=False, name=None): if not with_types or for_call: return name return self.dtype.as_arg(name) + + def as_python_arg(self, with_types=True, for_call=False, name=None): + if self.storage is dtypes.StorageType.GPU_Global: + return Array(self.dtype, [1]).as_python_arg(with_types, for_call, name) + if not with_types or for_call: + return name + return f"{name}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}" def sizes(self): return None @@ -849,6 +860,13 @@ def as_arg(self, with_types=True, for_call=False, name=None): if self.may_alias: return str(self.dtype.ctype) + ' *' + arrname return str(self.dtype.ctype) + ' * __restrict__ ' + arrname + + def as_python_arg(self, with_types=True, for_call=False, name=None): + arrname = name + + if not with_types or for_call: + return arrname + return f"{arrname}: {dtypes.TYPECLASS_TO_STRING[self.dtype].replace('::', '.')}{list(self.shape)}" def sizes(self): return [d.name if isinstance(d, symbolic.symbol) else str(d) for d in self.shape] diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index a66ee5c0d6..d95fa87e58 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -1,5 +1,5 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.two.Fortran2008 import Fortran2008 as f08 +from fparser.two import Fortran2008 as f08 from fparser.two import Fortran2003 as f03 from fparser.two import symbol_table @@ -523,6 +523,31 @@ def declaration_type_spec(self, node: FASTNode): def assumed_shape_spec_list(self, node: FASTNode): return node + def parse_shape_specification(self, dim: f03.Explicit_Shape_Spec, size: List[FASTNode], offset: List[int]): + + dim_expr = [i for i in dim.children if i is not None] + + # handle size definition + if len(dim_expr) == 1: + dim_expr = dim_expr[0] + #now to add the dimension to the size list after processing it if necessary + size.append(self.create_ast(dim_expr)) + offset.append(1) + # Here we support arrays that have size declaration - with initial offset. + elif len(dim_expr) == 2: + # extract offets + for expr in dim_expr: + if not isinstance(expr, f03.Int_Literal_Constant): + raise TypeError("Array offsets must be constant expressions!") + offset.append(int(dim_expr[0].tostr())) + + fortran_size = int(dim_expr[1].tostr()) - int(dim_expr[0].tostr()) + 1 + fortran_ast_size = f03.Int_Literal_Constant(str(fortran_size)) + + size.append(self.create_ast(fortran_ast_size)) + else: + raise TypeError("Array dimension must be at most two expressions") + def type_declaration_stmt(self, node: FASTNode): #decide if its a intrinsic variable type or a derived type @@ -574,33 +599,44 @@ def type_declaration_stmt(self, node: FASTNode): alloc = False symbol = False + attr_size = None + attr_offset = None for i in attributes: if i.string.lower() == "allocatable": alloc = True if i.string.lower() == "parameter": symbol = True + if isinstance(i, f08.Attr_Spec_List): + + dimension_spec = get_children(i, "Dimension_Attr_Spec") + if len(dimension_spec) == 0: + continue + + attr_size = [] + attr_offset = [] + sizes = get_child(dimension_spec[0], ["Explicit_Shape_Spec_List"]) + + for shape_spec in get_children(sizes, [f03.Explicit_Shape_Spec]): + self.parse_shape_specification(shape_spec, attr_size, attr_offset) + vardecls = [] for var in names: #first handle dimensions size = None + offset = None var_components = self.create_children(var) array_sizes = get_children(var, "Explicit_Shape_Spec_List") actual_name = get_child(var_components, ast_internal_classes.Name_Node) if len(array_sizes) == 1: array_sizes = array_sizes[0] size = [] + offset = [] for dim in array_sizes.children: #sanity check if isinstance(dim, f03.Explicit_Shape_Spec): - dim_expr = [i for i in dim.children if i is not None] - if len(dim_expr) == 1: - dim_expr = dim_expr[0] - #now to add the dimension to the size list after processing it if necessary - size.append(self.create_ast(dim_expr)) - else: - raise TypeError("Array dimension must be a single expression") + self.parse_shape_specification(dim, size, offset) #handle initializiation init = None @@ -615,15 +651,26 @@ def type_declaration_stmt(self, node: FASTNode): if symbol == False: - vardecls.append( - ast_internal_classes.Var_Decl_Node(name=actual_name.name, - type=testtype, - alloc=alloc, - sizes=size, - kind=kind, - line_number=node.item.span)) + if attr_size is None: + vardecls.append( + ast_internal_classes.Var_Decl_Node(name=actual_name.name, + type=testtype, + alloc=alloc, + sizes=size, + offsets=offset, + kind=kind, + line_number=node.item.span)) + else: + vardecls.append( + ast_internal_classes.Var_Decl_Node(name=actual_name.name, + type=testtype, + alloc=alloc, + sizes=attr_size, + offsets=attr_offset, + kind=kind, + line_number=node.item.span)) else: - if size is None: + if size is None and attr_size is None: self.symbols[actual_name.name] = init vardecls.append( ast_internal_classes.Symbol_Decl_Node(name=actual_name.name, @@ -631,16 +678,26 @@ def type_declaration_stmt(self, node: FASTNode): alloc=alloc, init=init, line_number=node.item.span)) + elif attr_size is not None: + vardecls.append( + ast_internal_classes.Symbol_Array_Decl_Node(name=actual_name.name, + type=testtype, + alloc=alloc, + sizes=attr_size, + offsets=attr_offset, + kind=kind, + init=init, + line_number=node.item.span)) else: vardecls.append( ast_internal_classes.Symbol_Array_Decl_Node(name=actual_name.name, type=testtype, alloc=alloc, sizes=size, + offsets=offset, kind=kind, init=init, line_number=node.item.span)) - return ast_internal_classes.Decl_Stmt_Node(vardecl=vardecls, line_number=node.item.span) def entity_decl(self, node: FASTNode): @@ -994,7 +1051,7 @@ def specification_part(self, node: FASTNode): decls = [self.create_ast(i) for i in node.children if isinstance(i, f08.Type_Declaration_Stmt)] - uses = [self.create_ast(i) for i in node.children if isinstance(i, f08.Use_Stmt)] + uses = [self.create_ast(i) for i in node.children if isinstance(i, f03.Use_Stmt)] tmp = [self.create_ast(i) for i in node.children] typedecls = [i for i in tmp if isinstance(i, ast_internal_classes.Type_Decl_Node)] symbols = [] diff --git a/dace/frontend/fortran/ast_internal_classes.py b/dace/frontend/fortran/ast_internal_classes.py index 6bdfb61faf..70a43e21b8 100644 --- a/dace/frontend/fortran/ast_internal_classes.py +++ b/dace/frontend/fortran/ast_internal_classes.py @@ -1,5 +1,5 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Any, List, Tuple, Type, TypeVar, Union, overload +from typing import Any, List, Optional, Tuple, Type, TypeVar, Union, overload # The node class is the base class for all nodes in the AST. It provides attributes including the line number and fields. # Attributes are not used when walking the tree, but are useful for debugging and for code generation. @@ -11,6 +11,14 @@ def __init__(self, *args, **kwargs): # real signature unknown self.integrity_exceptions = [] self.read_vars = [] self.written_vars = [] + self.parent: Optional[ + Union[ + Subroutine_Subprogram_Node, + Function_Subprogram_Node, + Main_Program_Node, + Module_Node + ] + ] = None for k, v in kwargs.items(): setattr(self, k, v) @@ -199,6 +207,7 @@ class Symbol_Array_Decl_Node(Statement_Node): ) _fields = ( 'sizes', + 'offsets' 'typeref', 'init', ) @@ -213,6 +222,7 @@ class Var_Decl_Node(Statement_Node): ) _fields = ( 'sizes', + 'offsets', 'typeref', 'init', ) diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index 7e5cd3bf00..e2a7246aed 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -1,7 +1,7 @@ # Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. from dace.frontend.fortran import ast_components, ast_internal_classes -from typing import List, Tuple, Set +from typing import Dict, List, Optional, Tuple, Set import copy @@ -310,6 +310,65 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No return ast_internal_classes.Execution_Part_Node(execution=newbody) +class ParentScopeAssigner(NodeVisitor): + """ + For each node, it assigns its parent scope - program, subroutine, function. + + If the parent node is one of the "parent" types, we assign it as the parent. + Otherwise, we look for the parent of my parent to cover nested AST nodes within + a single scope. + """ + def __init__(self): + pass + + def visit(self, node: ast_internal_classes.FNode, parent_node: Optional[ast_internal_classes.FNode] = None): + + parent_node_types = [ + ast_internal_classes.Subroutine_Subprogram_Node, + ast_internal_classes.Function_Subprogram_Node, + ast_internal_classes.Main_Program_Node, + ast_internal_classes.Module_Node + ] + + if parent_node is not None and type(parent_node) in parent_node_types: + node.parent = parent_node + elif parent_node is not None: + node.parent = parent_node.parent + + # Copied from `generic_visit` to recursively parse all leafs + for field, value in iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast_internal_classes.FNode): + self.visit(item, node) + elif isinstance(value, ast_internal_classes.FNode): + self.visit(value, node) + +class ScopeVarsDeclarations(NodeVisitor): + """ + Creates a mapping (scope name, variable name) -> variable declaration. + + The visitor is used to access information on variable dimension, sizes, and offsets. + """ + + def __init__(self): + + self.scope_vars: Dict[Tuple[str, str], ast_internal_classes.FNode] = {} + + def get_var(self, scope: ast_internal_classes.FNode, variable_name: str) -> ast_internal_classes.FNode: + return self.scope_vars[(self._scope_name(scope), variable_name)] + + def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node): + + parent_name = self._scope_name(node.parent) + var_name = node.name + self.scope_vars[(parent_name, var_name)] = node + + def _scope_name(self, scope: ast_internal_classes.FNode) -> str: + if isinstance(scope, ast_internal_classes.Main_Program_Node): + return scope.name.name.name + else: + return scope.name.name class IndexExtractorNodeLister(NodeVisitor): """ @@ -336,9 +395,20 @@ class IndexExtractor(NodeTransformer): Uses the IndexExtractorNodeLister to find all array subscript expressions in the AST node and its children that have to be extracted into independent expressions It then creates a new temporary variable for each of them and replaces the index expression with the variable. + + Before parsing the AST, the transformation first runs: + - ParentScopeAssigner to ensure that each node knows its scope assigner. + - ScopeVarsDeclarations to aggregate all variable declarations for each function. """ - def __init__(self, count=0): + def __init__(self, ast: ast_internal_classes.FNode, normalize_offsets: bool = False, count=0): + self.count = count + self.normalize_offsets = normalize_offsets + + if normalize_offsets: + ParentScopeAssigner().visit(ast) + self.scope_vars = ScopeVarsDeclarations() + self.scope_vars.visit(ast) def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]: @@ -367,9 +437,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No lister.visit(child) res = lister.nodes temp = self.count + + if res is not None: for j in res: - for i in j.indices: + for idx, i in enumerate(j.indices): if isinstance(i, ast_internal_classes.ParDecl_Node): continue else: @@ -383,16 +455,34 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No line_number=child.line_number) ], line_number=child.line_number)) - newbody.append( - ast_internal_classes.BinOp_Node( - op="=", - lval=ast_internal_classes.Name_Node(name=tmp_name), - rval=ast_internal_classes.BinOp_Node( - op="-", - lval=i, - rval=ast_internal_classes.Int_Literal_Node(value="1"), - line_number=child.line_number), - line_number=child.line_number)) + if self.normalize_offsets: + + # Find the offset of a variable to which we are assigning + var_name = child.lval.name.name + variable = self.scope_vars.get_var(child.parent, var_name) + offset = variable.offsets[idx] + + newbody.append( + ast_internal_classes.BinOp_Node( + op="=", + lval=ast_internal_classes.Name_Node(name=tmp_name), + rval=ast_internal_classes.BinOp_Node( + op="-", + lval=i, + rval=ast_internal_classes.Int_Literal_Node(value=str(offset)), + line_number=child.line_number), + line_number=child.line_number)) + else: + newbody.append( + ast_internal_classes.BinOp_Node( + op="=", + lval=ast_internal_classes.Name_Node(name=tmp_name), + rval=ast_internal_classes.BinOp_Node( + op="-", + lval=i, + rval=ast_internal_classes.Int_Literal_Node(value="1"), + line_number=child.line_number), + line_number=child.line_number)) newbody.append(self.visit(child)) return ast_internal_classes.Execution_Part_Node(execution=newbody) @@ -646,6 +736,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, rangepos: list, count: int, newbody: list, + scope_vars: ScopeVarsDeclarations, declaration=True, is_sum_to_loop=False): """ @@ -662,16 +753,40 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node, currentindex = 0 indices = [] - for i in node.indices: + offsets = scope_vars.get_var(node.parent, node.name.name).offsets + + for idx, i in enumerate(node.indices): if isinstance(i, ast_internal_classes.ParDecl_Node): + if i.type == "ALL": - ranges.append([ - ast_internal_classes.Int_Literal_Node(value="1"), - ast_internal_classes.Name_Range_Node(name="f2dace_MAX", - type="INTEGER", - arrname=node.name, - pos=currentindex) - ]) + + lower_boundary = None + if offsets[idx] != 1: + lower_boundary = ast_internal_classes.Int_Literal_Node(value=str(offsets[idx])) + else: + lower_boundary = ast_internal_classes.Int_Literal_Node(value="1") + + upper_boundary = ast_internal_classes.Name_Range_Node(name="f2dace_MAX", + type="INTEGER", + arrname=node.name, + pos=currentindex) + """ + When there's an offset, we add MAX_RANGE + offset. + But since the generated loop has `<=` condition, we need to subtract 1. + """ + if offsets[idx] != 1: + upper_boundary = ast_internal_classes.BinOp_Node( + lval=upper_boundary, + op="+", + rval=ast_internal_classes.Int_Literal_Node(value=str(offsets[idx])) + ) + upper_boundary = ast_internal_classes.BinOp_Node( + lval=upper_boundary, + op="-", + rval=ast_internal_classes.Int_Literal_Node(value="1") + ) + ranges.append([lower_boundary, upper_boundary]) + else: ranges.append([i.range[0], i.range[1]]) rangepos.append(currentindex) @@ -693,9 +808,13 @@ class ArrayToLoop(NodeTransformer): """ Transforms the AST by removing array expressions and replacing them with loops """ - def __init__(self): + def __init__(self, ast): self.count = 0 + ParentScopeAssigner().visit(ast) + self.scope_vars = ScopeVarsDeclarations() + self.scope_vars.visit(ast) + def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): newbody = [] for child in node.execution: @@ -709,7 +828,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No val = child.rval ranges = [] rangepos = [] - par_Decl_Range_Finder(current, ranges, rangepos, self.count, newbody, True) + par_Decl_Range_Finder(current, ranges, rangepos, self.count, newbody, self.scope_vars, True) if res_range is not None and len(res_range) > 0: rvals = [i for i in mywalk(val) if isinstance(i, ast_internal_classes.Array_Subscript_Node)] @@ -717,7 +836,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No rangeposrval = [] rangesrval = [] - par_Decl_Range_Finder(i, rangesrval, rangeposrval, self.count, newbody, False) + par_Decl_Range_Finder(i, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False) for i, j in zip(ranges, rangesrval): if i != j: @@ -791,8 +910,11 @@ class SumToLoop(NodeTransformer): """ Transforms the AST by removing array sums and replacing them with loops """ - def __init__(self): + def __init__(self, ast): self.count = 0 + ParentScopeAssigner().visit(ast) + self.scope_vars = ScopeVarsDeclarations() + self.scope_vars.visit(ast) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): newbody = [] @@ -811,7 +933,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No rangeposrval = [] rangesrval = [] - par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, False, True) + par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False, True) range_index = 0 body = ast_internal_classes.BinOp_Node(lval=current, diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index d7112892fe..b15435f4ff 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -133,7 +133,7 @@ def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG): for i in node: self.translate(i, sdfg) else: - warnings.warn("WARNING:", node.__class__.__name__) + warnings.warn(f"WARNING: {node.__class__.__name__}") def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG): """ @@ -1015,10 +1015,46 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG): if node.name not in self.contexts[sdfg.name].containers: self.contexts[sdfg.name].containers.append(node.name) +def create_ast_from_string( + source_string: str, + sdfg_name: str, + transform: bool = False, + normalize_offsets: bool = False +): + """ + Creates an AST from a Fortran file in a string + :param source_string: The fortran file as a string + :param sdfg_name: The name to be given to the resulting SDFG + :return: The resulting AST + + """ + parser = pf().create(std="f2008") + reader = fsr(source_string) + ast = parser(reader) + tables = SymbolTable + own_ast = ast_components.InternalFortranAst(ast, tables) + program = own_ast.create_ast(ast) + + functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines() + functions_and_subroutines_builder.visit(program) + functions_and_subroutines = functions_and_subroutines_builder.nodes + + if transform: + program = ast_transforms.functionStatementEliminator(program) + program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program) + program = ast_transforms.CallExtractor().visit(program) + program = ast_transforms.SignToIf().visit(program) + program = ast_transforms.ArrayToLoop(program).visit(program) + program = ast_transforms.SumToLoop(program).visit(program) + program = ast_transforms.ForDeclarer().visit(program) + program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) + + return (program, own_ast) def create_sdfg_from_string( source_string: str, sdfg_name: str, + normalize_offsets: bool = False ): """ Creates an SDFG from a fortran file in a string @@ -1040,10 +1076,10 @@ def create_sdfg_from_string( program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program) program = ast_transforms.CallExtractor().visit(program) program = ast_transforms.SignToIf().visit(program) - program = ast_transforms.ArrayToLoop().visit(program) - program = ast_transforms.SumToLoop().visit(program) + program = ast_transforms.ArrayToLoop(program).visit(program) + program = ast_transforms.SumToLoop(program).visit(program) program = ast_transforms.ForDeclarer().visit(program) - program = ast_transforms.IndexExtractor().visit(program) + program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) ast2sdfg = AST_translator(own_ast, __file__) sdfg = SDFG(sdfg_name) ast2sdfg.top_level = program @@ -1082,10 +1118,10 @@ def create_sdfg_from_fortran_file(source_string: str): program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program) program = ast_transforms.CallExtractor().visit(program) program = ast_transforms.SignToIf().visit(program) - program = ast_transforms.ArrayToLoop().visit(program) - program = ast_transforms.SumToLoop().visit(program) + program = ast_transforms.ArrayToLoop(program).visit(program) + program = ast_transforms.SumToLoop(program).visit(program) program = ast_transforms.ForDeclarer().visit(program) - program = ast_transforms.IndexExtractor().visit(program) + program = ast_transforms.IndexExtractor(program).visit(program) ast2sdfg = AST_translator(own_ast, __file__) sdfg = SDFG(source_string) ast2sdfg.top_level = program diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 4a0ec88531..67d8b6aded 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -15,6 +15,12 @@ from dace import dtypes, symbolic +if sys.version_info >= (3, 8): + NumConstant = ast.Constant +else: + NumConstant = ast.Num + + def _remove_outer_indentation(src: str): """ Removes extra indentation from a source Python function. @@ -66,8 +72,9 @@ def is_constant(node: ast.AST) -> bool: if sys.version_info >= (3, 8): if isinstance(node, ast.Constant): return True - if isinstance(node, (ast.Num, ast.Str, ast.NameConstant)): # For compatibility - return True + else: + if isinstance(node, (ast.Num, ast.Str, ast.NameConstant)): # For compatibility + return True return False @@ -82,13 +89,14 @@ def evalnode(node: ast.AST, gvars: Dict[str, Any]) -> Any: """ if not isinstance(node, ast.AST): return node - if isinstance(node, ast.Index): # For compatibility + if sys.version_info < (3, 9) and isinstance(node, ast.Index): # For compatibility node = node.value - if isinstance(node, ast.Num): # For compatibility - return node.n if sys.version_info >= (3, 8): if isinstance(node, ast.Constant): return node.value + else: + if isinstance(node, ast.Num): # For compatibility + return node.n # Replace internal constants with their values node = copy_tree(node) @@ -112,7 +120,7 @@ def rname(node): if isinstance(node, str): return node - if isinstance(node, ast.Num): + if sys.version_info < (3, 8) and isinstance(node, ast.Num): return str(node.n) if isinstance(node, ast.Name): # form x return node.id @@ -174,11 +182,11 @@ def subscript_to_ast_slice(node, without_array=False): # Python <3.9 compatibility result_slice = None - if isinstance(node.slice, ast.Index): + if sys.version_info < (3, 9) and isinstance(node.slice, ast.Index): slc = node.slice.value if not isinstance(slc, ast.Tuple): result_slice = [slc] - elif isinstance(node.slice, ast.ExtSlice): + elif sys.version_info < (3, 9) and isinstance(node.slice, ast.ExtSlice): slc = tuple(node.slice.dims) else: slc = node.slice @@ -196,7 +204,7 @@ def subscript_to_ast_slice(node, without_array=False): # Slice if isinstance(s, ast.Slice): result_slice.append((s.lower, s.upper, s.step)) - elif isinstance(s, ast.Index): # Index (Python <3.9) + elif sys.version_info < (3, 9) and isinstance(s, ast.Index): # Index (Python <3.9) result_slice.append(s.value) else: # Index result_slice.append(s) @@ -226,7 +234,7 @@ def _Subscript(self, t): self.dispatch(t.value) self.write('[') # Compatibility - if isinstance(t.slice, ast.Index): + if sys.version_info < (3, 9) and isinstance(t.slice, ast.Index): slc = t.slice.value else: slc = t.slice @@ -600,9 +608,9 @@ def visit_Name(self, node: ast.Name): def visit_Constant(self, node): return self.visit_Num(node) - def visit_Num(self, node: ast.Num): + def visit_Num(self, node: NumConstant): newname = f'__uu{self.id}' - self.gvars[newname] = node.n + self.gvars[newname] = node.value if sys.version_info >= (3, 8) else node.n self.id += 1 return ast.copy_location(ast.Name(id=newname, ctx=ast.Load()), node) @@ -705,3 +713,17 @@ def escape_string(value: Union[bytes, str]): return value.encode("unicode_escape").decode("utf-8") # Python 2.x return value.encode('string_escape') + + +def parse_function_arguments(node: ast.Call, argnames: List[str]) -> Dict[str, ast.AST]: + """ + Parses function arguments (both positional and keyword) from a Call node, + based on the function's argument names. If an argument was not given, it will + not be in the result. + """ + result = {} + for arg, aname in zip(node.args, argnames): + result[aname] = arg + for kw in node.keywords: + result[kw.arg] = kw.value + return result diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index ea1970dafd..69e650beaa 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -293,10 +293,11 @@ class tasklet(metaclass=TaskletMetaclass): The DaCe framework cannot analyze these tasklets for optimization. """ - def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python): + def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python, side_effects: bool = False): if isinstance(language, str): language = dtypes.Language[language] self.language = language + self.side_effects = side_effects def __enter__(self): if self.language != dtypes.Language.Python: diff --git a/dace/frontend/python/memlet_parser.py b/dace/frontend/python/memlet_parser.py index 6ef627a430..a95bf82046 100644 --- a/dace/frontend/python/memlet_parser.py +++ b/dace/frontend/python/memlet_parser.py @@ -1,7 +1,7 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import ast import copy -import re +import sys from collections import namedtuple from typing import Any, Dict, List, Optional, Tuple, Union from dataclasses import dataclass @@ -16,6 +16,22 @@ MemletType = Union[ast.Call, ast.Attribute, ast.Subscript, ast.Name] +if sys.version_info < (3, 8): + _simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) + BytesConstant = ast.Bytes + EllipsisConstant = ast.Ellipsis + NameConstant = ast.NameConstant + NumConstant = ast.Num + StrConstant = ast.Str +else: + _simple_ast_nodes = (ast.Constant, ast.Name) + BytesConstant = ast.Constant + EllipsisConstant = ast.Constant + NameConstant = ast.Constant + NumConstant = ast.Constant + StrConstant = ast.Constant + + @dataclass class MemletExpr: name: str @@ -114,7 +130,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): offsets.append(idx) idx += 1 new_idx += 1 - elif (isinstance(dim, ast.Ellipsis) or dim is Ellipsis + elif ((sys.version_info < (3, 8) and isinstance(dim, ast.Ellipsis)) or dim is Ellipsis or (isinstance(dim, ast.Constant) and dim.value is Ellipsis) or (isinstance(dim, ast.Name) and dim.id is Ellipsis)): if has_ellipsis: @@ -125,7 +141,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): ndslice[j] = (0, array.shape[j] - 1, 1) idx += 1 new_idx += 1 - elif (dim is None or (isinstance(dim, (ast.Constant, ast.NameConstant)) and dim.value is None)): + elif (dim is None or (isinstance(dim, (ast.Constant, NameConstant)) and dim.value is None)): new_axes.append(new_idx) new_idx += 1 # NOTE: Do not increment idx here @@ -200,7 +216,7 @@ def _fill_missing_slices(das, ast_ndslice, array, indices): def parse_memlet_subset(array: data.Data, node: Union[ast.Name, ast.Subscript], das: Dict[str, Any], - parsed_slice: Any = None) -> Tuple[subsets.Range, List[int]]: + parsed_slice: Any = None) -> Tuple[subsets.Range, List[int], List[int]]: """ Parses an AST subset and returns access range, as well as new dimensions to add. @@ -209,7 +225,7 @@ def parse_memlet_subset(array: data.Data, e.g., negative indices or empty shapes). :param node: AST node representing whole array or subset thereof. :param das: Dictionary of defined arrays and symbols mapped to their values. - :return: A 2-tuple of (subset, list of new axis indices). + :return: A 3-tuple of (subset, list of new axis indices, list of index-to-array-dimension correspondence). """ # Get memlet range ndslice = [(0, s - 1, 1) for s in array.shape] @@ -285,7 +301,11 @@ def ParseMemlet(visitor, if len(node.value.args) >= 2: write_conflict_resolution = node.value.args[1] - subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice) + try: + subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice) + except IndexError: + raise DaceSyntaxError(visitor, node, 'Failed to parse memlet expression due to dimensionality. ' + f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}') # If undefined, default number of accesses is the slice size if num_accesses is None: diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c9d92b7860..733c3c7f62 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -50,6 +50,36 @@ DependencyType = Dict[str, Tuple[SDFGState, Union[Memlet, nodes.Tasklet], Tuple[int]]] +if sys.version_info < (3, 8): + _simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) + BytesConstant = ast.Bytes + EllipsisConstant = ast.Ellipsis + NameConstant = ast.NameConstant + NumConstant = ast.Num + StrConstant = ast.Str +else: + _simple_ast_nodes = (ast.Constant, ast.Name) + BytesConstant = ast.Constant + EllipsisConstant = ast.Constant + NameConstant = ast.Constant + NumConstant = ast.Constant + StrConstant = ast.Constant + + +if sys.version_info < (3, 9): + Index = ast.Index + ExtSlice = ast.ExtSlice +else: + Index = type(None) + ExtSlice = type(None) + + +if sys.version_info < (3, 12): + TypeAlias = type(None) +else: + TypeAlias = ast.TypeAlias + + class SkipCall(Exception): """ Exception used to skip calls to functions that cannot be parsed. """ pass @@ -279,7 +309,7 @@ def repl_callback(repldict): # Extra AST node types that are disallowed after preprocessing _DISALLOWED_STMTS = DISALLOWED_STMTS + [ 'Global', 'Assert', 'Print', 'Nonlocal', 'Raise', 'Starred', 'AsyncFor', 'ListComp', 'GeneratorExp', 'SetComp', - 'DictComp', 'comprehension' + 'DictComp', 'comprehension', 'TypeAlias', 'TypeVar', 'ParamSpec', 'TypeVarTuple' ] TaskletType = Union[ast.FunctionDef, ast.With, ast.For] @@ -981,16 +1011,16 @@ def visit_TopLevelExpr(self, node): raise DaceSyntaxError(self, node, 'Local variable is already a tasklet input or output') self.outputs[connector] = memlet return None # Remove from final tasklet code - elif isinstance(node.value, ast.Str): + elif isinstance(node.value, StrConstant): return self.visit_TopLevelStr(node.value) return self.generic_visit(node) # Detect external tasklet code - def visit_TopLevelStr(self, node: ast.Str): + def visit_TopLevelStr(self, node: StrConstant): if self.extcode != None: raise DaceSyntaxError(self, node, 'Cannot provide more than one intrinsic implementation ' + 'for tasklet') - self.extcode = node.s + self.extcode = node.value if sys.version_info >= (3, 8) else node.s # TODO: Should get detected by _parse_Tasklet() if self.lang is None: @@ -1611,7 +1641,7 @@ def _parse_for_indices(self, node: ast.Expr): return indices - def _parse_value(self, node: Union[ast.Name, ast.Num, ast.Constant]): + def _parse_value(self, node: Union[ast.Name, NumConstant, ast.Constant]): """Parses a value Arguments: @@ -1626,7 +1656,7 @@ def _parse_value(self, node: Union[ast.Name, ast.Num, ast.Constant]): if isinstance(node, ast.Name): return node.id - elif isinstance(node, ast.Num): + elif sys.version_info < (3, 8) and isinstance(node, ast.Num): return str(node.n) elif isinstance(node, ast.Constant): return str(node.value) @@ -1646,14 +1676,14 @@ def _parse_slice(self, node: ast.Slice): return (self._parse_value(node.lower), self._parse_value(node.upper), self._parse_value(node.step) if node.step is not None else "1") - def _parse_index_as_range(self, node: Union[ast.Index, ast.Tuple]): + def _parse_index_as_range(self, node: Union[Index, ast.Tuple]): """ Parses an index as range :param node: Index node :return: Range in (from, to, step) format """ - if isinstance(node, ast.Index): + if sys.version_info < (3, 9) and isinstance(node, ast.Index): val = self._parse_value(node.value) elif isinstance(node, ast.Tuple): val = self._parse_value(node.elts) @@ -1760,7 +1790,7 @@ def visit_ast_or_value(arg): iterator = 'dace.map' else: ranges = [] - if isinstance(node.slice, (ast.Tuple, ast.ExtSlice)): + if isinstance(node.slice, (ast.Tuple, ExtSlice)): for s in node.slice.dims: ranges.append(self._parse_slice(s)) elif isinstance(node.slice, ast.Slice): @@ -2344,12 +2374,11 @@ def _is_test_simple(self, node: ast.AST): # Fix for scalar promotion tests # TODO: Maybe those tests should use the SDFG API instead of the # Python frontend which can change how it handles conditions. - simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) - is_test_simple = isinstance(node, simple_ast_nodes) + is_test_simple = isinstance(node, _simple_ast_nodes) if not is_test_simple: if isinstance(node, ast.Compare): - is_left_simple = isinstance(node.left, simple_ast_nodes) - is_right_simple = (len(node.comparators) == 1 and isinstance(node.comparators[0], simple_ast_nodes)) + is_left_simple = isinstance(node.left, _simple_ast_nodes) + is_right_simple = (len(node.comparators) == 1 and isinstance(node.comparators[0], _simple_ast_nodes)) if is_left_simple and is_right_simple: return True elif isinstance(node, ast.BoolOp): @@ -2510,6 +2539,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): # Looking for the first argument in a tasklet annotation: @dace.tasklet(STRING HERE) langInf = None + side_effects = None if isinstance(node, ast.FunctionDef) and \ hasattr(node, 'decorator_list') and \ isinstance(node.decorator_list, list) and \ @@ -2522,6 +2552,19 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): langArg = node.decorator_list[0].args[0].value langInf = dtypes.Language[langArg] + # Extract arguments from with statement + if isinstance(node, ast.With): + expr = node.items[0].context_expr + if isinstance(expr, ast.Call): + args = astutils.parse_function_arguments(expr, ['language', 'side_effects']) + langArg = args.get('language', None) + side_effects = args.get('side_effects', None) + langInf = astutils.evalnode(langArg, {**self.globals, **self.defined}) + if isinstance(langInf, str): + langInf = dtypes.Language[langInf] + + side_effects = astutils.evalnode(side_effects, {**self.globals, **self.defined}) + ttrans = TaskletTransformer(self, self.defined, self.sdfg, @@ -2536,6 +2579,9 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): symbols=self.symbols) node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name) + if side_effects is not None: + node.side_effects = side_effects + # Convert memlets to their actual data nodes for i in inputs.values(): if not isinstance(i, tuple) and i.data in self.scope_vars.keys(): @@ -3160,6 +3206,12 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): if (not is_return and isinstance(target, ast.Name) and true_name and not op and not isinstance(true_array, data.Scalar) and not (true_array.shape == (1, ))): + if true_name in self.views: + if result in self.sdfg.arrays and self.views[true_name] == ( + result, Memlet.from_array(result, self.sdfg.arrays[result])): + continue + else: + raise DaceSyntaxError(self, target, 'Cannot reassign View "{}"'.format(name)) if (isinstance(result, str) and result in self.sdfg.arrays and self.sdfg.arrays[result].is_equivalent(true_array)): # Skip error if the arrays are defined exactly in the same way. @@ -4270,7 +4322,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): func = None funcname = None # If the call directly refers to an SDFG or dace-compatible program - if isinstance(node.func, ast.Num): + if sys.version_info < (3, 8) and isinstance(node.func, ast.Num): if self._has_sdfg(node.func.n): func = node.func.n elif isinstance(node.func, ast.Constant): @@ -4589,15 +4641,15 @@ def _visitname(self, name: str, node: ast.AST): return rname #### Visitors that return arrays - def visit_Str(self, node: ast.Str): + def visit_Str(self, node: StrConstant): # A string constant returns a string literal return StringLiteral(node.s) - def visit_Bytes(self, node: ast.Bytes): + def visit_Bytes(self, node: BytesConstant): # A bytes constant returns a string literal return StringLiteral(node.s) - def visit_Num(self, node: ast.Num): + def visit_Num(self, node: NumConstant): if isinstance(node.n, bool): return dace.bool_(node.n) if isinstance(node.n, (int, float, complex)): @@ -4617,7 +4669,7 @@ def visit_Name(self, node: ast.Name): # If visiting a name, check if it is a defined variable or a global return self._visitname(node.id, node) - def visit_NameConstant(self, node: ast.NameConstant): + def visit_NameConstant(self, node: NameConstant): return self.visit_Constant(node) def visit_Attribute(self, node: ast.Attribute): @@ -4666,6 +4718,9 @@ def visit_Dict(self, node: ast.Dict): def visit_Lambda(self, node: ast.Lambda): # Return a string representation of the function return astutils.unparse(node) + + def visit_TypeAlias(self, node: TypeAlias): + raise NotImplementedError('Type aliases are not supported in DaCe') ############################################################ @@ -4892,7 +4947,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]: res = self.visit(s) else: res = self._visit_ast_or_value(s) - elif isinstance(s, ast.Index): + elif sys.version_info < (3, 9) and isinstance(s, ast.Index): res = self._parse_subscript_slice(s.value) elif isinstance(s, ast.Slice): lower = s.lower @@ -4910,7 +4965,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]: res = ((lower, upper, step), ) elif isinstance(s, ast.Tuple): res = tuple(self._parse_subscript_slice(d, multidim=True) for d in s.elts) - elif isinstance(s, ast.ExtSlice): + elif sys.version_info < (3, 9) and isinstance(s, ast.ExtSlice): res = tuple(self._parse_subscript_slice(d, multidim=True) for d in s.dims) else: res = _promote(s) @@ -4972,8 +5027,8 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False): # If the value is a tuple of constants (e.g., array.shape) and the # slice is constant, return the value itself nslice = self.visit(node.slice) - if isinstance(nslice, (ast.Index, Number)): - if isinstance(nslice, ast.Index): + if isinstance(nslice, (Index, Number)): + if sys.version_info < (3, 9) and isinstance(nslice, ast.Index): v = self._parse_value(nslice.value) else: v = nslice @@ -5037,7 +5092,7 @@ def _visit_ast_or_value(self, node: ast.AST) -> Any: out = out[0] return out - def visit_Index(self, node: ast.Index) -> Any: + def visit_Index(self, node: Index) -> Any: if isinstance(node.value, ast.Tuple): for i, elt in enumerate(node.value.elts): node.value.elts[i] = self._visit_ast_or_value(elt) @@ -5045,7 +5100,7 @@ def visit_Index(self, node: ast.Index) -> Any: node.value = self._visit_ast_or_value(node.value) return node - def visit_ExtSlice(self, node: ast.ExtSlice) -> Any: + def visit_ExtSlice(self, node: ExtSlice) -> Any: for i, dim in enumerate(node.dims): node.dims[i] = self._visit_ast_or_value(dim) diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 10a1ab120e..1636e57ad0 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -20,6 +20,20 @@ from dace.frontend.python.common import (DaceSyntaxError, SDFGConvertible, SDFGClosure, StringLiteral) +if sys.version_info < (3, 8): + BytesConstant = ast.Bytes + EllipsisConstant = ast.Ellipsis + NameConstant = ast.NameConstant + NumConstant = ast.Num + StrConstant = ast.Str +else: + BytesConstant = ast.Constant + EllipsisConstant = ast.Constant + NameConstant = ast.Constant + NumConstant = ast.Constant + StrConstant = ast.Constant + + class DaceRecursionError(Exception): """ Exception that indicates a recursion in a data-centric parsed context. @@ -342,13 +356,13 @@ def remake_dict(args): # Remake keyword argument names from AST kwarg_names = [] for kw in arg.keys: - if isinstance(kw, ast.Num): + if sys.version_info >= (3, 8) and isinstance(kw, ast.Constant): + kwarg_names.append(kw.value) + elif sys.version_info < (3, 8) and isinstance(kw, ast.Num): kwarg_names.append(kw.n) - elif isinstance(kw, (ast.Str, ast.Bytes)): + elif sys.version_info < (3, 8) and isinstance(kw, (ast.Str, ast.Bytes)): kwarg_names.append(kw.s) - elif isinstance(kw, ast.NameConstant): - kwarg_names.append(kw.value) - elif sys.version_info >= (3, 8) and isinstance(kw, ast.Constant): + elif sys.version_info < (3, 8) and isinstance(kw, ast.NameConstant): kwarg_names.append(kw.value) else: raise NotImplementedError(f'Key type {type(kw).__name__} is not supported') @@ -754,7 +768,8 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: def visit_Call(self, node: ast.Call) -> Any: from dace.frontend.python.interface import in_program, inline # Avoid import loop - if hasattr(node.func, 'n') and isinstance(node.func.n, SDFGConvertible): + if (hasattr(node.func, 'value') and isinstance(node.func.value, SDFGConvertible) or + sys.version_info < (3, 8) and hasattr(node.func, 'n') and isinstance(node.func.n, SDFGConvertible)): # Skip already-parsed calls return self.generic_visit(node) @@ -858,7 +873,8 @@ def visit_JoinedStr(self, node: ast.JoinedStr) -> Any: parsed = [ not isinstance(v, ast.FormattedValue) or isinstance(v.value, ast.Constant) for v in visited.values ] - values = [v.s if isinstance(v, ast.Str) else astutils.unparse(v.value) for v in visited.values] + values = [v.s if sys.version_info < (3, 8) and isinstance(v, ast.Str) else astutils.unparse(v.value) + for v in visited.values] return ast.copy_location( ast.Constant(kind='', value=''.join(('{%s}' % v) if not p else v for p, v in zip(parsed, values))), node) @@ -1268,7 +1284,7 @@ def _convert_to_ast(contents: Any): node) else: # Augment closure with new value - newnode = self.resolver.global_value_to_node(e, node, f'inlined_{id(contents)}', True, keep_object=True) + newnode = self.resolver.global_value_to_node(contents, node, f'inlined_{id(contents)}', True, keep_object=True) return newnode return _convert_to_ast(contents) @@ -1358,7 +1374,7 @@ def _get_given_args(self, node: ast.Call, function: 'DaceProgram') -> Set[str]: def visit_Call(self, node: ast.Call): # Only parse calls to parsed SDFGConvertibles - if not isinstance(node.func, (ast.Num, ast.Constant)): + if not isinstance(node.func, (NumConstant, ast.Constant)): self.seen_calls.add(astutils.unparse(node.func)) return self.generic_visit(node) if hasattr(node.func, 'oldnode'): @@ -1366,10 +1382,7 @@ def visit_Call(self, node: ast.Call): self.seen_calls.add(astutils.unparse(node.func.oldnode.func)) else: self.seen_calls.add(astutils.rname(node.func.oldnode)) - if isinstance(node.func, ast.Num): - value = node.func.n - else: - value = node.func.value + value = node.func.value if sys.version_info >= (3, 8) else node.func.n if not hasattr(value, '__sdfg__') or isinstance(value, SDFG): return self.generic_visit(node) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index 9643d51c1f..eace0c8336 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -617,9 +617,10 @@ def _elementwise(pv: 'ProgramVisitor', def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: dace.typeclass = None): """ Implements a simple call of the form `out = func(inp)`. """ + create_input = True if isinstance(inpname, (list, tuple)): # TODO investigate this inpname = inpname[0] - if not isinstance(inpname, str): + if not isinstance(inpname, str) and not symbolic.issymbolic(inpname): # Constant parameter cst = inpname inparr = data.create_datadescriptor(cst) @@ -627,6 +628,10 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: inparr.transient = True sdfg.add_constant(inpname, cst, inparr) sdfg.add_datadesc(inpname, inparr) + elif symbolic.issymbolic(inpname): + dtype = symbolic.symtype(inpname) + inparr = data.Scalar(dtype) + create_input = False else: inparr = sdfg.arrays[inpname] @@ -636,10 +641,17 @@ def _simple_call(sdfg: SDFG, state: SDFGState, inpname: str, func: str, restype: outarr.dtype = restype num_elements = data._prod(inparr.shape) if num_elements == 1: - inp = state.add_read(inpname) + if create_input: + inp = state.add_read(inpname) + inconn_name = '__inp' + else: + inconn_name = symbolic.symstr(inpname) + out = state.add_write(outname) - tasklet = state.add_tasklet(func, {'__inp'}, {'__out'}, '__out = {f}(__inp)'.format(f=func)) - state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) + tasklet = state.add_tasklet(func, {'__inp'} if create_input else {}, {'__out'}, + f'__out = {func}({inconn_name})') + if create_input: + state.add_edge(inp, None, tasklet, '__inp', Memlet.from_array(inpname, inparr)) state.add_edge(tasklet, '__out', out, None, Memlet.from_array(outname, outarr)) else: state.add_mapped_tasklet( @@ -2158,8 +2170,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[-1], arr2.shape[-2]) if res is None: - warnings.warn(f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of ' - f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning) + warnings.warn( + f'Last mode of first tesnsor/matrix {arr1.shape[-1]} and second-last mode of ' + f'second tensor/matrix {arr2.shape[-2]} may not match', UserWarning) elif not res: raise SyntaxError('Matrix dimension mismatch %s != %s' % (arr1.shape[-1], arr2.shape[-2])) @@ -2176,8 +2189,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[-1], arr2.shape[0]) if res is None: - warnings.warn(f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Number of matrix columns {arr1.shape[-1]} and length of vector {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Number of matrix columns {} must match" "size of vector {}.".format(arr1.shape[1], arr2.shape[0])) @@ -2188,8 +2202,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[0], arr2.shape[0]) if res is None: - warnings.warn(f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Length of vector {arr1.shape[0]} and number of matrix rows {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Size of vector {} must match number of matrix " "rows {} must match".format(arr1.shape[0], arr2.shape[0])) @@ -2200,8 +2215,9 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op res = symbolic.equal(arr1.shape[0], arr2.shape[0]) if res is None: - warnings.warn(f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} ' - f'may not match', UserWarning) + warnings.warn( + f'Length of first vector {arr1.shape[0]} and length of second vector {arr2.shape[0]} ' + f'may not match', UserWarning) elif not res: raise SyntaxError("Vectors in vector product must have same size: " "{} vs. {}".format(arr1.shape[0], arr2.shape[0])) @@ -4401,10 +4417,13 @@ def _datatype_converter(sdfg: SDFG, state: SDFGState, arg: UfuncInput, dtype: dt # Set tasklet parameters impl = { - 'name': "_convert_to_{}_".format(dtype.to_string()), + 'name': + "_convert_to_{}_".format(dtype.to_string()), 'inputs': ['__inp'], 'outputs': ['__out'], - 'code': "__out = dace.{}(__inp)".format(dtype.to_string()) + 'code': + "__out = {}(__inp)".format(f"dace.{dtype.to_string()}" if dtype not in (dace.bool, + dace.bool_) else dtype.to_string()) } if dtype in (dace.bool, dace.bool_): impl['code'] = "__out = dace.bool_(__inp)" diff --git a/dace/libraries/blas/environments/cublas.py b/dace/libraries/blas/environments/cublas.py index d4ab879e61..ef73b511c0 100644 --- a/dace/libraries/blas/environments/cublas.py +++ b/dace/libraries/blas/environments/cublas.py @@ -25,7 +25,7 @@ class cuBLAS: def handle_setup_code(node): location = node.location if not location or "gpu" not in node.location: - location = 0 + location = -1 # -1 means current device else: try: location = int(location["gpu"]) diff --git a/dace/libraries/blas/environments/rocblas.py b/dace/libraries/blas/environments/rocblas.py index 5d752ed690..47e16531ff 100644 --- a/dace/libraries/blas/environments/rocblas.py +++ b/dace/libraries/blas/environments/rocblas.py @@ -25,7 +25,7 @@ class rocBLAS: def handle_setup_code(node): location = node.location if not location or "gpu" not in node.location: - location = 0 + location = -1 # -1 means current device else: try: location = int(location["gpu"]) diff --git a/dace/libraries/blas/include/dace_cublas.h b/dace/libraries/blas/include/dace_cublas.h index 8ec03c2b37..3547a009d2 100644 --- a/dace/libraries/blas/include/dace_cublas.h +++ b/dace/libraries/blas/include/dace_cublas.h @@ -21,8 +21,10 @@ static void CheckCublasError(cublasStatus_t const& status) { } static cublasHandle_t CreateCublasHandle(int device) { - if (cudaSetDevice(device) != cudaSuccess) { - throw std::runtime_error("Failed to set CUDA device."); + if (device >= 0) { + if (cudaSetDevice(device) != cudaSuccess) { + throw std::runtime_error("Failed to set CUDA device."); + } } cublasHandle_t handle; CheckCublasError(cublasCreate(&handle)); @@ -65,8 +67,10 @@ class _CublasConstants { } _CublasConstants(int device) { - if (cudaSetDevice(device) != cudaSuccess) { - throw std::runtime_error("Failed to set CUDA device."); + if (device >= 0) { + if (cudaSetDevice(device) != cudaSuccess) { + throw std::runtime_error("Failed to set CUDA device."); + } } // Allocate constant zero with the largest used size cudaMalloc(&zero_, sizeof(cuDoubleComplex) * 1); diff --git a/dace/libraries/blas/include/dace_rocblas.h b/dace/libraries/blas/include/dace_rocblas.h index 7a7e4a75ee..00469136a3 100644 --- a/dace/libraries/blas/include/dace_rocblas.h +++ b/dace/libraries/blas/include/dace_rocblas.h @@ -24,8 +24,10 @@ static void CheckRocblasError(rocblas_status const& status) { } static rocblas_handle CreateRocblasHandle(int device) { - if (hipSetDevice(device) != hipSuccess) { - throw std::runtime_error("Failed to set HIP device."); + if (device >= 0) { + if (hipSetDevice(device) != hipSuccess) { + throw std::runtime_error("Failed to set HIP device."); + } } rocblas_handle handle; CheckRocblasError(rocblas_create_handle(&handle)); @@ -68,53 +70,55 @@ class _RocblasConstants { } _RocblasConstants(int device) { - if (hipSetDevice(device) != hipSuccess) { - throw std::runtime_error("Failed to set HIP device."); + if (device >= 0) { + if (hipSetDevice(device) != hipSuccess) { + throw std::runtime_error("Failed to set HIP device."); + } } // Allocate constant zero with the largest used size - hipMalloc(&zero_, sizeof(hipDoubleComplex) * 1); - hipMemset(zero_, 0, sizeof(hipDoubleComplex) * 1); + (void)hipMalloc(&zero_, sizeof(hipDoubleComplex) * 1); + (void)hipMemset(zero_, 0, sizeof(hipDoubleComplex) * 1); // Allocate constant one - hipMalloc(&half_pone_, sizeof(__half) * 1); + (void)hipMalloc(&half_pone_, sizeof(__half) * 1); __half half_pone = __float2half(1.0f); - hipMemcpy(half_pone_, &half_pone, sizeof(__half) * 1, + (void)hipMemcpy(half_pone_, &half_pone, sizeof(__half) * 1, hipMemcpyHostToDevice); - hipMalloc(&float_pone_, sizeof(float) * 1); + (void)hipMalloc(&float_pone_, sizeof(float) * 1); float float_pone = 1.0f; - hipMemcpy(float_pone_, &float_pone, sizeof(float) * 1, + (void)hipMemcpy(float_pone_, &float_pone, sizeof(float) * 1, hipMemcpyHostToDevice); - hipMalloc(&double_pone_, sizeof(double) * 1); + (void)hipMalloc(&double_pone_, sizeof(double) * 1); double double_pone = 1.0; - hipMemcpy(double_pone_, &double_pone, sizeof(double) * 1, + (void)hipMemcpy(double_pone_, &double_pone, sizeof(double) * 1, hipMemcpyHostToDevice); - hipMalloc(&complex64_pone_, sizeof(hipComplex) * 1); + (void)hipMalloc(&complex64_pone_, sizeof(hipComplex) * 1); hipComplex complex64_pone = make_hipFloatComplex(1.0f, 0.0f); - hipMemcpy(complex64_pone_, &complex64_pone, sizeof(hipComplex) * 1, + (void)hipMemcpy(complex64_pone_, &complex64_pone, sizeof(hipComplex) * 1, hipMemcpyHostToDevice); - hipMalloc(&complex128_pone_, sizeof(hipDoubleComplex) * 1); + (void)hipMalloc(&complex128_pone_, sizeof(hipDoubleComplex) * 1); hipDoubleComplex complex128_pone = make_hipDoubleComplex(1.0, 0.0); - hipMemcpy(complex128_pone_, &complex128_pone, sizeof(hipDoubleComplex) * 1, + (void)hipMemcpy(complex128_pone_, &complex128_pone, sizeof(hipDoubleComplex) * 1, hipMemcpyHostToDevice); // Allocate custom factors and default to zero - hipMalloc(&custom_alpha_, sizeof(hipDoubleComplex) * 1); - hipMemset(custom_alpha_, 0, sizeof(hipDoubleComplex) * 1); - hipMalloc(&custom_beta_, sizeof(hipDoubleComplex) * 1); - hipMemset(custom_beta_, 0, sizeof(hipDoubleComplex) * 1); + (void)hipMalloc(&custom_alpha_, sizeof(hipDoubleComplex) * 1); + (void)hipMemset(custom_alpha_, 0, sizeof(hipDoubleComplex) * 1); + (void)hipMalloc(&custom_beta_, sizeof(hipDoubleComplex) * 1); + (void)hipMemset(custom_beta_, 0, sizeof(hipDoubleComplex) * 1); } _RocblasConstants(_RocblasConstants const&) = delete; ~_RocblasConstants() { - hipFree(zero_); - hipFree(half_pone_); - hipFree(float_pone_); - hipFree(double_pone_); - hipFree(complex64_pone_); - hipFree(complex128_pone_); - hipFree(custom_alpha_); - hipFree(custom_beta_); + (void)hipFree(zero_); + (void)hipFree(half_pone_); + (void)hipFree(float_pone_); + (void)hipFree(double_pone_); + (void)hipFree(complex64_pone_); + (void)hipFree(complex128_pone_); + (void)hipFree(custom_alpha_); + (void)hipFree(custom_beta_); } _RocblasConstants& operator=(_RocblasConstants const&) = delete; diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 2db2055ae5..83be99d78b 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -184,11 +184,11 @@ def expansion(node, state, sdfg): code = '' if dtype in (dace.complex64, dace.complex128): code = f''' - {dtype.ctype} alpha = {alpha}; - {dtype.ctype} beta = {beta}; + {dtype.ctype} __alpha = {alpha}; + {dtype.ctype} __beta = {beta}; ''' - opt['alpha'] = '&alpha' - opt['beta'] = '&beta' + opt['alpha'] = '&__alpha' + opt['beta'] = '&__beta' code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " @@ -287,12 +287,12 @@ def expansion(cls, node, state, sdfg): # Set pointer mode to host call_prefix += f'''{cls.set_pointer_mode}(__dace_{cls.backend}blas_handle, {cls.pointer_host}); - {dtype.ctype} alpha = {alpha}; - {dtype.ctype} beta = {beta}; + {dtype.ctype} __alpha = {alpha}; + {dtype.ctype} __beta = {beta}; ''' call_suffix += f'''{cls.set_pointer_mode}(__dace_{cls.backend}blas_handle, {cls.pointer_device});''' - alpha = f'({cdtype} *)&alpha' - beta = f'({cdtype} *)&beta' + alpha = f'({cdtype} *)&__alpha' + beta = f'({cdtype} *)&__beta' else: alpha = constants[node.alpha] beta = constants[node.beta] diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index f0767a0473..83d07ded29 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -217,5 +217,7 @@ class MatMul(dace.sdfg.nodes.LibraryNode): default=0, desc="A scalar which will be multiplied with C before adding C") - def __init__(self, name, location=None): + def __init__(self, name, location=None, alpha=1, beta=0): + self.alpha = alpha + self.beta = beta super().__init__(name, location=location, inputs={"_a", "_b"}, outputs={"_c"}) diff --git a/dace/libraries/lapack/environments/cusolverdn.py b/dace/libraries/lapack/environments/cusolverdn.py index c92c8bf3e7..4daad8062e 100644 --- a/dace/libraries/lapack/environments/cusolverdn.py +++ b/dace/libraries/lapack/environments/cusolverdn.py @@ -24,7 +24,7 @@ class cuSolverDn: def handle_setup_code(node): location = node.location if not location or "gpu" not in node.location: - location = 0 + location = -1 # -1 means current device else: try: location = int(location["gpu"]) diff --git a/dace/libraries/lapack/include/dace_cusolverdn.h b/dace/libraries/lapack/include/dace_cusolverdn.h index 2da65ffa2f..f262541f0b 100644 --- a/dace/libraries/lapack/include/dace_cusolverdn.h +++ b/dace/libraries/lapack/include/dace_cusolverdn.h @@ -21,8 +21,10 @@ static void CheckCusolverDnError(cusolverStatus_t const& status) { } static cusolverDnHandle_t CreateCusolverDnHandle(int device) { - if (cudaSetDevice(device) != cudaSuccess) { - throw std::runtime_error("Failed to set CUDA device."); + if (device >= 0) { + if (cudaSetDevice(device) != cudaSuccess) { + throw std::runtime_error("Failed to set CUDA device."); + } } cusolverDnHandle_t handle; CheckCusolverDnError(cusolverDnCreate(&handle)); diff --git a/dace/libraries/linalg/environments/cutensor.py b/dace/libraries/linalg/environments/cutensor.py index e3572a0673..0022ec1f57 100644 --- a/dace/libraries/linalg/environments/cutensor.py +++ b/dace/libraries/linalg/environments/cutensor.py @@ -24,7 +24,7 @@ class cuTensor: def handle_setup_code(node): location = node.location if not location or "gpu" not in node.location: - location = 0 + location = -1 # -1 means current device else: try: location = int(location["gpu"]) diff --git a/dace/libraries/linalg/include/dace_cutensor.h b/dace/libraries/linalg/include/dace_cutensor.h index 8079892285..ddad2feaa3 100644 --- a/dace/libraries/linalg/include/dace_cutensor.h +++ b/dace/libraries/linalg/include/dace_cutensor.h @@ -20,8 +20,10 @@ static void CheckCuTensorError(cutensorStatus_t const& status) { } static cutensorHandle_t CreateCuTensorHandle(int device) { - if (cudaSetDevice(device) != cudaSuccess) { - throw std::runtime_error("Failed to set CUDA device."); + if (device >= 0) { + if (cudaSetDevice(device) != cudaSuccess) { + throw std::runtime_error("Failed to set CUDA device."); + } } cutensorHandle_t handle; CheckCuTensorError(cutensorInit(&handle)); diff --git a/dace/libraries/sparse/environments/cusparse.py b/dace/libraries/sparse/environments/cusparse.py index 0970557944..a731f75bf7 100644 --- a/dace/libraries/sparse/environments/cusparse.py +++ b/dace/libraries/sparse/environments/cusparse.py @@ -24,7 +24,7 @@ class cuSPARSE: def handle_setup_code(node): location = node.location if not location or "gpu" not in node.location: - location = 0 + location = -1 # -1 means current device else: try: location = int(location["gpu"]) diff --git a/dace/libraries/sparse/include/dace_cusparse.h b/dace/libraries/sparse/include/dace_cusparse.h index 82470089e0..9d28bb4748 100644 --- a/dace/libraries/sparse/include/dace_cusparse.h +++ b/dace/libraries/sparse/include/dace_cusparse.h @@ -20,8 +20,10 @@ static void CheckCusparseError(cusparseStatus_t const& status) { } static cusparseHandle_t CreateCusparseHandle(int device) { - if (cudaSetDevice(device) != cudaSuccess) { - throw std::runtime_error("Failed to set CUDA device."); + if (device >= 0) { + if (cudaSetDevice(device) != cudaSuccess) { + throw std::runtime_error("Failed to set CUDA device."); + } } cusparseHandle_t handle; CheckCusparseError(cusparseCreate(&handle)); diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py index 0f76c7e252..dd026ea62c 100644 --- a/dace/libraries/standard/nodes/reduce.py +++ b/dace/libraries/standard/nodes/reduce.py @@ -1562,13 +1562,14 @@ class Reduce(dace.sdfg.nodes.LibraryNode): identity = Property(allow_none=True) def __init__(self, + name, wcr='lambda a, b: a', axes=None, identity=None, schedule=dtypes.ScheduleType.Default, debuginfo=None, **kwargs): - super().__init__(name='Reduce', **kwargs) + super().__init__(name=name, **kwargs) self.wcr = wcr self.axes = axes self.identity = identity @@ -1577,7 +1578,7 @@ def __init__(self, @staticmethod def from_json(json_obj, context=None): - ret = Reduce("lambda a, b: a", None) + ret = Reduce('reduce', 'lambda a, b: a', None) dace.serialize.set_properties_from_json(ret, json_obj, context=context) return ret diff --git a/dace/libraries/stencil/subscript_converter.py b/dace/libraries/stencil/subscript_converter.py index 8abb3fc6c8..d159b345cb 100644 --- a/dace/libraries/stencil/subscript_converter.py +++ b/dace/libraries/stencil/subscript_converter.py @@ -1,9 +1,34 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import ast +import sys from collections import defaultdict from typing import Tuple +if sys.version_info < (3, 8): + _simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) + BytesConstant = ast.Bytes + EllipsisConstant = ast.Ellipsis + NameConstant = ast.NameConstant + NumConstant = ast.Num + StrConstant = ast.Str +else: + _simple_ast_nodes = (ast.Constant, ast.Name) + BytesConstant = ast.Constant + EllipsisConstant = ast.Constant + NameConstant = ast.Constant + NumConstant = ast.Constant + StrConstant = ast.Constant + + +if sys.version_info < (3, 9): + Index = ast.Index + ExtSlice = ast.ExtSlice +else: + Index = type(None) + ExtSlice = type(None) + + class SubscriptConverter(ast.NodeTransformer): """ Finds all subscript accesses using constant indices in the given code, and @@ -67,9 +92,9 @@ def visit_Subscript(self, node: ast.Subscript): # This can be a bunch of different things, varying between Python 3.8 # and Python 3.9, so try hard to unpack it into an index we can use. index_tuple = node.slice - if isinstance(index_tuple, (ast.Subscript, ast.Index)): + if isinstance(index_tuple, (ast.Subscript, Index)): index_tuple = index_tuple.value - if isinstance(index_tuple, (ast.Constant, ast.Num)): + if isinstance(index_tuple, (ast.Constant, NumConstant)): index_tuple = (index_tuple, ) if isinstance(index_tuple, ast.Tuple): index_tuple = index_tuple.elts diff --git a/dace/memlet.py b/dace/memlet.py index 74a1320a3b..d448ca1134 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -512,22 +512,47 @@ def validate(self, sdfg, state): if self.data is not None and self.data not in sdfg.arrays: raise KeyError('Array "%s" not found in SDFG' % self.data) - def used_symbols(self, all_symbols: bool) -> Set[str]: + def used_symbols(self, all_symbols: bool, edge=None) -> Set[str]: """ Returns a set of symbols used in this edge's properties. :param all_symbols: If False, only returns the set of symbols that will be used in the generated code and are needed as arguments. + :param edge: If given, provides richer context-based tests for the case + of ``all_symbols=False``. """ # Symbolic properties are in volume, and the two subsets result = set() + view_edge = False if all_symbols: result |= set(map(str, self.volume.free_symbols)) - if self.src_subset: - result |= self.src_subset.free_symbols - - if self.dst_subset: - result |= self.dst_subset.free_symbols + elif edge is not None: # Not all symbols are requested, and an edge is given + view_edge = False + from dace.sdfg import nodes + if isinstance(edge.dst, nodes.CodeNode) or isinstance(edge.src, nodes.CodeNode): + view_edge = True + elif edge.dst_conn == 'views' and isinstance(edge.dst, nodes.AccessNode): + view_edge = True + elif edge.src_conn == 'views' and isinstance(edge.src, nodes.AccessNode): + view_edge = True + + if not view_edge: + if self.src_subset: + result |= self.src_subset.free_symbols + + if self.dst_subset: + result |= self.dst_subset.free_symbols + else: + # View edges do not require the end of the range nor strides + if self.src_subset: + for rb, _, _ in self.src_subset.ndrange(): + if symbolic.issymbolic(rb): + result |= set(map(str, rb.free_symbols)) + + if self.dst_subset: + for rb, _, _ in self.dst_subset.ndrange(): + if symbolic.issymbolic(rb): + result |= set(map(str, rb.free_symbols)) return result diff --git a/dace/properties.py b/dace/properties.py index 61e569341f..44f8b4fbcc 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -1001,8 +1001,11 @@ def get_free_symbols(self, defined_syms: Set[str] = None) -> Set[str]: if self.language == dace.dtypes.Language.Python: visitor = TaskletFreeSymbolVisitor(defined_syms) if self.code: - for stmt in self.code: - visitor.visit(stmt) + if isinstance(self.code, list): + for stmt in self.code: + visitor.visit(stmt) + else: + visitor.visit(self.code) return visitor.free_symbols return set() diff --git a/dace/sdfg/analysis/schedule_tree/__init__.py b/dace/sdfg/analysis/schedule_tree/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dace/sdfg/analysis/schedule_tree/passes.py b/dace/sdfg/analysis/schedule_tree/passes.py new file mode 100644 index 0000000000..cc33245875 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/passes.py @@ -0,0 +1,60 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Assortment of passes for schedule trees. +""" + +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from typing import Set + + +def remove_unused_and_duplicate_labels(stree: tn.ScheduleTreeScope): + """ + Removes unused and duplicate labels from the schedule tree. + + :param stree: The schedule tree to remove labels from. + """ + + class FindGotos(tn.ScheduleNodeVisitor): + + def __init__(self): + self.gotos: Set[str] = set() + + def visit_GotoNode(self, node: tn.GotoNode): + if node.target is not None: + self.gotos.add(node.target) + + class RemoveLabels(tn.ScheduleNodeTransformer): + + def __init__(self, labels_to_keep: Set[str]) -> None: + self.labels_to_keep = labels_to_keep + self.labels_seen = set() + + def visit_StateLabel(self, node: tn.StateLabel): + if node.state.name not in self.labels_to_keep: + return None + if node.state.name in self.labels_seen: + return None + self.labels_seen.add(node.state.name) + return node + + fg = FindGotos() + fg.visit(stree) + return RemoveLabels(fg.gotos).visit(stree) + + +def remove_empty_scopes(stree: tn.ScheduleTreeScope): + """ + Removes empty scopes from the schedule tree. + + :warning: This pass is not safe to use for for-loops, as it will remove indices that may be used after the loop. + """ + + class RemoveEmptyScopes(tn.ScheduleNodeTransformer): + + def visit_scope(self, node: tn.ScheduleTreeScope): + if len(node.children) == 0: + return None + + return self.generic_visit(node) + + return RemoveEmptyScopes().visit(stree) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py new file mode 100644 index 0000000000..917f748cb8 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -0,0 +1,743 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from collections import defaultdict +import copy +from typing import Dict, List, Set +import dace +from dace import data, subsets, symbolic +from dace.codegen import control_flow as cf +from dace.sdfg.sdfg import InterstateEdge, SDFG +from dace.sdfg.state import SDFGState +from dace.sdfg import utils as sdutil, graph as gr, nodes as nd +from dace.sdfg.replace import replace_datadesc_names +from dace.frontend.python.astutils import negate_expr +from dace.sdfg.analysis.schedule_tree import treenodes as tn, passes as stpasses +from dace.transformation.passes.analysis import StateReachability +from dace.transformation.helpers import unsqueeze_memlet +from dace.properties import CodeBlock +from dace.memlet import Memlet + +import networkx as nx +import time +import sys + +NODE_TO_SCOPE_TYPE = { + dace.nodes.MapEntry: tn.MapScope, + dace.nodes.ConsumeEntry: tn.ConsumeScope, + dace.nodes.PipelineEntry: tn.PipelineScope, +} + + +def dealias_sdfg(sdfg: SDFG): + """ + Renames all data containers in an SDFG tree (i.e., nested SDFGs) to use the same data descriptors + as the top-level SDFG. This function takes care of offsetting memlets and internal + uses of arrays such that there is one naming system, and no aliasing of managed memory. + + This function operates in-place. + + :param sdfg: The SDFG to operate on. + """ + for nsdfg in sdfg.all_sdfgs_recursive(): + + if not nsdfg.parent: + continue + + replacements: Dict[str, str] = {} + inv_replacements: Dict[str, List[str]] = {} + parent_edges: Dict[str, Memlet] = {} + to_unsqueeze: Set[str] = set() + + parent_sdfg = nsdfg.parent_sdfg + parent_state = nsdfg.parent + parent_node = nsdfg.parent_nsdfg_node + + for name, desc in nsdfg.arrays.items(): + if desc.transient: + continue + for edge in parent_state.edges_by_connector(parent_node, name): + parent_name = edge.data.data + assert parent_name in parent_sdfg.arrays + if name != parent_name: + replacements[name] = parent_name + parent_edges[name] = edge + if parent_name in inv_replacements: + inv_replacements[parent_name].append(name) + to_unsqueeze.add(parent_name) + else: + inv_replacements[parent_name] = [name] + break + + if to_unsqueeze: + for parent_name in to_unsqueeze: + parent_arr = parent_sdfg.arrays[parent_name] + if isinstance(parent_arr, data.View): + parent_arr = data.Array(parent_arr.dtype, parent_arr.shape, parent_arr.transient, + parent_arr.allow_conflicts, parent_arr.storage, parent_arr.location, + parent_arr.strides, parent_arr.offset, parent_arr.may_alias, + parent_arr.lifetime, parent_arr.alignment, parent_arr.debuginfo, + parent_arr.total_size, parent_arr.start_offset, parent_arr.optional, + parent_arr.pool) + elif isinstance(parent_arr, data.StructureView): + parent_arr = data.Structure(parent_arr.members, parent_arr.name, parent_arr.transient, + parent_arr.storage, parent_arr.location, parent_arr.lifetime, + parent_arr.debuginfo) + child_names = inv_replacements[parent_name] + for name in child_names: + child_arr = copy.deepcopy(parent_arr) + child_arr.transient = False + nsdfg.arrays[name] = child_arr + for state in nsdfg.states(): + for e in state.edges(): + if not state.is_leaf_memlet(e): + continue + + mpath = state.memlet_path(e) + src, dst = mpath[0].src, mpath[-1].dst + + # We need to take directionality of the memlet into account and unsqueeze either to source or + # destination subset + if isinstance(src, nd.AccessNode) and src.data in child_names: + src_data = src.data + new_src_memlet = unsqueeze_memlet(e.data, parent_edges[src.data].data, use_src_subset=True) + else: + src_data = None + new_src_memlet = None + # We need to take directionality of the memlet into account + if isinstance(dst, nd.AccessNode) and dst.data in child_names: + dst_data = dst.data + new_dst_memlet = unsqueeze_memlet(e.data, parent_edges[dst.data].data, use_dst_subset=True) + else: + dst_data = None + new_dst_memlet = None + + if new_src_memlet is not None: + e.data.src_subset = new_src_memlet.subset + if new_dst_memlet is not None: + e.data.dst_subset = new_dst_memlet.subset + if e.data.data == src_data: + e.data.data = new_src_memlet.data + elif e.data.data == dst_data: + e.data.data = new_dst_memlet.data + + for e in nsdfg.edges(): + repl_dict = dict() + syms = e.data.read_symbols() + for memlet in e.data.get_read_memlets(nsdfg.arrays): + if memlet.data in child_names: + repl_dict[str(memlet)] = unsqueeze_memlet(memlet, parent_edges[memlet.data].data) + if memlet.data in syms: + syms.remove(memlet.data) + for s in syms: + if s in parent_edges: + repl_dict[s] = str(parent_edges[s].data) + e.data.replace_dict(repl_dict) + for name in child_names: + edge = parent_edges[name] + for e in parent_state.memlet_tree(edge): + if e.data.data == parent_name: + e.data.subset = subsets.Range.from_array(parent_arr) + else: + e.data.other_subset = subsets.Range.from_array(parent_arr) + + if replacements: + symbolic.safe_replace(replacements, lambda d: replace_datadesc_names(nsdfg, d), value_as_string=True) + parent_node.in_connectors = { + replacements[c] if c in replacements else c: t + for c, t in parent_node.in_connectors.items() + } + parent_node.out_connectors = { + replacements[c] if c in replacements else c: t + for c, t in parent_node.out_connectors.items() + } + for e in parent_state.all_edges(parent_node): + if e.src_conn in replacements: + e._src_conn = replacements[e.src_conn] + elif e.dst_conn in replacements: + e._dst_conn = replacements[e.dst_conn] + + +def normalize_memlet(sdfg: SDFG, state: SDFGState, original: gr.MultiConnectorEdge[Memlet], data: str) -> Memlet: + """ + Normalizes a memlet to a given data descriptor. + + :param sdfg: The SDFG. + :param state: The state. + :param original: The original memlet. + :param data: The data descriptor. + :return: A new memlet. + """ + # Shallow copy edge + edge = gr.MultiConnectorEdge(original.src, original.src_conn, original.dst, original.dst_conn, + copy.deepcopy(original.data), original.key) + edge.data.try_initialize(sdfg, state, edge) + + if '.' in edge.data.data and edge.data.data.startswith(data + '.'): + return edge.data + if edge.data.data == data: + return edge.data + + memlet = edge.data + if memlet._is_data_src: + new_subset, new_osubset = memlet.get_dst_subset(edge, state), memlet.get_src_subset(edge, state) + else: + new_subset, new_osubset = memlet.get_src_subset(edge, state), memlet.get_dst_subset(edge, state) + + memlet.data = data + memlet.subset = new_subset + memlet.other_subset = new_osubset + memlet._is_data_src = True + return memlet + + +def replace_memlets(sdfg: SDFG, input_mapping: Dict[str, Memlet], output_mapping: Dict[str, Memlet]): + """ + Replaces all uses of data containers in memlets and interstate edges in an SDFG. + :param sdfg: The SDFG. + :param input_mapping: A mapping from internal data descriptor names to external input memlets. + :param output_mapping: A mapping from internal data descriptor names to external output memlets. + """ + for state in sdfg.states(): + for e in state.edges(): + mpath = state.memlet_path(e) + src = mpath[0].src + dst = mpath[-1].dst + memlet = e.data + if isinstance(src, dace.nodes.AccessNode) and src.data in input_mapping: + src_data = src.data + src_memlet = unsqueeze_memlet(memlet, input_mapping[src.data], use_src_subset=True) + else: + src_data = None + src_memlet = None + if isinstance(dst, dace.nodes.AccessNode) and dst.data in output_mapping: + dst_data = dst.data + dst_memlet = unsqueeze_memlet(memlet, output_mapping[dst.data], use_dst_subset=True) + else: + dst_data = None + dst_memlet = None + + # Other cases (code->code) + if src_data is None and dst_data is None: + if e.data.data in input_mapping: + memlet = unsqueeze_memlet(memlet, input_mapping[e.data.data]) + elif e.data.data in output_mapping: + memlet = unsqueeze_memlet(memlet, output_mapping[e.data.data]) + e.data = memlet + else: + if src_memlet is not None: + memlet.src_subset = src_memlet.subset + if dst_memlet is not None: + memlet.dst_subset = dst_memlet.subset + if memlet.data == src_data: + memlet.data = src_memlet.data + elif memlet.data == dst_data: + memlet.data = dst_memlet.data + + for e in sdfg.edges(): + repl_dict = dict() + syms = e.data.read_symbols() + for memlet in e.data.get_read_memlets(sdfg.arrays): + if memlet.data in input_mapping or memlet.data in output_mapping: + # If array name is both in the input connectors and output connectors with different + # memlets, this is undefined behavior. Prefer output + if memlet.data in input_mapping: + mapping = input_mapping + if memlet.data in output_mapping: + mapping = output_mapping + + repl_dict[str(memlet)] = str(unsqueeze_memlet(memlet, mapping[memlet.data])) + if memlet.data in syms: + syms.remove(memlet.data) + for s in syms: + if s in input_mapping: + repl_dict[s] = str(input_mapping[s]) + + # Manual replacement with strings + # TODO(later): Would be MUCH better to use MemletReplacer / e.data.replace_dict(repl_dict, replace_keys=False) + for find, replace in repl_dict.items(): + for k, v in e.data.assignments.items(): + if find in v: + e.data.assignments[k] = v.replace(find, replace) + condstr = e.data.condition.as_string + if find in condstr: + e.data.condition.as_string = condstr.replace(find, replace) + + +def remove_name_collisions(sdfg: SDFG): + """ + Removes name collisions in nested SDFGs by renaming states, data containers, and symbols. + + :param sdfg: The SDFG. + """ + state_names_seen = set() + identifiers_seen = set() + + for nsdfg in sdfg.all_sdfgs_recursive(): + # Rename duplicate states + for state in nsdfg.nodes(): + if state.label in state_names_seen: + state.set_label(data.find_new_name(state.label, state_names_seen)) + state_names_seen.add(state.label) + + replacements: Dict[str, str] = {} + parent_node = nsdfg.parent_nsdfg_node + + # Preserve top-level SDFG names + do_not_replace = False + if not parent_node: + do_not_replace = True + + # Rename duplicate data containers + for name, desc in nsdfg.arrays.items(): + if name in identifiers_seen: + if not desc.transient or do_not_replace: + continue + + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # Rename duplicate top-level symbols + for name in nsdfg.get_all_toplevel_symbols(): + # Will already be renamed during conversion + if parent_node is not None and name in parent_node.symbol_mapping: + continue + + if name in identifiers_seen and not do_not_replace: + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # Rename duplicate constants + for name in nsdfg.constants_prop.keys(): + if name in identifiers_seen and not do_not_replace: + new_name = data.find_new_name(name, identifiers_seen) + replacements[name] = new_name + name = new_name + identifiers_seen.add(name) + + # If there is a name collision, replace all uses of the old names with the new names + if replacements: + nsdfg.replace_dict(replacements) + + +def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, + viewed_name: str) -> tn.ViewNode: + """ + Helper function to create a view schedule tree node from a memlet edge. + """ + sdfg = state.parent + normalized = normalize_memlet(sdfg, state, edge, viewed_name) + return tn.ViewNode(target=view_name, + source=viewed_name, + memlet=normalized, + src_desc=sdfg.arrays[viewed_name], + view_desc=sdfg.arrays[view_name]) + + +def replace_symbols_until_set(nsdfg: dace.nodes.NestedSDFG): + """ + Replaces symbol values in a nested SDFG until their value has been reset. This is used for matching symbol + namespaces between an SDFG and a nested SDFG. + """ + mapping = nsdfg.symbol_mapping + sdfg = nsdfg.sdfg + reachable_states = StateReachability().apply_pass(sdfg, {})[sdfg.sdfg_id] + redefined_symbols: Dict[SDFGState, Set[str]] = defaultdict(set) + + # Collect redefined symbols + for e in sdfg.edges(): + redefined = e.data.assignments.keys() + redefined_symbols[e.dst] |= redefined + for reachable in reachable_states[e.dst]: + redefined_symbols[reachable] |= redefined + + # Replace everything but the redefined symbols + for state in sdfg.nodes(): + per_state_mapping = {k: v for k, v in mapping.items() if k not in redefined_symbols[state]} + symbolic.safe_replace(per_state_mapping, state.replace_dict) + for e in sdfg.out_edges(state): + symbolic.safe_replace(per_state_mapping, lambda d: e.data.replace_dict(d, replace_keys=False)) + + +def prepare_schedule_tree_edges(state: SDFGState) -> Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode]: + """ + Creates a dictionary mapping edges to their corresponding schedule tree nodes, if relevant. + This handles view edges, reference sets, and dynamic map inputs. + + :param state: The state. + """ + result: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] = {} + scope_to_edges: Dict[nd.EntryNode, List[gr.MultiConnectorEdge[Memlet]]] = defaultdict(list) + edges_to_ignore = set() + sdfg = state.parent + + for edge in state.edges(): + if edge in edges_to_ignore or edge in result: + continue + if edge.data.is_empty(): # Ignore empty memlets + edges_to_ignore.add(edge) + continue + + # Part of a memlet path - only consider innermost memlets + mtree = state.memlet_tree(edge) + all_edges = set(e for e in mtree) + leaves = set(mtree.leaves()) + edges_to_ignore.update(all_edges - leaves) + + # For every tree leaf, create a copy/view/reference set node as necessary + for e in leaves: + if e in edges_to_ignore or e in result: + continue + + # 1. Check for views + if isinstance(e.src, dace.nodes.AccessNode): + desc = e.src.desc(sdfg) + if isinstance(desc, (dace.data.View, dace.data.StructureView)): + vedge = sdutil.get_view_edge(state, e.src) + if e is vedge: + viewed_node = sdutil.get_view_node(state, e.src) + result[e] = _make_view_node(state, e, e.src.data, viewed_node.data) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + continue + if isinstance(e.dst, dace.nodes.AccessNode): + desc = e.dst.desc(sdfg) + if isinstance(desc, (dace.data.View, dace.data.StructureView)): + vedge = sdutil.get_view_edge(state, e.dst) + if e is vedge: + viewed_node = sdutil.get_view_node(state, e.dst) + result[e] = _make_view_node(state, e, e.dst.data, viewed_node.data) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + continue + + # 2. Check for reference sets + if isinstance(e.dst, dace.nodes.AccessNode) and e.dst_conn == 'set': + assert isinstance(e.dst.desc(sdfg), dace.data.Reference) + result[e] = tn.RefSetNode(target=e.dst.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + ref_desc=sdfg.arrays[e.dst.data]) + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + continue + + # 3. Check for copies + # Get both ends of the memlet path + mpath = state.memlet_path(e) + src = mpath[0].src + dst = mpath[-1].dst + if not isinstance(src, dace.nodes.AccessNode): + continue + if not isinstance(dst, (dace.nodes.AccessNode, dace.nodes.EntryNode)): + continue + + # If the edge destination is the innermost node, it is a downward-pointing path + is_target_dst = e.dst is dst + + innermost_node = dst if is_target_dst else src + outermost_node = src if is_target_dst else dst + + # Normalize memlets to their innermost node, or source->destination if it is a same-scope edge + if e.src is src and e.dst is dst: + outermost_node = src + innermost_node = dst + + if isinstance(dst, dace.nodes.EntryNode): + # Special case: dynamic map range has no data + result[e] = tn.DynScopeCopyNode(target=e.dst_conn, memlet=e.data) + else: + target_name = innermost_node.data + new_memlet = normalize_memlet(sdfg, state, e, outermost_node.data) + result[e] = tn.CopyNode(target=target_name, memlet=new_memlet) + + scope = state.entry_node(e.dst if mtree.downwards else e.src) + scope_to_edges[scope].append(e) + + return result, scope_to_edges + + +def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: + """ + Use scope-aware topological sort to get nodes by scope and return the schedule tree of this state. + + :param state: The state. + :return: A string for the whole state + """ + result: List[tn.ScheduleTreeNode] = [] + sdfg = state.parent + + edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode] + scope_to_edges: Dict[nd.EntryNode, List[gr.MultiConnectorEdge[Memlet]]] + edge_to_stree, scope_to_edges = prepare_schedule_tree_edges(state) + edges_to_ignore = set() + + # Handle all unscoped edges to generate output views + views = _generate_views_in_scope(scope_to_edges[None], edge_to_stree, sdfg, state) + result.extend(views) + + scopes: List[List[tn.ScheduleTreeNode]] = [] + for node in sdutil.scope_aware_topological_sort(state): + if isinstance(node, dace.nodes.EntryNode): + # Handle dynamic scope inputs + for e in state.in_edges(node): + if e in edges_to_ignore: + continue + if e in edge_to_stree: + result.append(edge_to_stree[e]) + edges_to_ignore.add(e) + + # Handle all scoped edges to generate (views) + views = _generate_views_in_scope(scope_to_edges[node], edge_to_stree, sdfg, state) + result.extend(views) + + # Create scope node and add to stack + scopes.append(result) + subnodes = [] + result.append(NODE_TO_SCOPE_TYPE[type(node)](node=node, children=subnodes)) + result = subnodes + elif isinstance(node, dace.nodes.ExitNode): + result = scopes.pop() + elif isinstance(node, dace.nodes.NestedSDFG): + nested_array_mapping_input = {} + nested_array_mapping_output = {} + generated_nviews = set() + + # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG + replace_symbols_until_set(node) + + # Create memlets for nested SDFG mapping, or nview schedule nodes if slice cannot be determined + for e in state.all_edges(node): + conn = e.dst_conn if e.dst is node else e.src_conn + if e.data.is_empty() or not conn: + continue + res = sdutil.map_view_to_array(node.sdfg.arrays[conn], sdfg.arrays[e.data.data], e.data.subset) + no_mapping = False + if res is None: + no_mapping = True + else: + mapping, expanded, squeezed = res + if expanded: # "newaxis" slices will be seen as views (for now) + no_mapping = True + else: + if e.dst is node: + nested_array_mapping_input[conn] = e.data + else: + nested_array_mapping_output[conn] = e.data + + if no_mapping: # Must use view (nview = nested SDFG view) + if conn not in generated_nviews: + result.append( + tn.NView(target=conn, + source=e.data.data, + memlet=e.data, + src_desc=sdfg.arrays[e.data.data], + view_desc=node.sdfg.arrays[conn])) + generated_nviews.add(conn) + + replace_memlets(node.sdfg, nested_array_mapping_input, nested_array_mapping_output) + + # Insert the nested SDFG flattened + nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) + result.extend(nested_stree.children) + elif isinstance(node, dace.nodes.Tasklet): + in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} + out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + result.append(tn.TaskletNode(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + elif isinstance(node, dace.nodes.LibraryNode): + # NOTE: LibraryNodes do not necessarily have connectors + if node.in_connectors: + in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} + else: + in_memlets = set([e.data for e in state.in_edges(node)]) + if node.out_connectors: + out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} + else: + out_memlets = set([e.data for e in state.out_edges(node)]) + result.append(tn.LibraryCall(node=node, in_memlets=in_memlets, out_memlets=out_memlets)) + elif isinstance(node, dace.nodes.AccessNode): + # If one of the neighboring edges has a schedule tree node attached to it, use that + # (except for views, which were generated above) + for e in state.all_edges(node): + if e in edges_to_ignore: + continue + if e in edge_to_stree: + if isinstance(edge_to_stree[e], tn.ViewNode): + continue + result.append(edge_to_stree[e]) + edges_to_ignore.add(e) + + assert len(scopes) == 0 + + return result + + +def _generate_views_in_scope(edges: List[gr.MultiConnectorEdge[Memlet]], + edge_to_stree: Dict[gr.MultiConnectorEdge[Memlet], tn.ScheduleTreeNode], sdfg: SDFG, + state: SDFGState) -> List[tn.ScheduleTreeNode]: + """ + Generates all view and reference set edges in the correct order. This function is intended to be used + at the beginning of a scope. + """ + result: List[tn.ScheduleTreeNode] = [] + + # Make a dependency graph of all the views + g = nx.DiGraph() + node_to_stree = {} + for e in edges: + if e not in edge_to_stree: + continue + st = edge_to_stree[e] + if not isinstance(st, tn.ViewNode): + continue + g.add_edge(st.source, st.target) + node_to_stree[st.target] = st + + # Traverse in order and deduplicate + already_generated = set() + for n in nx.topological_sort(g): + if n in node_to_stree and n not in already_generated: + result.append(node_to_stree[n]) + already_generated.add(n) + + return result + + +def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope: + """ + Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of + the SDFG. + Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, etc.) + or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + + It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, + erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the + ``from_schedule_tree`` function. + + :param sdfg: The SDFG to convert. + :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might not be + usable after the conversion if ``in_place`` is True! + :return: A schedule tree representing the given SDFG. + """ + from dace.transformation import helpers as xfh # Avoid import loop + + if not in_place: + sdfg = copy.deepcopy(sdfg) + + # Prepare SDFG for conversion + ############################# + + # Split edges with assignments and conditions + xfh.split_interstate_edges(sdfg) + + # Replace code->code edges with data<->code edges + xfh.replace_code_to_code_edges(sdfg) + + if toplevel: # Top-level SDFG preparation (only perform once) + dealias_sdfg(sdfg) + # Handle name collisions (in arrays, state labels, symbols) + remove_name_collisions(sdfg) + + ############################# + + # Create initial tree from CFG + cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') + + # Traverse said tree (also into states) to create the schedule tree + def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]: + result: List[tn.ScheduleTreeNode] = [] + if isinstance(node, cf.GeneralBlock): + subnodes: List[tn.ScheduleTreeNode] = [] + for n in node.elements: + subnodes.extend(totree(n, node)) + if not node.sequential: + # Nest in general block + result = [tn.GBlock(children=subnodes)] + else: + # Use the sub-nodes directly + result = subnodes + + elif isinstance(node, cf.SingleState): + result = state_schedule_tree(node.state) + + # Add interstate assignments unrelated to structured control flow + if parent is not None: + for e in sdfg.out_edges(node.state): + edge_body = [] + + if e not in parent.assignments_to_ignore: + for aname, aval in e.data.assignments.items(): + edge_body.append( + tn.AssignNode(name=aname, + value=CodeBlock(aval), + edge=InterstateEdge(assignments={aname: aval}))) + + if not parent.sequential: + if e not in parent.gotos_to_ignore: + edge_body.append(tn.GotoNode(target=e.dst.label)) + else: + if e in parent.gotos_to_break: + edge_body.append(tn.BreakNode()) + elif e in parent.gotos_to_continue: + edge_body.append(tn.ContinueNode()) + + if e not in parent.gotos_to_ignore and not e.data.is_unconditional(): + if sdfg.out_degree(node.state) == 1 and parent.sequential: + # Conditional state in sequential block! Add "if not condition goto exit" + result.append( + tn.StateIfScope(condition=CodeBlock(negate_expr(e.data.condition)), + children=[tn.GotoNode(target=None)])) + result.extend(edge_body) + else: + # Add "if condition" with the body above + result.append(tn.StateIfScope(condition=e.data.condition, children=edge_body)) + else: + result.extend(edge_body) + + elif isinstance(node, cf.ForScope): + result.append(tn.ForScope(header=node, children=totree(node.body))) + elif isinstance(node, cf.IfScope): + result.append(tn.IfScope(condition=node.condition, children=totree(node.body))) + if node.orelse is not None: + result.append(tn.ElseScope(children=totree(node.orelse))) + elif isinstance(node, cf.IfElseChain): + # Add "if" for the first condition, "elif"s for the rest + result.append(tn.IfScope(condition=node.body[0][0], children=totree(node.body[0][1]))) + for cond, body in node.body[1:]: + result.append(tn.ElifScope(condition=cond, children=totree(body))) + # "else goto exit" + result.append(tn.ElseScope(children=[tn.GotoNode(target=None)])) + elif isinstance(node, cf.WhileScope): + result.append(tn.WhileScope(header=node, children=totree(node.body))) + elif isinstance(node, cf.DoWhileScope): + result.append(tn.DoWhileScope(header=node, children=totree(node.body))) + else: + # e.g., "SwitchCaseScope" + raise tn.UnsupportedScopeException(type(node).__name__) + + if node.first_state is not None: + result = [tn.StateLabel(state=node.first_state)] + result + + return result + + # Recursive traversal of the control flow tree + result = tn.ScheduleTreeScope(children=totree(cfg)) + + # Clean up tree + stpasses.remove_unused_and_duplicate_labels(result) + + return result + + +if __name__ == '__main__': + s = time.time() + sdfg = SDFG.from_file(sys.argv[1]) + print('Loaded SDFG in', time.time() - s, 'seconds') + s = time.time() + stree = as_schedule_tree(sdfg, in_place=True) + print('Created schedule tree in', time.time() - s, 'seconds') + + with open('output_stree.txt', 'w') as fp: + fp.write(stree.as_string(-1) + '\n') diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py new file mode 100644 index 0000000000..99918cd2a4 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -0,0 +1,408 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from dataclasses import dataclass, field +from dace import nodes, data, subsets +from dace.codegen import control_flow as cf +from dace.properties import CodeBlock +from dace.sdfg import InterstateEdge +from dace.sdfg.state import SDFGState +from dace.symbolic import symbol +from dace.memlet import Memlet +from typing import Dict, Iterator, List, Optional, Set, Union + +INDENTATION = ' ' + + +class UnsupportedScopeException(Exception): + pass + + +@dataclass +class ScheduleTreeNode: + parent: Optional['ScheduleTreeScope'] = field(default=None, init=False) + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'UNSUPPORTED' + + def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: + """ + Traverse tree nodes in a pre-order manner. + """ + yield self + + +@dataclass +class ScheduleTreeScope(ScheduleTreeNode): + children: List['ScheduleTreeNode'] + containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) + symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) + + def __init__(self, + children: Optional[List['ScheduleTreeNode']] = None): + self.children = children or [] + if self.children: + for child in children: + child.parent = self + + def as_string(self, indent: int = 0): + if not self.children: + return (indent + 1) * INDENTATION + 'pass' + return '\n'.join([child.as_string(indent + 1) for child in self.children]) + + def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: + """ + Traverse tree nodes in a pre-order manner. + """ + yield from super().preorder_traversal() + for child in self.children: + yield from child.preorder_traversal() + + # TODO: Helper function that gets input/output memlets of the scope + + +@dataclass +class ControlFlowScope(ScheduleTreeScope): + pass + + +@dataclass +class DataflowScope(ScheduleTreeScope): + node: nodes.EntryNode + + +@dataclass +class GBlock(ControlFlowScope): + """ + General control flow block. Contains a list of states + that can run in arbitrary order based on edges (gotos). + Normally contains irreducible control flow. + """ + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + 'gblock:\n' + return result + super().as_string(indent) + + +@dataclass +class StateLabel(ScheduleTreeNode): + state: SDFGState + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'label {self.state.name}:' + + +@dataclass +class GotoNode(ScheduleTreeNode): + target: Optional[str] = None #: If None, equivalent to "goto exit" or "return" + + def as_string(self, indent: int = 0): + name = self.target or 'exit' + return indent * INDENTATION + f'goto {name}' + + +@dataclass +class AssignNode(ScheduleTreeNode): + """ + Represents a symbol assignment that is not part of a structured control flow block. + """ + name: str + value: CodeBlock + edge: InterstateEdge + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' + + +@dataclass +class ForScope(ControlFlowScope): + """ + For loop scope. + """ + header: cf.ForScope + + def as_string(self, indent: int = 0): + node = self.header + + result = (indent * INDENTATION + f'for {node.itervar} = {node.init}; {node.condition.as_string}; ' + f'{node.itervar} = {node.update}:\n') + return result + super().as_string(indent) + + +@dataclass +class WhileScope(ControlFlowScope): + """ + While loop scope. + """ + header: cf.WhileScope + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'while {self.header.test.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class DoWhileScope(ControlFlowScope): + """ + Do/While loop scope. + """ + header: cf.DoWhileScope + + def as_string(self, indent: int = 0): + header = indent * INDENTATION + 'do:\n' + footer = indent * INDENTATION + f'while {self.header.test.as_string}\n' + return header + super().as_string(indent) + footer + + +@dataclass +class IfScope(ControlFlowScope): + """ + If branch scope. + """ + condition: CodeBlock + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'if {self.condition.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class StateIfScope(IfScope): + """ + A special class of an if scope in general blocks for if statements that are part of a state transition. + """ + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' + return result + super(IfScope, self).as_string(indent) + + +@dataclass +class BreakNode(ScheduleTreeNode): + """ + Represents a break statement. + """ + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'break' + + +@dataclass +class ContinueNode(ScheduleTreeNode): + """ + Represents a continue statement. + """ + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'continue' + + +@dataclass +class ElifScope(ControlFlowScope): + """ + Else-if branch scope. + """ + condition: CodeBlock + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' + return result + super().as_string(indent) + + +@dataclass +class ElseScope(ControlFlowScope): + """ + Else branch scope. + """ + + def as_string(self, indent: int = 0): + result = indent * INDENTATION + 'else:\n' + return result + super().as_string(indent) + + +@dataclass +class MapScope(DataflowScope): + """ + Map scope. + """ + + def as_string(self, indent: int = 0): + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' + return result + super().as_string(indent) + + +@dataclass +class ConsumeScope(DataflowScope): + """ + Consume scope. + """ + + def as_string(self, indent: int = 0): + node: nodes.ConsumeEntry = self.node + cond = 'stream not empty' if node.consume.condition is None else node.consume.condition.as_string + result = indent * INDENTATION + f'consume (PE {node.consume.pe_index} out of {node.consume.num_pes}) while {cond}:\n' + return result + super().as_string(indent) + + +@dataclass +class PipelineScope(DataflowScope): + """ + Pipeline scope. + """ + + def as_string(self, indent: int = 0): + rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) + result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' + return result + super().as_string(indent) + + +@dataclass +class TaskletNode(ScheduleTreeNode): + node: nodes.Tasklet + in_memlets: Dict[str, Memlet] + out_memlets: Dict[str, Memlet] + + def as_string(self, indent: int = 0): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) + out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + if not out_memlets: + return indent * INDENTATION + f'tasklet({in_memlets})' + return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' + + +@dataclass +class LibraryCall(ScheduleTreeNode): + node: nodes.LibraryNode + in_memlets: Union[Dict[str, Memlet], Set[Memlet]] + out_memlets: Union[Dict[str, Memlet], Set[Memlet]] + + def as_string(self, indent: int = 0): + if isinstance(self.in_memlets, set): + in_memlets = ', '.join(f'{v}' for v in self.in_memlets) + else: + in_memlets = ', '.join(f'{v}' for v in self.in_memlets.values()) + if isinstance(self.out_memlets, set): + out_memlets = ', '.join(f'{v}' for v in self.out_memlets) + else: + out_memlets = ', '.join(f'{v}' for v in self.out_memlets.values()) + libname = type(self.node).__name__ + # Get the properties of the library node without its superclasses + own_properties = ', '.join(f'{k}={getattr(self.node, k)}' for k, v in self.node.__properties__.items() + if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) + return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + + +@dataclass +class CopyNode(ScheduleTreeNode): + target: str + memlet: Memlet + + def as_string(self, indent: int = 0): + if self.memlet.other_subset is not None and any(s != 0 for s in self.memlet.other_subset.min_element()): + offset = f'[{self.memlet.other_subset}]' + else: + offset = '' + if self.memlet.wcr is not None: + wcr = f' with {self.memlet.wcr}' + else: + wcr = '' + + return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' + + +@dataclass +class DynScopeCopyNode(ScheduleTreeNode): + """ + A special case of a copy node that is used in dynamic scope inputs (e.g., dynamic map ranges). + """ + target: str + memlet: Memlet + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' + + +@dataclass +class ViewNode(ScheduleTreeNode): + target: str #: View name + source: str #: Viewed container name + memlet: Memlet + src_desc: data.Data + view_desc: data.Data + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' + + +@dataclass +class NView(ViewNode): + """ + Nested SDFG view node. Subclass of a view that specializes in nested SDFG boundaries. + """ + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = nview {self.memlet} as {self.view_desc.shape}' + + +@dataclass +class RefSetNode(ScheduleTreeNode): + """ + Reference set node. Sets a reference to a data container. + """ + target: str + memlet: Memlet + src_desc: data.Data + ref_desc: data.Data + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' + + +# Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes +class ScheduleNodeVisitor: + + def visit(self, node: ScheduleTreeNode): + """Visit a node.""" + if isinstance(node, list): + return [self.visit(snode) for snode in node] + if isinstance(node, ScheduleTreeScope) and hasattr(self, 'visit_scope'): + return self.visit_scope(node) + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + def generic_visit(self, node: ScheduleTreeNode): + if isinstance(node, ScheduleTreeScope): + for child in node.children: + self.visit(child) + + +class ScheduleNodeTransformer(ScheduleNodeVisitor): + + def visit(self, node: ScheduleTreeNode): + if isinstance(node, list): + result = [] + for snode in node: + new_node = self.visit(snode) + if new_node is not None: + result.append(new_node) + return result + + return super().visit(node) + + def generic_visit(self, node: ScheduleTreeNode): + new_values = [] + if isinstance(node, ScheduleTreeScope): + for value in node.children: + if isinstance(value, ScheduleTreeNode): + value = self.visit(value) + if value is None: + continue + elif not isinstance(value, ScheduleTreeNode): + new_values.extend(value) + continue + new_values.append(value) + for val in new_values: + val.parent = node + node.children[:] = new_values + return node diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py new file mode 100644 index 0000000000..59a2c178d2 --- /dev/null +++ b/dace/sdfg/memlet_utils.py @@ -0,0 +1,79 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import ast +from dace.frontend.python import memlet_parser +from dace import data, Memlet +from typing import Callable, Dict, Optional, Set, Union + + +class MemletReplacer(ast.NodeTransformer): + """ + Iterates over all memlet expressions (name or subscript with matching array in SDFG) in a code block. + The callable can also return another memlet to replace the current one. + """ + + def __init__(self, + arrays: Dict[str, data.Data], + process: Callable[[Memlet], Union[Memlet, None]], + array_filter: Optional[Set[str]] = None) -> None: + """ + Create a new memlet replacer. + + :param arrays: A mapping from array names to data descriptors. + :param process: A callable that takes a memlet and returns a memlet or None. + :param array_filter: An optional subset of array names to process. + """ + self.process = process + self.arrays = arrays + self.array_filter = array_filter or self.arrays.keys() + + def _parse_memlet(self, node: Union[ast.Name, ast.Subscript]) -> Memlet: + """ + Parses a memlet from a subscript or name node. + + :param node: The node to parse. + :return: The parsed memlet. + """ + # Get array name + if isinstance(node, ast.Name): + data = node.id + elif isinstance(node, ast.Subscript): + data = node.value.id + else: + raise TypeError('Expected Name or Subscript') + + # Parse memlet subset + array = self.arrays[data] + subset, newaxes, _ = memlet_parser.parse_memlet_subset(array, node, self.arrays) + if newaxes: + raise NotImplementedError('Adding new axes to memlets is not supported') + + return Memlet(data=data, subset=subset) + + def _memlet_to_ast(self, memlet: Memlet) -> ast.Subscript: + """ + Converts a memlet to a subscript node. + + :param memlet: The memlet to convert. + :return: The converted node. + """ + return ast.parse(f'{memlet.data}[{memlet.subset}]').body[0].value + + def _replace(self, node: Union[ast.Name, ast.Subscript]) -> ast.Subscript: + cur_memlet = self._parse_memlet(node) + new_memlet = self.process(cur_memlet) + if new_memlet is None: + return node + + new_node = self._memlet_to_ast(new_memlet) + return ast.copy_location(new_node, node) + + def visit_Name(self, node: ast.Name): + if node.id in self.array_filter: + return self._replace(node) + return self.generic_visit(node) + + def visit_Subscript(self, node: ast.Subscript): + if isinstance(node.value, ast.Name) and node.value.id in self.array_filter: + return self._replace(node) + return self.generic_visit(node) diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 28431deeea..32369a19a3 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -342,6 +342,10 @@ class Tasklet(CodeNode): 'additional side effects on the system state (e.g., callback). ' 'Defaults to None, which lets the framework make assumptions based on ' 'the tasklet contents') + ignored_symbols = SetProperty(element_type=str, desc='A set of symbols to ignore when computing ' + 'the symbols used by this tasklet. Used to skip certain symbols in non-Python ' + 'tasklets, where only string analysis is possible; and to skip globals in Python ' + 'tasklets that should not be given as parameters to the SDFG.') def __init__(self, label, @@ -355,6 +359,7 @@ def __init__(self, code_exit="", location=None, side_effects=None, + ignored_symbols=None, debuginfo=None): super(Tasklet, self).__init__(label, location, inputs, outputs) @@ -365,6 +370,7 @@ def __init__(self, self.code_init = CodeBlock(code_init, dtypes.Language.CPP) self.code_exit = CodeBlock(code_exit, dtypes.Language.CPP) self.side_effects = side_effects + self.ignored_symbols = ignored_symbols or set() self.debuginfo = debuginfo @property @@ -393,7 +399,11 @@ def validate(self, sdfg, state): @property def free_symbols(self) -> Set[str]: - return self.code.get_free_symbols(self.in_connectors.keys() | self.out_connectors.keys()) + symbols_to_ignore = self.in_connectors.keys() | self.out_connectors.keys() + symbols_to_ignore |= self.ignored_symbols + + return self.code.get_free_symbols(symbols_to_ignore) + def has_side_effects(self, sdfg) -> bool: """ @@ -581,16 +591,19 @@ def from_json(json_obj, context=None): return ret def used_symbols(self, all_symbols: bool) -> Set[str]: - free_syms = set().union(*(map(str, - pystr_to_symbolic(v).free_symbols) for v in self.symbol_mapping.values()), - *(map(str, - pystr_to_symbolic(v).free_symbols) for v in self.location.values())) + free_syms = set().union(*(map(str, pystr_to_symbolic(v).free_symbols) for v in self.location.values())) + + keys_to_use = set(self.symbol_mapping.keys()) # Filter out unused internal symbols from symbol mapping if not all_symbols: internally_used_symbols = self.sdfg.used_symbols(all_symbols=False) - free_syms &= internally_used_symbols - + keys_to_use &= internally_used_symbols + + free_syms |= set().union(*(map(str, + pystr_to_symbolic(v).free_symbols) for k, v in self.symbol_mapping.items() + if k in keys_to_use)) + return free_syms @property @@ -640,7 +653,7 @@ def validate(self, sdfg, state, references: Optional[Set[int]] = None, **context raise NameError('Data descriptor "%s" not found in nested SDFG connectors' % dname) if dname in connectors and desc.transient: raise NameError('"%s" is a connector but its corresponding array is transient' % dname) - + # Validate inout connectors from dace.sdfg import utils # Avoids circular import inout_connectors = self.in_connectors.keys() & self.out_connectors.keys() diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 0fec4812b7..0554775dcd 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1477,8 +1477,8 @@ def propagate_subset(memlets: List[Memlet], new_memlet.volume = simplify(sum(m.volume for m in memlets) * functools.reduce(lambda a, b: a * b, rng.size(), 1)) if any(m.dynamic for m in memlets): new_memlet.dynamic = True - elif symbolic.issymbolic(new_memlet.volume) and any(s not in defined_variables - for s in new_memlet.volume.free_symbols): + if symbolic.issymbolic(new_memlet.volume) and any(s not in defined_variables + for s in new_memlet.volume.free_symbols): new_memlet.dynamic = True new_memlet.volume = 0 diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 5e42830a75..4b36fad4fe 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -124,6 +124,7 @@ def replace_properties_dict(node: Any, if lang is dtypes.Language.CPP: # Replace in C++ code prefix = '' tokenized = tokenize_cpp.findall(code) + active_replacements = set() for name, new_name in reduced_repl.items(): if name not in tokenized: continue @@ -131,8 +132,14 @@ def replace_properties_dict(node: Any, # Use local variables and shadowing to replace replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' prefix = replacement + prefix + active_replacements.add(name) if prefix: propval.code = prefix + code + + # Ignore replaced symbols since they no longer exist as reads + if isinstance(node, dace.nodes.Tasklet): + node._ignored_symbols.update(active_replacements) + else: warnings.warn('Replacement of %s with %s was not made ' 'for string tasklet code of language %s' % (name, new_name, lang)) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a23d2616f9..a85e773337 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -62,7 +62,7 @@ def __getitem__(self, key): token = tokens.pop(0) result = result.members[token] return result - + def __setitem__(self, key, val): if isinstance(key, str) and '.' in key: raise KeyError('NestedDict does not support setting nested keys') @@ -273,7 +273,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: rhs_symbols = set() for lhs, rhs in self.assignments.items(): # Always add LHS symbols to the set of candidate free symbols - rhs_symbols |= symbolic.free_symbols_and_functions(rhs) + rhs_symbols |= set(map(str, dace.symbolic.symbols_in_ast(ast.parse(rhs)))) # Add the RHS to the set of candidate defined symbols ONLY if it has not been read yet # This also solves the ordering issue that may arise in cases like the 3rd example above if lhs not in cond_symbols and lhs not in rhs_symbols: @@ -756,7 +756,7 @@ def replace_dict(self, if replace_in_graph: # Replace in inter-state edges for edge in self.edges(): - edge.data.replace_dict(repldict) + edge.data.replace_dict(repldict, replace_keys=replace_keys) # Replace in states for state in self.nodes(): @@ -1323,7 +1323,7 @@ def arrays_recursive(self): if isinstance(node, nd.NestedSDFG): yield from node.sdfg.arrays_recursive() - def used_symbols(self, all_symbols: bool) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool=False) -> Set[str]: """ Returns a set of symbol names that are used by the SDFG, but not defined within it. This property is used to determine the symbolic @@ -1331,27 +1331,23 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: :param all_symbols: If False, only returns the set of symbols that will be used in the generated code and are needed as arguments. + :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping + will be removed from the set of defined symbols. """ defined_syms = set() free_syms = set() - # Exclude data descriptor names, constants, and shapes of global data descriptors - not_strictly_necessary_global_symbols = set() - for name, desc in self.arrays.items(): + # Exclude data descriptor names and constants + for name in self.arrays.keys(): defined_syms.add(name) - if not all_symbols: - used_desc_symbols = desc.used_symbols(all_symbols) - not_strictly_necessary = (desc.used_symbols(all_symbols=True) - used_desc_symbols) - not_strictly_necessary_global_symbols |= set(map(str, not_strictly_necessary)) - defined_syms |= set(self.constants_prop.keys()) - # Start with the set of SDFG free symbols - if all_symbols: - free_syms |= set(self.symbols.keys()) - else: - free_syms |= set(s for s in self.symbols.keys() if s not in not_strictly_necessary_global_symbols) + # Add used symbols from init and exit code + for code in self.init_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + for code in self.exit_code.values(): + free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) # Add free state symbols used_before_assignment = set() @@ -1362,7 +1358,8 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: ordered_states = self.nodes() for state in ordered_states: - free_syms |= state.used_symbols(all_symbols) + state_fsyms = state.used_symbols(all_symbols) + free_syms |= state_fsyms # Add free inter-state symbols for e in self.out_edges(state): @@ -1370,13 +1367,22 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly # compute the symbols that are used before being assigned. efsyms = e.data.used_symbols(all_symbols) - defined_syms |= set(e.data.assignments.keys()) - efsyms + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_fsyms) used_before_assignment.update(efsyms - defined_syms) free_syms |= efsyms # Remove symbols that were used before they were assigned defined_syms -= used_before_assignment + # Remove from defined symbols those that are in the symbol mapping + if self.parent_nsdfg_node is not None and keep_defined_in_mapping: + defined_syms -= set(self.parent_nsdfg_node.symbol_mapping.keys()) + + # Add the set of SDFG symbol parameters + # If all_symbols is False, those symbols would only be added in the case of non-Python tasklets + if all_symbols: + free_syms |= set(self.symbols.keys()) + # Subtract symbols defined in inter-state edges and constants return free_syms - defined_syms @@ -1392,6 +1398,29 @@ def free_symbols(self) -> Set[str]: """ return self.used_symbols(all_symbols=True) + def get_all_toplevel_symbols(self) -> Set[str]: + """ + Returns a set of all symbol names that are used by the SDFG's state machine. + This includes all symbols in the descriptor repository and interstate edges, + whether free or defined. Used to identify duplicates when, e.g., inlining or + dealiasing a set of nested SDFGs. + """ + # Exclude constants and data descriptor names + exclude = set(self.arrays.keys()) | set(self.constants_prop.keys()) + + syms = set() + + # Start with the set of SDFG free symbols + syms |= set(self.symbols.keys()) + + # Add inter-state symbols + for e in self.edges(): + syms |= set(e.data.assignments.keys()) + syms |= e.data.free_symbols + + # Subtract exluded symbols + return syms - exclude + def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: """ Determines what data containers are read and written in this SDFG. Does @@ -1458,7 +1487,7 @@ def init_signature(self, for_call=False, free_symbols=None) -> str: :param for_call: If True, returns arguments that can be used when calling the SDFG. """ # Get global free symbols scalar arguments - free_symbols = free_symbols or self.free_symbols + free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False) return ", ".join( dt.Scalar(self.symbols[k]).as_arg(name=k, with_types=not for_call, for_call=for_call) for k in sorted(free_symbols) if not k.startswith('__dace')) @@ -1478,6 +1507,21 @@ def signature_arglist(self, with_types=True, for_call=False, with_arrays=True, a arglist = arglist or self.arglist(scalars_only=not with_arrays) return [v.as_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] + def python_signature_arglist(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> List[str]: + """ Returns a list of arguments necessary to call this SDFG, + formatted as a list of Data-Centric Python definitions. + + :param with_types: If True, includes argument types in the result. + :param for_call: If True, returns arguments that can be used when + calling the SDFG. + :param with_arrays: If True, includes arrays, otherwise, + only symbols and scalars are included. + :param arglist: An optional cached argument list. + :return: A list of strings. For example: `['A: dace.float32[M]', 'b: dace.int32']`. + """ + arglist = arglist or self.arglist(scalars_only=not with_arrays, free_symbols=[]) + return [v.as_python_arg(name=k, with_types=with_types, for_call=for_call) for k, v in arglist.items()] + def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: """ Returns a C/C++ signature of this SDFG, used when generating code. @@ -1493,6 +1537,21 @@ def signature(self, with_types=True, for_call=False, with_arrays=True, arglist=N """ return ", ".join(self.signature_arglist(with_types, for_call, with_arrays, arglist)) + def python_signature(self, with_types=True, for_call=False, with_arrays=True, arglist=None) -> str: + """ Returns a Data-Centric Python signature of this SDFG, used when generating code. + + :param with_types: If True, includes argument types (can be used + for a function prototype). If False, only + include argument names (can be used for function + calls). + :param for_call: If True, returns arguments that can be used when + calling the SDFG. + :param with_arrays: If True, includes arrays, otherwise, + only symbols and scalars are included. + :param arglist: An optional cached argument list. + """ + return ", ".join(self.python_signature_arglist(with_types, for_call, with_arrays, arglist)) + def _repr_html_(self): """ HTML representation of the SDFG, used mainly for Jupyter notebooks. """ diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a4a6648401..1ff8fe4cf1 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -7,7 +7,7 @@ import inspect import itertools import warnings -from typing import Any, AnyStr, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload +from typing import TYPE_CHECKING, Any, AnyStr, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, overload import dace from dace import data as dt @@ -24,6 +24,9 @@ from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset +if TYPE_CHECKING: + import dace.sdfg.scope + def _getdebuginfo(old_dinfo=None) -> dtypes.DebugInfo: """ Returns a DebugInfo object for the position that called this function. @@ -409,6 +412,13 @@ def scope_children(self, ################################################################### # Query, subgraph, and replacement methods + def is_leaf_memlet(self, e): + if isinstance(e.src, nd.ExitNode) and e.src_conn and e.src_conn.startswith('OUT_'): + return False + if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): + return False + return True + def used_symbols(self, all_symbols: bool) -> Set[str]: """ Returns a set of symbol names that are used in the state. @@ -428,13 +438,23 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: elif isinstance(n, nd.AccessNode): # Add data descriptor symbols freesyms |= set(map(str, n.desc(sdfg).used_symbols(all_symbols))) - elif (isinstance(n, nd.Tasklet) and n.language == dtypes.Language.Python): - # Consider callbacks defined as symbols as free - for stmt in n.code.code: - for astnode in ast.walk(stmt): - if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) - and astnode.func.id in sdfg.symbols): - freesyms.add(astnode.func.id) + elif isinstance(n, nd.Tasklet): + if n.language == dtypes.Language.Python: + # Consider callbacks defined as symbols as free + for stmt in n.code.code: + for astnode in ast.walk(stmt): + if (isinstance(astnode, ast.Call) and isinstance(astnode.func, ast.Name) + and astnode.func.id in sdfg.symbols): + freesyms.add(astnode.func.id) + else: + # Find all string tokens and filter them to sdfg.symbols, while ignoring connectors + codesyms = symbolic.symbols_in_code( + n.code.as_string, + potential_symbols=sdfg.symbols.keys(), + symbols_to_ignore=(n.in_connectors.keys() | n.out_connectors.keys() | n.ignored_symbols), + ) + freesyms |= codesyms + continue if hasattr(n, 'used_symbols'): freesyms |= n.used_symbols(all_symbols) @@ -442,24 +462,17 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: freesyms |= n.free_symbols # Free symbols from memlets - def _is_leaf_memlet(e): - if isinstance(e.src, nd.ExitNode) and e.src_conn and e.src_conn.startswith('OUT_'): - return False - if isinstance(e.dst, nd.EntryNode) and e.dst_conn and e.dst_conn.startswith('IN_'): - return False - return True - for e in self.edges(): # If used for code generation, only consider memlet tree leaves - if not all_symbols and not _is_leaf_memlet(e): + if not all_symbols and not self.is_leaf_memlet(e): continue - freesyms |= e.data.used_symbols(all_symbols) + freesyms |= e.data.used_symbols(all_symbols, e) # Do not consider SDFG constants as symbols new_symbols.update(set(sdfg.constants.keys())) return freesyms - new_symbols - + @property def free_symbols(self) -> Set[str]: """ @@ -471,7 +484,6 @@ def free_symbols(self) -> Set[str]: """ return self.used_symbols(all_symbols=True) - def defined_symbols(self) -> Dict[str, dt.Data]: """ Returns a dictionary that maps currently-defined symbols in this SDFG @@ -532,8 +544,8 @@ def _read_and_write_sets(self) -> Tuple[Dict[AnyStr, List[Subset]], Dict[AnyStr, # Filter out memlets which go out but the same data is written to the AccessNode by another memlet for out_edge in list(out_edges): for in_edge in list(in_edges): - if (in_edge.data.data == out_edge.data.data and - in_edge.data.dst_subset.covers(out_edge.data.src_subset)): + if (in_edge.data.data == out_edge.data.data + and in_edge.data.dst_subset.covers(out_edge.data.src_subset)): out_edges.remove(out_edge) break @@ -676,14 +688,15 @@ def arglist(self, defined_syms=None, shared_transients=None) -> Dict[str, dt.Dat defined_syms = defined_syms or self.defined_symbols() scalar_args.update({ k: dt.Scalar(defined_syms[k]) if k in defined_syms else sdfg.arrays[k] - for k in self.free_symbols if not k.startswith('__dace') and k not in sdfg.constants + for k in self.used_symbols(all_symbols=False) if not k.startswith('__dace') and k not in sdfg.constants }) # Add scalar arguments from free symbols of data descriptors for arg in data_args.values(): scalar_args.update({ str(k): dt.Scalar(k.dtype) - for k in arg.free_symbols if not str(k).startswith('__dace') and str(k) not in sdfg.constants + for k in arg.used_symbols(all_symbols=False) + if not str(k).startswith('__dace') and str(k) not in sdfg.constants }) # Fill up ordered dictionary @@ -800,7 +813,7 @@ def __init__(self, label=None, sdfg=None, debuginfo=None, location=None): self.nosync = False self.location = location if location is not None else {} self._default_lineinfo = None - + def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) @@ -1450,7 +1463,7 @@ def add_reduce( """ import dace.libraries.standard as stdlib # Avoid import loop debuginfo = _getdebuginfo(debuginfo or self._default_lineinfo) - result = stdlib.Reduce(wcr, axes, identity, schedule=schedule, debuginfo=debuginfo) + result = stdlib.Reduce('Reduce', wcr, axes, identity, schedule=schedule, debuginfo=debuginfo) self.add_node(result) return result diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 3396335ece..1078414161 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -810,7 +810,7 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg out_edges = state.out_edges(view) # Invalid case: No data to view - if len(in_edges) == 0 or len(out_edges) == 0: + if len(in_edges) == 0 and len(out_edges) == 0: return None # If there is one edge (in/out) that leads (via memlet path) to an access diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py new file mode 100644 index 0000000000..6e311cde0c --- /dev/null +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -0,0 +1,285 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import sympy as sp +from typing import Dict + + +class UnionFind: + """ + Simple, not really optimized UnionFind implementation. + """ + + def __init__(self, elements) -> None: + self.ids = {e: e for e in elements} + + def add_element(self, e): + if e in self.ids: + return False + self.ids.update({e: e}) + return True + + def find(self, e): + prev = e + curr = self.ids[e] + while prev != curr: + prev = curr + curr = self.ids[curr] + # shorten the path + self.ids[e] = curr + return curr + + def union(self, e, f): + if f not in self.ids: + self.add_element(f) + self.ids[self.find(e)] = f + + +class ContradictingAssumptions(Exception): + pass + + +class Assumptions: + """ + Summarises the assumptions for a single symbol in three lists: equal, greater, lesser. + """ + + def __init__(self) -> None: + self.greater = [] + self.lesser = [] + self.equal = [] + + def add_greater(self, g): + if isinstance(g, sp.Symbol): + self.greater.append(g) + else: + self.greater = [x for x in self.greater if isinstance(x, sp.Symbol) or x > g] + if len([y for y in self.greater if not isinstance(y, sp.Symbol)]) == 0: + self.greater.append(g) + self.check_consistency() + + def add_lesser(self, l): + if isinstance(l, sp.Symbol): + self.lesser.append(l) + else: + self.lesser = [x for x in self.lesser if isinstance(x, sp.Symbol) or x < l] + if len([y for y in self.lesser if not isinstance(y, sp.Symbol)]) == 0: + self.lesser.append(l) + self.check_consistency() + + def add_equal(self, e): + for x in self.equal: + if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e: + raise ContradictingAssumptions() + self.equal.append(e) + self.check_consistency() + + def check_consistency(self): + if len(self.equal) > 0: + # we know exact value + for e in self.equal: + for g in self.greater: + if (e <= g) == True: + raise ContradictingAssumptions() + for l in self.lesser: + if (e >= l) == True: + raise ContradictingAssumptions() + else: + # check if any greater > any lesser + for g in self.greater: + for l in self.lesser: + if (g > l) == True: + raise ContradictingAssumptions() + return True + + def num_assumptions(self): + # returns the number of individual assumptions for this symbol + return len(self.greater) + len(self.lesser) + len(self.equal) + + +def propagate_assumptions(x, y, condensed_assumptions): + """ + Assuming x is equal to y, we propagate the assumptions on x to y. E.g. we have x==y and + x<5. Then, this method adds y<5 to the assumptions. + + :param x: A symbol. + :param y: Another symbol equal to x. + :param condensed_assumptions: Current assumptions over all symbols. + """ + if x == y: + return + assum_x = condensed_assumptions[x] + if y not in condensed_assumptions: + condensed_assumptions[y] = Assumptions() + assum_y = condensed_assumptions[y] + for e in assum_x.equal: + if e is not sp.Symbol(y): + assum_y.add_equal(e) + for g in assum_x.greater: + assum_y.add_greater(g) + for l in assum_x.lesser: + assum_y.add_lesser(l) + assum_y.check_consistency() + + +def propagate_assumptions_equal_symbols(condensed_assumptions): + """ + This method handles two things: 1) It generates the substitution dict for all equality assumptions. + And 2) it propagates assumptions too all equal symbols. For each equivalence class, we find a unique + representative using UnionFind. Then, all assumptions get propagates to this symbol using + ``propagate_assumptions``. + + :param condensed_assumptions: Current assumptions over all symbols. + :return: Returns a tuple consisting of 2 substitution dicts. The first one replaces each symbol with + the unique representative of its equivalence class. The second dict replaces each symbol with its numeric + value (if we assume it to be equal some value, e.g. N==5). + """ + # Make one set with unique identifier for each equality class + uf = UnionFind(list(condensed_assumptions)) + for sym in condensed_assumptions: + for other in condensed_assumptions[sym].equal: + if isinstance(other, sp.Symbol): + # we assume sym == other --> union these + uf.union(sym, other.name) + + equality_subs1 = {} + + # For each equivalence class, we now have one unique identifier. + # For each class, we give all the assumptions to this single symbol. + # And we swap each symbol in class for this symbol. + for sym in list(condensed_assumptions): + for other in condensed_assumptions[sym].equal: + if isinstance(other, sp.Symbol): + propagate_assumptions(sym, uf.find(sym), condensed_assumptions) + equality_subs1.update({sym: sp.Symbol(uf.find(sym))}) + + equality_subs2 = {} + # In a second step, each symbol gets replace with its equal number (if present) + # using equality_subs2. + for sym, assum in condensed_assumptions.items(): + for e in assum.equal: + if not isinstance(e, sp.Symbol): + equality_subs2.update({sym: e}) + + # Imagine we have M>N and M==10. We need to deduce N<10 from that. Following code handles that: + for sym, assum in condensed_assumptions.items(): + for g in assum.greater: + if isinstance(g, sp.Symbol): + for e in condensed_assumptions[g.name].equal: + if not isinstance(e, sp.Symbol): + condensed_assumptions[sym].add_greater(e) + assum.greater.remove(g) + for l in assum.lesser: + if isinstance(l, sp.Symbol): + for e in condensed_assumptions[l.name].equal: + if not isinstance(e, sp.Symbol): + condensed_assumptions[sym].add_lesser(e) + assum.lesser.remove(l) + return equality_subs1, equality_subs2 + + +def parse_assumptions(assumptions, array_symbols): + """ + Parses a list of assumptions into substitution dictionaries. Firstly, it gathers all assumptions and + keeps only the strongest ones. Afterwards it constructs two substitution dicts for the equality + assumptions: First dict for symbol==symbol assumptions; second dict for symbol==number assumptions. + The other assumptions get handles by N tuples of substitution dicts (N = max number of concurrent + assumptions for a single symbol). Each tuple is responsible for at most one assumption for each symbol. + First dict in the tuple substitutes the symbol with the assumption; second dict restores the initial symbol. + + :param assumptions: List of assumption strings. + :param array_symbols: List of symbols we assume to be positive, since they are the size of a data container. + :return: Tuple consisting of the 2 dicts responsible for the equality assumptions and the list of size N + reponsible for all other assumptions. + """ + + # TODO: This assumptions system can be improved further, especially the deduction of further assumptions + # from the ones we already have. An example of what is not working currently: + # We have assumptions N>0 N<5 and M>5. + # In the first substitution round we use N>0 and M>5. + # In the second substitution round we use N<5. + # Therefore, Max(M, N) will not be evaluated to M, even though from the input assumptions + # one can clearly deduce M>N. + # This happens since N<5 and M>5 are not in the same substitution round. + # The easiest way to fix this is probably to actually deduce the M>N assumption. + # This guarantees that in some substitution round, we will replace M with N + _p_M, where + # _p_M is some positive symbol. Hence, we would resolve Max(M, N) to N + _p_M, which is M. + + # I suspect there to be many more cases where further assumptions will not be deduced properly. + # But if the user enters assumptions as explicitly as possible, e.g. N<5 M>5 M>N, then everything + # works fine. + + # For each symbol x appearing as a data container size, we can assume x>0. + # TODO (later): Analyze size of shapes more, such that e.g. shape N + 1 --> We can assume N > -1. + # For now we only extract assumptions out of shapes if shape consists of only a single symbol. + for sym in array_symbols: + assumptions.append(f'{sym.name}>0') + + if assumptions is None: + return {}, [({}, {})] + + # Gather assumptions, keeping only the strongest ones for each symbol. + condensed_assumptions: Dict[str, Assumptions] = {} + for a in assumptions: + if '==' in a: + symbol, rhs = a.split('==') + if symbol not in condensed_assumptions: + condensed_assumptions[symbol] = Assumptions() + try: + condensed_assumptions[symbol].add_equal(int(rhs)) + except ValueError: + condensed_assumptions[symbol].add_equal(sp.Symbol(rhs)) + elif '>' in a: + symbol, rhs = a.split('>') + if symbol not in condensed_assumptions: + condensed_assumptions[symbol] = Assumptions() + try: + condensed_assumptions[symbol].add_greater(int(rhs)) + except ValueError: + condensed_assumptions[symbol].add_greater(sp.Symbol(rhs)) + # add the opposite, i.e. for x>y, we add yx + if rhs not in condensed_assumptions: + condensed_assumptions[rhs] = Assumptions() + condensed_assumptions[rhs].add_greater(sp.Symbol(symbol)) + + # Handle equal assumptions. + equality_subs = propagate_assumptions_equal_symbols(condensed_assumptions) + + # How many assumptions does symbol with most assumptions have? + curr_max = -1 + for _, assum in condensed_assumptions.items(): + if assum.num_assumptions() > curr_max: + curr_max = assum.num_assumptions() + + all_subs = [] + for i in range(curr_max): + all_subs.append(({}, {})) + + # Construct all the substitution dicts. In each substitution round we take at most one assumption for each + # symbol. Each round has two dicts: First one swaps in the assumption and second one restores the initial + # symbol. + for sym, assum in condensed_assumptions.items(): + i = 0 + for g in assum.greater: + replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True) + all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g}) + all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g}) + i += 1 + for l in assum.lesser: + replacement_symbol = sp.Symbol(f'_n_{sym}', negative=True, integer=True) + all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + l}) + all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - l}) + i += 1 + + return equality_subs, all_subs diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py index a80e769f64..e592fd11b5 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -328,4 +328,6 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): # now we have a triple (node, oNode, exitCandidates) nodes_oNodes_exits.append((node, oNode, exitCandidates)) + # remove artificial end node + sdfg_nx.remove_node(artificial_end_node) return nodes_oNodes_exits diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index a05fe10266..3549e86a20 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -19,6 +19,9 @@ import warnings from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits +from dace.sdfg.work_depth_analysis.assumptions import parse_assumptions +from dace.transformation.passes.symbol_ssa import StrictSymbolSSA +from dace.transformation.pass_pipeline import FixedPointPipeline def get_array_size_symbols(sdfg): @@ -39,22 +42,6 @@ def get_array_size_symbols(sdfg): return symbols -def posify_certain_symbols(expr, syms_to_posify): - """ - Takes an expression and evaluates it while assuming that certain symbols are positive. - - :param expr: The expression to evaluate. - :param syms_to_posify: List of symbols we assume to be positive. - :note: This is adapted from the Sympy function posify. - """ - - expr = sp.sympify(expr) - - reps = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) for s in syms_to_posify if s.is_positive is None} - expr = expr.subs(reps) - return expr.subs({r: s for s, r in reps.items()}) - - def symeval(val, symbols): """ Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. @@ -64,7 +51,7 @@ def symeval(val, symbols): """ first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} - return val.subs(first_replacement).subs(second_replacement) + return sp.simplify(val.subs(first_replacement).subs(second_replacement)) def evaluate_symbols(base, new): @@ -87,7 +74,14 @@ def count_work_matmul(node, symbols, state): result *= symeval(C_memlet.data.subset.size()[-1], symbols) # K result *= symeval(A_memlet.data.subset.size()[-1], symbols) - return result + return sp.sympify(result) + + +def count_depth_matmul(node, symbols, state): + # optimal depth of a matrix multiplication is O(log(size of shared dimension)): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) + return bigo(sp.log(size_shared_dimension)) def count_work_reduce(node, symbols, state): @@ -102,7 +96,12 @@ def count_work_reduce(node, symbols, state): result *= in_memlet.data.volume else: result = 0 - return result + return sp.sympify(result) + + +def count_depth_reduce(node, symbols, state): + # optimal depth of reduction is log of the work + return bigo(sp.log(count_work_reduce(node, symbols, state))) LIBNODES_TO_WORK = { @@ -111,22 +110,6 @@ def count_work_reduce(node, symbols, state): Reduce: count_work_reduce, } - -def count_depth_matmul(node, symbols, state): - # For now we set it equal to work: see comments in count_depth_reduce just below - return count_work_matmul(node, symbols, state) - - -def count_depth_reduce(node, symbols, state): - # depth of reduction is log2 of the work - # TODO: Can we actually assume this? Or is it equal to the work? - # Another thing to consider is that we essetially do NOT count wcr edges as operations for now... - - # return sp.ceiling(sp.log(count_work_reduce(node, symbols, state), 2)) - # set it equal to work for now - return count_work_reduce(node, symbols, state) - - LIBNODES_TO_DEPTH = { MatMul: count_depth_matmul, Transpose: lambda *args: 0, @@ -254,9 +237,9 @@ def count_depth_code(code): def tasklet_work(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: + # simplified work analysis for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return bigo(oedge.data.num_accesses) - + return oedge.data.num_accesses elif tasklet_node.code.language == dtypes.Language.Python: return count_arithmetic_ops_code(tasklet_node.code.code) else: @@ -267,11 +250,10 @@ def tasklet_work(tasklet_node, state): def tasklet_depth(tasklet_node, state): - # TODO: how to get depth of CPP tasklets? - # For now we use depth == work: if tasklet_node.code.language == dtypes.Language.CPP: + # Depth == work for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return bigo(oedge.data.num_accesses) + return oedge.data.num_accesses if tasklet_node.code.language == dtypes.Language.Python: return count_depth_code(tasklet_node.code.code) else: @@ -282,19 +264,41 @@ def tasklet_depth(tasklet_node, state): def get_tasklet_work(node, state): - return tasklet_work(node, state), -1 + return sp.sympify(tasklet_work(node, state)), sp.sympify(-1) def get_tasklet_work_depth(node, state): - return tasklet_work(node, state), tasklet_depth(node, state) + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) def get_tasklet_avg_par(node, state): - return tasklet_work(node, state), tasklet_depth(node, state) + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) + + +def update_value_map(old, new): + # add new assignments to old + old.update({k: v for k, v in new.items() if k not in old}) + # check for conflicts: + for k, v in new.items(): + if k in old and old[k] != v: + # conflict detected --> forget this mapping completely + old.pop(k) -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, - symbols) -> Tuple[sp.Expr, sp.Expr]: +def do_initial_subs(w, d, eq, subs1): + """ + Calls subs three times for the give (w)ork and (d)epth values. + """ + return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) + + +def sdfg_work_depth(sdfg: SDFG, + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -304,6 +308,11 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana :param w_d_map: Dictionary which will save the result. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the SDFG. """ @@ -313,9 +322,16 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols) - state_works[state] = sp.simplify(state_work * state.executions) - state_depths[state] = sp.simplify(state_depth * state.executions) + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, + detailed_analysis) + + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. + state_work = sp.simplify(state_work * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + state_works[state], state_depths[state] = state_work, state_depth w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and @@ -329,12 +345,18 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # Now we need to go over each triple (node, oNode, exits). For each triple, we # - remove edge (oNode, node), i.e. the backward edge # - for all exits e, add edge (oNode, e). This edge may already exist + # - remove edge from node to exit (if present, i.e. while-do loop) + # - This ensures that every node with > 1 outgoing edge is a branch guard + # - useful for detailed anaylsis. for node, oNode, exits in nodes_oNodes_exits: sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) for e in exits: if len(sdfg.edges_between(oNode, e)) == 0: # no edge there yet sdfg.add_edge(oNode, e, InterstateEdge()) + if len(sdfg.edges_between(node, e)) > 0: + # edge present --> remove it + sdfg.remove_edge(sdfg.edges_between(node, e)[0]) # add a dummy exit to the SDFG, such that each path ends there. dummy_exit = sdfg.add_state('dummy_exit') @@ -345,6 +367,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. work_map: Dict[SDFGState, sp.Expr] = {} depth_map: Dict[SDFGState, sp.Expr] = {} + # Keeps track of assignments done on InterstateEdges. + state_value_map: Dict[SDFGState, Dict[sp.Symbol, sp.Symbol]] = {} # The dummy state has 0 work and depth. state_depths[dummy_exit] = sp.sympify(0) state_works[dummy_exit] = sp.sympify(0) @@ -353,40 +377,67 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions # have been calculated. traversal_q = deque() - traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None)) + traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() + while traversal_q: - state, depth, work, ie = traversal_q.popleft() + state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() if ie is not None: visited.add(ie) - n_depth = sp.simplify(depth + state_depths[state]) - n_work = sp.simplify(work + state_works[state]) + if state in state_value_map: + # update value map: + update_value_map(state_value_map[state], value_map) + else: + state_value_map[state] = value_map + + # ignore assignments such as tmp=x[0], as those do not give much information. + value_map = {k: v for k, v in state_value_map[state].items() if '[' not in k and '[' not in v} + n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) + n_work = sp.simplify((work + state_works[state]).subs(value_map)) # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one # single path with the least average parallelsim (of all paths with more than 0 work). if analyze_tasklet == get_tasklet_avg_par: - if state in depth_map: # and hence als state in work_map - # if current path has 0 depth, we don't do anything. + if state in depth_map: # this means we have already visited this state before + cse = common_subexpr_stack.pop() + # if current path has 0 depth (--> 0 work as well), we don't do anything. if n_depth != 0: - # see if we need to update the work and depth of the current state + # check if we need to update the work and depth of the current state # we update if avg parallelism of new incoming path is less than current avg parallelism - old_avg_par = sp.simplify(work_map[state] / depth_map[state]) - new_avg_par = sp.simplify(n_work / n_depth) - - if depth_map[state] == 0 or new_avg_par < old_avg_par: - # old value was divided by zero or new path gives actually worse avg par, then we keep new value - depth_map[state] = n_depth - work_map[state] = n_work + if depth_map[state] == 0: + # old value was divided by zero --> we take new value anyway + depth_map[state] = cse[1] + n_depth + work_map[state] = cse[0] + n_work + else: + old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) + new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) + # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) + depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), + (depth_map[state], True)) + work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), + (work_map[state], True)) else: depth_map[state] = n_depth work_map[state] = n_work else: # search heaviest and deepest path separately if state in depth_map: # and consequently also in work_map - depth_map[state] = sp.Max(depth_map[state], n_depth) - work_map[state] = sp.Max(work_map[state], n_work) + # This cse value would appear in both arguments of the Max. Hence, for performance reasons, + # we pull it out of the Max expression. + # Example: We do cse + Max(a, b) instead of Max(cse + a, cse + b). + # This increases performance drastically, expecially since we avoid nesting Max expressions + # for cases where cse itself contains Max operators. + cse = common_subexpr_stack.pop() + if detailed_analysis: + # This MAX should be covered in the more detailed analysis + cond = condition_stack.pop() + work_map[state] = cse[0] + sp.Piecewise((work_map[state], sp.Not(cond)), (n_work, cond)) + depth_map[state] = cse[1] + sp.Piecewise((depth_map[state], sp.Not(cond)), (n_depth, cond)) + else: + work_map[state] = cse[0] + sp.Max(work_map[state], n_work) + depth_map[state] = cse[1] + sp.Max(depth_map[state], n_depth) else: depth_map[state] = n_depth work_map[state] = n_work @@ -397,7 +448,22 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana pass else: for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) + if len(out_edges) > 1: + # It is important to copy these stacks. Else both branches operate on the same stack. + # state is a branch guard --> save condition on stack + new_cond_stack = list(condition_stack) + new_cond_stack.append(oedge.data.condition_sympy()) + # same for common_subexr_stack + new_cse_stack = list(common_subexpr_stack) + new_cse_stack.append((work_map[state], depth_map[state])) + # same for value_map + new_value_map = dict(state_value_map[state]) + new_value_map.update({sp.Symbol(k): sp.Symbol(v) for k, v in oedge.data.assignments.items()}) + traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) + else: + value_map.update(oedge.data.assignments) + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, + common_subexpr_stack, value_map)) try: max_depth = depth_map[dummy_exit] @@ -408,16 +474,21 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana raise Exception( 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') - sdfg_result = (sp.simplify(max_work), sp.simplify(max_depth)) + sdfg_result = (max_work, max_depth) w_d_map[get_uuid(sdfg)] = sdfg_result return sdfg_result -def scope_work_depth(state: SDFGState, - w_d_map: Dict[str, sp.Expr], - analyze_tasklet, - symbols, - entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: +def scope_work_depth( + state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + entry: nd.EntryNode = None, + detailed_analysis: bool = False, +) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a scope. This works by traversing through the scope analyzing the work and depth of each encountered node. @@ -430,7 +501,14 @@ def scope_work_depth(state: SDFGState, this can be done in linear time by traversing the graph in topological order. :param state: The state in which the scope to analyze is contained. - :param sym_map: A dictionary mapping symbols to their values. + :param w_d_map: Dictionary saving the final result for each SDFG element. + :param analyze_tasklet: Function used to analyze tasklets. Either analyzes just work, work and depth or average parallelism. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. :return: A tuple containing the work and depth of the scope. """ @@ -447,7 +525,9 @@ def scope_work_depth(state: SDFGState, if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, node, + detailed_analysis) + s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work w_d_map[get_uuid(node, state)] = (s_work, s_depth) @@ -457,8 +537,13 @@ def scope_work_depth(state: SDFGState, elif isinstance(node, nd.Tasklet): # add up work for whole state, but also save work for this node in w_d_map t_work, t_depth = analyze_tasklet(node, state) + # check if tasklet has any outgoing wcr edges + for e in state.out_edges(node): + if e.data.wcr is not None: + t_work += count_arithmetic_ops_code(e.data.wcr) + t_work, t_depth = do_initial_subs(t_work, t_depth, equality_subs, subs1) work += t_work - w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) + w_d_map[get_uuid(node, state)] = (t_work, t_depth) elif isinstance(node, nd.NestedSDFG): # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. # We only want global symbols in our final work depth expressions. @@ -466,18 +551,35 @@ def scope_work_depth(state: SDFGState, nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, equality_subs, + subs1, detailed_analysis) + nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map work += nsdfg_work w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) elif isinstance(node, nd.LibraryNode): - lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) - work += lib_node_work - lib_node_depth = -1 # not analyzed + try: + lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + except KeyError: + # add a symbol to the top level sdfg, such that the user can define it in the extension + top_level_sdfg = state.parent + # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, + # such that the user can define its value. But it doesn't... + # How to achieve this? + top_level_sdfg.add_symbol(f'{node.name}_work', dtypes.int64) + lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) + lib_node_depth = sp.sympify(-1) # not analyzed if analyze_tasklet != get_tasklet_work: # we are analyzing depth - lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + try: + lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + except KeyError: + top_level_sdfg = state.parent + top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) + lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) + lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) + work += lib_node_work w_d_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) if entry is not None: @@ -485,8 +587,8 @@ def scope_work_depth(state: SDFGState, if isinstance(entry, nd.MapEntry): nmap: nd.Map = entry.map range: Range = nmap.range - n_exec = range.num_elements_exact() - work = work * sp.simplify(n_exec) + n_exec = range.num_elements() + work = sp.simplify(work * n_exec.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) else: print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') @@ -510,6 +612,7 @@ def scope_work_depth(state: SDFGState, traversal_q.append((node, sp.sympify(0), None)) # this map keeps track of the length of the longest path ending at each state so far seen. depth_map = {} + wcr_depth_map = {} while traversal_q: node, in_depth, in_edge = traversal_q.popleft() @@ -534,19 +637,51 @@ def scope_work_depth(state: SDFGState, # replace out_edges with the out_edges of the scope exit node out_edges = state.out_edges(exit_node) for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[node], oedge)) + # check for wcr + wcr_depth = sp.sympify(0) + if oedge.data.wcr is not None: + # This division gives us the number of writes to each single memory location, which is the depth + # as these need to be sequential (without assumptions on HW etc). + wcr_depth = oedge.data.volume / oedge.data.subset.num_elements() + if get_uuid(node, state) in wcr_depth_map: + # max + wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], + wcr_depth) + else: + wcr_depth_map[get_uuid(node, state)] = wcr_depth + # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. + wcr_depth = wcr_depth if not isinstance(oedge.dst, nd.MapExit) else sp.sympify(0) + + # only append if it's actually new information + # this e.g. helps for huge nested SDFGs with lots of inputs/outputs inside a map scope + append = True + for n, d, _ in traversal_q: + if oedge.dst == n and depth_map[node] + wcr_depth == d: + append = False + break + if append: + traversal_q.append((oedge.dst, depth_map[node] + wcr_depth, oedge)) + else: + visited.add(oedge) if len(out_edges) == 0 or node == scope_exit: # We have reached an end node --> update max_depth max_depth = sp.Max(max_depth, depth_map[node]) + for uuid in wcr_depth_map: + w_d_map[uuid] = (w_d_map[uuid][0], w_d_map[uuid][1] + wcr_depth_map[uuid]) # summarise work / depth of the whole scope in the dictionary - scope_result = (sp.simplify(work), sp.simplify(max_depth)) + scope_result = (work, max_depth) w_d_map[get_uuid(state)] = scope_result return scope_result -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, - symbols) -> Tuple[sp.Expr, sp.Expr]: +def state_work_depth(state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols, + equality_subs, + subs1, + detailed_analysis=False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -554,13 +689,23 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task :param w_d_map: The result will be saved to this map. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the state. """ - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, None) + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, None, + detailed_analysis) return work, depth -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> None: +def analyze_sdfg(sdfg: SDFG, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + assumptions: [str], + detailed_analysis: bool = False) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -568,12 +713,24 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No condition and an assignment. :param sdfg: The SDFG to analyze. :param w_d_map: Dictionary of SDFG elements to (work, depth) tuples. Result will be saved in here. - :param analyze_tasklet: The function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + :param analyze_tasklet: Function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + :param assumptions: List of strings. Each string corresponds to one assumption for some symbol, e.g. 'N>5'. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). """ # deepcopy such that original sdfg not changed sdfg = deepcopy(sdfg) + # apply SSA pass + pipeline = FixedPointPipeline([StrictSymbolSSA()]) + pipeline.apply_pass(sdfg, {}) + + array_symbols = get_array_size_symbols(sdfg) + # parse assumptions + equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state # will be executed, or to determine upper bounds for that number (such as in the case of branching) for sd in sdfg.all_sdfgs_recursive(): @@ -581,17 +738,36 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols) + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}, + detailed_analysis) - # Note: This posify could be done more often to improve performance. - array_symbols = get_array_size_symbols(sdfg) for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. - v_w = posify_certain_symbols(symeval(v_w, symbols), array_symbols) - v_d = posify_certain_symbols(symeval(v_d, symbols), array_symbols) + v_w, v_d = do_subs(v_w, v_d, all_subs) + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) +def do_subs(work, depth, all_subs): + """ + Handles all substitutions beyond the equality substitutions and the first substitution. + :param work: Some work expression. + :param depth: Some depth expression. + :param all_subs: List of substitution pairs to perform. + :return: Work depth expressions after doing all substitutions. + """ + # first do subs2 of first sub + # then do all the remaining subs + subs2 = all_subs[0][1] if len(all_subs) > 0 else {} + work, depth = sp.simplify(sp.sympify(work).subs(subs2)), sp.simplify(sp.sympify(depth).subs(subs2)) + for i in range(1, len(all_subs)): + subs1, subs2 = all_subs[i] + work, depth = sp.simplify(work.subs(subs1)), sp.simplify(depth.subs(subs1)) + work, depth = sp.simplify(work.subs(subs2)), sp.simplify(depth.subs(subs2)) + return work, depth + + ################################################################################ # Utility functions for running the analysis from the command line ############# ################################################################################ @@ -608,7 +784,9 @@ def main() -> None: choices=['work', 'workDepth', 'avgPar'], default='workDepth', help='Choose what to analyze. Default: workDepth') + parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5') + parser.add_argument("--detailed", action="store_true", help="Turns on detailed mode.") args = parser.parse_args() if not os.path.exists(args.filename): @@ -624,7 +802,7 @@ def main() -> None: sdfg = SDFG.from_file(args.filename) work_depth_map = {} - analyze_sdfg(sdfg, work_depth_map, analyze_tasklet) + analyze_sdfg(sdfg, work_depth_map, analyze_tasklet, args.assume, args.detailed) if args.analyze == 'workDepth': for k, v, in work_depth_map.items(): diff --git a/dace/symbolic.py b/dace/symbolic.py index 0ab6e3f6ff..f3dfcfb36d 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -14,6 +14,7 @@ from dace import dtypes DEFAULT_SYMBOL_TYPE = dtypes.int32 +_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') # NOTE: Up to (including) version 1.8, sympy.abc._clash is a dictionary of the # form {'N': sympy.abc.N, 'I': sympy.abc.I, 'pi': sympy.abc.pi} @@ -658,6 +659,7 @@ def eval(cls, x, y): def _eval_is_boolean(self): return True + class IfExpr(sympy.Function): @classmethod @@ -723,6 +725,19 @@ class IsNot(sympy.Function): pass +class Attr(sympy.Function): + """ + Represents a get-attribute call on a function, equivalent to ``a.b`` in Python. + """ + + @property + def free_symbols(self): + return {sympy.Symbol(str(self))} + + def __str__(self): + return f'{self.args[0]}.{self.args[1]}' + + def sympy_intdiv_fix(expr): """ Fix for SymPy printing out reciprocal values when they should be integral in "ceiling/floor" sympy functions. @@ -926,10 +941,9 @@ def _process_is(elem: Union[Is, IsNot]): return expr -class SympyBooleanConverter(ast.NodeTransformer): +class PythonOpToSympyConverter(ast.NodeTransformer): """ - Replaces boolean operations with the appropriate SymPy functions to avoid - non-symbolic evaluation. + Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation. """ _ast_to_sympy_comparators = { ast.Eq: 'Eq', @@ -945,12 +959,37 @@ class SympyBooleanConverter(ast.NodeTransformer): ast.NotIn: 'NotIn', } + _ast_to_sympy_functions = { + ast.BitAnd: 'BitwiseAnd', + ast.BitOr: 'BitwiseOr', + ast.BitXor: 'BitwiseXor', + ast.Invert: 'BitwiseNot', + ast.LShift: 'LeftShift', + ast.RShift: 'RightShift', + ast.FloorDiv: 'int_floor', + } + def visit_UnaryOp(self, node): if isinstance(node.op, ast.Not): func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node) new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[]) return ast.copy_location(new_node, node) - return node + elif isinstance(node.op, ast.Invert): + func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), + node) + new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[]) + return ast.copy_location(new_node, node) + return self.generic_visit(node) + + def visit_BinOp(self, node): + if type(node.op) in self._ast_to_sympy_functions: + func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), + node) + new_node = ast.Call(func=func_node, + args=[self.visit(value) for value in (node.left, node.right)], + keywords=[]) + return ast.copy_location(new_node, node) + return self.generic_visit(node) def visit_BoolOp(self, node): func_node = ast.copy_location(ast.Name(id=type(node.op).__name__, ctx=ast.Load()), node) @@ -970,8 +1009,7 @@ def visit_Compare(self, node: ast.Compare): raise NotImplementedError op = node.ops[0] arguments = [node.left, node.comparators[0]] - func_node = ast.copy_location( - ast.Name(id=SympyBooleanConverter._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node) + func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node) new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[]) return ast.copy_location(new_node, node) @@ -984,41 +1022,28 @@ def visit_NameConstant(self, node): return self.visit_Constant(node) def visit_IfExp(self, node): - new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load), args=[node.test, node.body, node.orelse], keywords=[]) + new_node = ast.Call(func=ast.Name(id='IfExpr', ctx=ast.Load), + args=[self.visit(node.test), + self.visit(node.body), + self.visit(node.orelse)], + keywords=[]) return ast.copy_location(new_node, node) - -class BitwiseOpConverter(ast.NodeTransformer): - """ - Replaces C/C++ bitwise operations with functions to avoid sympification to boolean operations. - """ - _ast_to_sympy_functions = { - ast.BitAnd: 'BitwiseAnd', - ast.BitOr: 'BitwiseOr', - ast.BitXor: 'BitwiseXor', - ast.Invert: 'BitwiseNot', - ast.LShift: 'LeftShift', - ast.RShift: 'RightShift', - ast.FloorDiv: 'int_floor', - } - - def visit_UnaryOp(self, node): - if isinstance(node.op, ast.Invert): - func_node = ast.copy_location( - ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node) - new_node = ast.Call(func=func_node, args=[self.visit(node.operand)], keywords=[]) - return ast.copy_location(new_node, node) - return self.generic_visit(node) - - def visit_BinOp(self, node): - if type(node.op) in BitwiseOpConverter._ast_to_sympy_functions: - func_node = ast.copy_location( - ast.Name(id=BitwiseOpConverter._ast_to_sympy_functions[type(node.op)], ctx=ast.Load()), node) - new_node = ast.Call(func=func_node, - args=[self.visit(value) for value in (node.left, node.right)], + + def visit_Subscript(self, node): + if isinstance(node.value, ast.Attribute): + attr = ast.Subscript(value=ast.Name(id=node.value.attr, ctx=ast.Load()), slice=node.slice, ctx=ast.Load()) + new_node = ast.Call(func=ast.Name(id='Attr', ctx=ast.Load), + args=[self.visit(node.value.value), self.visit(attr)], keywords=[]) return ast.copy_location(new_node, node) return self.generic_visit(node) + def visit_Attribute(self, node): + new_node = ast.Call(func=ast.Name(id='Attr', ctx=ast.Load), + args=[self.visit(node.value), ast.Name(id=node.attr, ctx=ast.Load)], + keywords=[]) + return ast.copy_location(new_node, node) + @lru_cache(maxsize=16384) def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic: @@ -1070,21 +1095,17 @@ def pystr_to_symbolic(expr, symbol_map=None, simplify=None) -> sympy.Basic: 'int_ceil': int_ceil, 'IfExpr': IfExpr, 'Mod': sympy.Mod, + 'Attr': Attr, } # _clash1 enables all one-letter variables like N as symbols # _clash also allows pi, beta, zeta and other common greek letters locals.update(_sympy_clash) if isinstance(expr, str): - # Sympy processes "not/and/or" as direct evaluation. Replace with - # And/Or(x, y), Not(x) - if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b', expr): - expr = unparse(SympyBooleanConverter().visit(ast.parse(expr).body[0])) - - # NOTE: If the expression contains bitwise operations, replace them with user-functions. - # NOTE: Sympy does not support bitwise operations and converts them to boolean operations. - if re.search('[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]', expr): - expr = unparse(BitwiseOpConverter().visit(ast.parse(expr).body[0])) + # Sympy processes "not/and/or" as direct evaluation. Replace with And/Or(x, y), Not(x) + # Also replaces bitwise operations with user-functions since SymPy does not support bitwise operations. + if re.search(r'\bnot\b|\band\b|\bor\b|\bNone\b|==|!=|\bis\b|\bif\b|[&]|[|]|[\^]|[~]|[<<]|[>>]|[//]|[\.]', expr): + expr = unparse(PythonOpToSympyConverter().visit(ast.parse(expr).body[0])) # TODO: support SymExpr over-approximated expressions try: @@ -1125,6 +1146,8 @@ def _print_Function(self, expr): return f'(({self._print(expr.args[0])}) and ({self._print(expr.args[1])}))' if str(expr.func) == 'OR': return f'(({self._print(expr.args[0])}) or ({self._print(expr.args[1])}))' + if str(expr.func) == 'Attr': + return f'{self._print(expr.args[0])}.{self._print(expr.args[1])}' return super()._print_Function(expr) def _print_Mod(self, expr): @@ -1377,6 +1400,29 @@ def equal(a: SymbolicType, b: SymbolicType, is_length: bool = True) -> Union[boo if is_length: for arg in args: facts += [sympy.Q.integer(arg), sympy.Q.positive(arg)] - + with sympy.assuming(*facts): return sympy.ask(sympy.Q.is_true(sympy.Eq(*args))) + + +def symbols_in_code(code: str, potential_symbols: Set[str] = None, + symbols_to_ignore: Set[str] = None) -> Set[str]: + """ + Tokenizes a code string for symbols and returns a set thereof. + + :param code: The code to tokenize. + :param potential_symbols: If not None, filters symbols to this given set. + :param symbols_to_ignore: If not None, filters out symbols from this set. + """ + if not code: + return set() + if potential_symbols is not None and len(potential_symbols) == 0: + # Don't bother tokenizing for an empty set of potential symbols + return set() + + tokens = set(re.findall(_NAME_TOKENS, code)) + if potential_symbols is not None: + tokens &= potential_symbols + if symbols_to_ignore is None: + return tokens + return tokens - symbols_to_ignore diff --git a/dace/transformation/dataflow/mpi.py b/dace/transformation/dataflow/mpi.py index 8138b86b26..b6a467dc21 100644 --- a/dace/transformation/dataflow/mpi.py +++ b/dace/transformation/dataflow/mpi.py @@ -23,9 +23,9 @@ class MPITransformMap(transformation.SingleStateTransformation): .. code-block:: text Input1 - Output1 - \ / + \\ / Input2 --- MapEntry -- Arbitrary R -- MapExit -- Output2 - / \ + / \\ InputN - OutputN diff --git a/dace/transformation/dataflow/tasklet_fusion.py b/dace/transformation/dataflow/tasklet_fusion.py index 99f8f625be..d6b4a3039b 100644 --- a/dace/transformation/dataflow/tasklet_fusion.py +++ b/dace/transformation/dataflow/tasklet_fusion.py @@ -249,6 +249,9 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): t1.language) for in_edge in graph.in_edges(t1): + if in_edge.src_conn is None and isinstance(in_edge.src, dace.nodes.EntryNode): + if len(new_tasklet.in_connectors) > 0: + continue graph.add_edge(in_edge.src, in_edge.src_conn, new_tasklet, in_edge.dst_conn, in_edge.data) for in_edge in graph.in_edges(t2): diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 73da318e94..8986c4e37f 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1307,6 +1307,23 @@ def redirect_edge(state: SDFGState, return new_edge +def replace_code_to_code_edges(sdfg: SDFG): + """ + Adds access nodes between all code->code edges in each state. + + :param sdfg: The SDFG to process. + """ + for state in sdfg.nodes(): + for edge in state.edges(): + if not isinstance(edge.src, nodes.CodeNode) or not isinstance(edge.dst, nodes.CodeNode): + continue + # Add access nodes + aname = state.add_access(edge.data.data) + state.add_edge(edge.src, edge.src_conn, aname, None, edge.data) + state.add_edge(aname, None, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) + state.remove_edge(edge) + + def can_run_state_on_fpga(state: SDFGState): """ Checks if state can be executed on FPGA. Used by FPGATransformState diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index 1ca92d5ffd..86e1cde062 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -14,6 +14,7 @@ Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]] SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]] + @properties.make_properties class StateReachability(ppl.Pass): """ @@ -35,13 +36,68 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta """ reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - reachable[sdfg.sdfg_id] = {} - tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) - for state in sdfg.nodes(): - reachable[sdfg.sdfg_id][state] = set(tc.successors(state)) + result: Dict[SDFGState, Set[SDFGState]] = {} + + # In networkx this is currently implemented naively for directed graphs. + # The implementation below is faster + # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + + for n, v in reachable_nodes(sdfg.nx): + result[n] = set(v) + + reachable[sdfg.sdfg_id] = result + return reachable +def _single_shortest_path_length_no_self(adj, source): + """Yields (node, level) in a breadth first search, without the first level + unless a self-edge exists. + + Adapted from Shortest Path Length helper function in NetworkX. + + Parameters + ---------- + adj : dict + Adjacency dict or view + firstlevel : dict + starting nodes, e.g. {source: 1} or {target: 1} + cutoff : int or float + level at which we stop the process + """ + firstlevel = {source: 1} + + seen = {} # level (number of hops) when seen in BFS + level = 0 # the current level + nextlevel = set(firstlevel) # set of nodes to check at next level + n = len(adj) + while nextlevel: + thislevel = nextlevel # advance to next level + nextlevel = set() # and start a new set (fringe) + found = [] + for v in thislevel: + if v not in seen: + if level == 0 and v is source: # Skip 0-length path to self + found.append(v) + continue + seen[v] = level # set the level of vertex v + found.append(v) + yield (v, level) + if len(seen) == n: + return + for v in found: + nextlevel.update(adj[v]) + level += 1 + del seen + + +def reachable_nodes(G): + """Computes the reachable nodes in G.""" + adj = G.adj + for n in G: + yield (n, dict(_single_shortest_path_length_no_self(adj, n))) + + @properties.make_properties class SymbolAccessSets(ppl.Pass): """ @@ -57,9 +113,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - def apply_pass( - self, top_sdfg: SDFG, _ - ) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: + def apply_pass(self, top_sdfg: SDFG, + _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. """ @@ -216,9 +271,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {SymbolAccessSets, StateReachability} - def _find_dominating_write( - self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState] - ) -> Optional[Edge[InterstateEdge]]: + def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], + state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]: last_state: SDFGState = read if isinstance(read, SDFGState) else read.src in_edges = last_state.parent.in_edges(last_state) @@ -257,9 +311,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) all_doms = cfg.all_dominators(sdfg, idom) - symbol_access_sets: Dict[ - Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]] - ] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] + symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]], + Tuple[Set[str], + Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id] for read_loc, (reads, _) in symbol_access_sets.items(): @@ -321,12 +375,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {AccessSets, FindAccessNodes, StateReachability} - def _find_dominating_write( - self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge], - access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], - state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], - no_self_shadowing: bool = False - ) -> Optional[Tuple[SDFGState, nd.AccessNode]]: + def _find_dominating_write(self, + desc: str, + state: SDFGState, + read: Union[nd.AccessNode, InterstateEdge], + access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], + state_idom: Dict[SDFGState, SDFGState], + access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], + no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]: if isinstance(read, nd.AccessNode): # If the read is also a write, it shadows itself. iedges = state.in_edges(read) @@ -408,18 +464,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i for oedge in out_edges: syms = oedge.data.free_symbols & anames if desc in syms: - write = self._find_dominating_write( - desc, state, oedge.data, access_nodes, idom, access_sets - ) + write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom, + access_sets) result[desc][write].add((state, oedge.data)) # Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not # dominating any reads and are thus not part of the results yet. for state in desc_states_with_nodes: for write_node in access_nodes[desc][state][1]: if not (state, write_node) in result[desc]: - write = self._find_dominating_write( - desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True - ) + write = self._find_dominating_write(desc, + state, + write_node, + access_nodes, + idom, + access_sets, + no_self_shadowing=True) result[desc][write].add((state, write_node)) # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index c197adf827..9cec6d11af 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -102,12 +102,8 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = for e in sdfg.out_edges(state): e.data.replace_dict(mapping, replace_keys=False) - # If symbols are never unknown any longer, remove from SDFG + # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} - # Remove from symbol repository - for sym in result: - if sym in sdfg.symbols: - sdfg.remove_symbol(sym) # Remove single-valued symbols from data descriptors (e.g., symbolic array size) sdfg.replace_dict({k: v @@ -121,6 +117,14 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = for sym in intersection: del edge.data.assignments[sym] + # If symbols are never unknown any longer, remove from SDFG + fsyms = sdfg.used_symbols(all_symbols=False) + result = {k: v for k, v in result.items() if k not in fsyms} + for sym in result: + if sym in sdfg.symbols: + # Remove from symbol repository and nested SDFG symbol mapipng + sdfg.remove_symbol(sym) + result = set(result.keys()) if self.recursive: @@ -188,7 +192,7 @@ def collect_constants(self, if len(in_edges) == 1: # Special case, propagate as-is if state not in result: # Condition evaluates to False when state is the start-state result[state] = {} - + # First the prior state if in_edges[0].src in result: # Condition evaluates to False when state is the start-state self._propagate(result[state], result[in_edges[0].src]) diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 94fcbdbc58..cf55f7a9b2 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -1,16 +1,13 @@ # Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. import itertools -import re from dataclasses import dataclass from typing import Optional, Set, Tuple -from dace import SDFG, dtypes, properties +from dace import SDFG, dtypes, properties, symbolic from dace.sdfg import nodes from dace.transformation import pass_pipeline as ppl -_NAME_TOKENS = re.compile(r'[a-zA-Z_][a-zA-Z_0-9]*') - @dataclass(unsafe_hash=True) @properties.make_properties @@ -81,7 +78,7 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: # Add symbols in global/init/exit code for code in itertools.chain(sdfg.global_code.values(), sdfg.init_code.values(), sdfg.exit_code.values()): - result |= _symbols_in_code(code.as_string) + result |= symbolic.symbols_in_code(code.as_string) for desc in sdfg.arrays.values(): result |= set(map(str, desc.free_symbols)) @@ -94,21 +91,19 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: for node in state.nodes(): if isinstance(node, nodes.Tasklet): if node.code.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code.as_string) + result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_global.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_global.as_string) + result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_init.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_init.as_string) + result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), + node.ignored_symbols) if node.code_exit.language != dtypes.Language.Python: - result |= _symbols_in_code(node.code_exit.as_string) - + result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), + node.ignored_symbols) for e in sdfg.edges(): result |= e.data.free_symbols return result - -def _symbols_in_code(code: str) -> Set[str]: - if not code: - return set() - return set(re.findall(_NAME_TOKENS, code)) diff --git a/requirements.txt b/requirements.txt index 33cd58a0bf..996449dbef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ charset-normalizer==3.1.0 click==8.1.3 dill==0.3.6 Flask==2.3.2 -fparser==0.1.2 +fparser==0.1.3 idna==3.4 importlib-metadata==6.6.0 itsdangerous==2.1.2 @@ -20,7 +20,7 @@ PyYAML==6.0 requests==2.31.0 six==1.16.0 sympy==1.9 -urllib3==2.0.3 +urllib3==2.0.6 websockets==11.0.3 Werkzeug==2.3.5 zipp==3.15.0 diff --git a/setup.py b/setup.py index b1737aed5a..a0ac2e2d49 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", ], - python_requires='>=3.6, <3.12', + python_requires='>=3.6, <3.13', packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), package_data={ '': [ @@ -74,7 +74,7 @@ include_package_data=True, install_requires=[ 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', - 'fparser >= 0.1.2', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', + 'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"' ] + cmake_requires, extras_require={ diff --git a/tests/blas/nodes/dot_test.py b/tests/blas/nodes/dot_test.py index d5f1d24263..a936be60a9 100755 --- a/tests/blas/nodes/dot_test.py +++ b/tests/blas/nodes/dot_test.py @@ -92,20 +92,23 @@ def run_test(target, size, vector_length): def test_dot_pure(): - return run_test("pure", 64, 1) + assert isinstance(run_test("pure", 64, 1), dace.SDFG) +# TODO: Refactor to use assert or return True/False (pytest deprecation of returning non-booleans) @xilinx_test() def test_dot_xilinx(): return run_test("xilinx", 64, 16) +# TODO: Refactor to use assert or return True/False (pytest deprecation of returning non-booleans) @xilinx_test() def test_dot_xilinx_decoupled(): with set_temporary("compiler", "xilinx", "decouple_array_interfaces", value=True): return run_test("xilinx", 64, 16) +# TODO: Refactor to use assert or return True/False (pytest deprecation of returning non-booleans) @intel_fpga_test() def test_dot_intel_fpga(): return run_test("intel_fpga", 64, 16) @@ -119,4 +122,4 @@ def test_dot_intel_fpga(): args = parser.parse_args() size = args.N - run_test(target, size, vector_length) + run_test(args.target, size, args.vector_length) diff --git a/tests/codegen/codegen_used_symbols_test.py b/tests/codegen/codegen_used_symbols_test.py new file mode 100644 index 0000000000..afa0ca0a05 --- /dev/null +++ b/tests/codegen/codegen_used_symbols_test.py @@ -0,0 +1,95 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests used-symbols in code generation.""" +import dace +import numpy +import pytest + + +n0i, n0j, n0k = (dace.symbol(s, dtype=dace.int32) for s in ('n0i', 'n0j', 'n0k')) +n1i, n1j, n1k = (dace.symbol(s, dtype=dace.int64) for s in ('n1i', 'n1j', 'n1k')) + + +@dace.program +def rprj3(r: dace.float64[n0i, n0j, n0k], s: dace.float64[n1i, n1j, n1k]): + + for i, j, k in dace.map[1:s.shape[0] - 1, 1:s.shape[1] - 1, 1:s.shape[2] - 1]: + + s[i, j, k] = ( + 0.5000 * r[2 * i, 2 * j, 2 * k] + + 0.2500 * (r[2 * i - 1, 2 * j, 2 * k] + r[2 * i + 1, 2 * j, 2 * k] + r[2 * i, 2 * j - 1, 2 * k] + + r[2 * i, 2 * j + 1, 2 * k] + r[2 * i, 2 * j, 2 * k - 1] + r[2 * i, 2 * j, 2 * k + 1]) + + 0.1250 * (r[2 * i - 1, 2 * j - 1, 2 * k] + r[2 * i - 1, 2 * j + 1, 2 * k] + + r[2 * i + 1, 2 * j - 1, 2 * k] + r[2 * i + 1, 2 * j + 1, 2 * k] + + r[2 * i - 1, 2 * j, 2 * k - 1] + r[2 * i - 1, 2 * j, 2 * k + 1] + + r[2 * i + 1, 2 * j, 2 * k - 1] + r[2 * i + 1, 2 * j, 2 * k + 1] + + r[2 * i, 2 * j - 1, 2 * k - 1] + r[2 * i, 2 * j - 1, 2 * k + 1] + + r[2 * i, 2 * j + 1, 2 * k - 1] + r[2 * i, 2 * j + 1, 2 * k + 1]) + + 0.0625 * (r[2 * i - 1, 2 * j - 1, 2 * k - 1] + r[2 * i - 1, 2 * j - 1, 2 * k + 1] + + r[2 * i - 1, 2 * j + 1, 2 * k - 1] + r[2 * i - 1, 2 * j + 1, 2 * k + 1] + + r[2 * i + 1, 2 * j - 1, 2 * k - 1] + r[2 * i + 1, 2 * j - 1, 2 * k + 1] + + r[2 * i + 1, 2 * j + 1, 2 * k - 1] + r[2 * i + 1, 2 * j + 1, 2 * k + 1])) + + +def test_codegen_used_symbols_cpu(): + + rng = numpy.random.default_rng(42) + r = rng.random((10, 10, 10)) + s_ref = numpy.zeros((4, 4, 4)) + s_val = numpy.zeros((4, 4, 4)) + + rprj3.f(r, s_ref) + rprj3(r, s_val) + + assert numpy.allclose(s_ref, s_val) + + +def test_codegen_used_symbols_cpu_2(): + + @dace.program + def rprj3_nested(r: dace.float64[n0i, n0j, n0k], s: dace.float64[n1i, n1j, n1k]): + rprj3(r, s) + + rng = numpy.random.default_rng(42) + r = rng.random((10, 10, 10)) + s_ref = numpy.zeros((4, 4, 4)) + s_val = numpy.zeros((4, 4, 4)) + + rprj3.f(r, s_ref) + rprj3_nested(r, s_val) + + assert numpy.allclose(s_ref, s_val) + + +@pytest.mark.gpu +def test_codegen_used_symbols_gpu(): + + sdfg = rprj3.to_sdfg() + for _, desc in sdfg.arrays.items(): + if not desc.transient and isinstance(desc, dace.data.Array): + desc.storage = dace.StorageType.GPU_Global + sdfg.apply_gpu_transformations() + func = sdfg.compile() + + try: + import cupy + + rng = numpy.random.default_rng(42) + r = rng.random((10, 10, 10)) + r_dev = cupy.asarray(r) + s_ref = numpy.zeros((4, 4, 4)) + s_val = cupy.zeros((4, 4, 4)) + + rprj3.f(r, s_ref) + func(r=r_dev, s=s_val, n0i=10, n0j=10, n0k=10, n1i=4, n1j=4, n1k=4) + + assert numpy.allclose(s_ref, s_val) + + except (ImportError, ModuleNotFoundError): + pass + + +if __name__ == "__main__": + + test_codegen_used_symbols_cpu() + test_codegen_used_symbols_cpu_2() + test_codegen_used_symbols_gpu() diff --git a/tests/codegen/control_flow_detection_test.py b/tests/codegen/control_flow_detection_test.py index 99d6a39b29..982140f7ed 100644 --- a/tests/codegen/control_flow_detection_test.py +++ b/tests/codegen/control_flow_detection_test.py @@ -120,6 +120,33 @@ def test_single_outedge_branch(): assert np.allclose(res, 2) +def test_extraneous_goto(): + + @dace.program + def tester(a: dace.float64[20]): + if a[0] < 0: + a[1] = 1 + a[2] = 1 + + sdfg = tester.to_sdfg(simplify=True) + assert 'goto' not in sdfg.generate_code()[0].code + + +def test_extraneous_goto_nested(): + + @dace.program + def tester(a: dace.float64[20]): + if a[0] < 0: + if a[0] < 1: + a[1] = 1 + else: + a[1] = 2 + a[2] = 1 + + sdfg = tester.to_sdfg(simplify=True) + assert 'goto' not in sdfg.generate_code()[0].code + + if __name__ == '__main__': test_for_loop_detection() test_invalid_for_loop_detection() @@ -128,3 +155,5 @@ def test_single_outedge_branch(): test_edge_sympy_function('TrueFalse') test_edge_sympy_function('SwitchCase') test_single_outedge_branch() + test_extraneous_goto() + test_extraneous_goto_nested() diff --git a/tests/compile_sdfg_test.py b/tests/compile_sdfg_test.py index 33ace1156a..3120359262 100644 --- a/tests/compile_sdfg_test.py +++ b/tests/compile_sdfg_test.py @@ -51,7 +51,7 @@ def tester(a: int): return a + 1 csdfg = tester.to_sdfg().compile() - with pytest.warns(None, match='Casting'): + with pytest.warns(UserWarning, match='Casting'): result = csdfg(0.1) assert result.item() == 1 diff --git a/tests/fortran/array_attributes_test.py b/tests/fortran/array_attributes_test.py new file mode 100644 index 0000000000..af433905bc --- /dev/null +++ b/tests/fortran/array_attributes_test.py @@ -0,0 +1,117 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_array_attribute_no_offset(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(5) :: d + + do i=1,5 + d(i) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test") + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 1 + assert sdfg.data('d').shape[0] == 5 + assert len(sdfg.data('d').offset) == 1 + assert sdfg.data('d').offset[0] == -1 + + a = np.full([5], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(1,5): + # offset -1 is already added + assert a[i-1] == i * 2 + +def test_fortran_frontend_array_attribute_offset(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(50:54) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(50:54) :: d + + do i=50,54 + d(i) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test") + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 1 + assert sdfg.data('d').shape[0] == 5 + assert len(sdfg.data('d').offset) == 1 + assert sdfg.data('d').offset[0] == -1 + + a = np.full([60], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(50,54): + # offset -1 is already added + assert a[i-1] == i * 2 + +def test_fortran_frontend_array_offset(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision d(50:54) + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision d(50:54) + + do i=50,54 + d(i) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test") + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 1 + assert sdfg.data('d').shape[0] == 5 + assert len(sdfg.data('d').offset) == 1 + assert sdfg.data('d').offset[0] == -1 + + a = np.full([60], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(50,54): + # offset -1 is already added + assert a[i-1] == i * 2 + + +if __name__ == "__main__": + + test_fortran_frontend_array_offset() + test_fortran_frontend_array_attribute_no_offset() + test_fortran_frontend_array_attribute_offset() diff --git a/tests/fortran/array_to_loop_offset.py b/tests/fortran/array_to_loop_offset.py new file mode 100644 index 0000000000..43d01d9b6b --- /dev/null +++ b/tests/fortran/array_to_loop_offset.py @@ -0,0 +1,119 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_arr2loop_without_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,3) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(5,3) :: d + + do i=1,5 + d(i, :) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,9], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(1,6): + for j in range(1,4): + assert a[i-1, j-1] == i * 2 + +def test_fortran_frontend_arr2loop_1d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(2:6) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(2:6) :: d + + d(:) = 5 + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 1 + assert sdfg.data('d').shape[0] == 5 + + a = np.full([6], 42, order="F", dtype=np.float64) + sdfg(d=a) + assert a[0] == 42 + for i in range(2,7): + assert a[i-1] == 5 + +def test_fortran_frontend_arr2loop_2d_offset(): + """ + Tests that the generated array map correctly handles offsets. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(5,7:9) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(5,7:9) :: d + + do i=1,5 + d(i, :) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + + # Now test to verify it executes correctly with no offset normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,9], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(1,6): + for j in range(7,10): + assert a[i-1, j-1] == i * 2 + +if __name__ == "__main__": + + test_fortran_frontend_arr2loop_1d_offset() + test_fortran_frontend_arr2loop_2d_offset() + test_fortran_frontend_arr2loop_without_offset() diff --git a/tests/fortran/offset_normalizer.py b/tests/fortran/offset_normalizer.py new file mode 100644 index 0000000000..b4138c1cac --- /dev/null +++ b/tests/fortran/offset_normalizer.py @@ -0,0 +1,164 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +from dace.frontend.fortran import ast_transforms, fortran_parser + +def test_fortran_frontend_offset_normalizer_1d(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(50:54) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(50:54) :: d + + do i=50,54 + d(i) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + + # Test to verify that offset is normalized correctly + ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True) + + for subroutine in ast.subroutine_definitions: + + loop = subroutine.execution_part.execution[1] + idx_assignment = loop.body.execution[1] + assert idx_assignment.rval.rval.value == "50" + + # Now test to verify it executes correctly + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 1 + assert sdfg.data('d').shape[0] == 5 + + a = np.full([5], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(0,5): + assert a[i] == (50+i)* 2 + +def test_fortran_frontend_offset_normalizer_2d(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(50:54,7:9) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(50:54,7:9) :: d + + do i=50,54 + do j=7,9 + d(i, j) = i * 2.0 + 3 * j + end do + end do + + END SUBROUTINE index_test_function + """ + + # Test to verify that offset is normalized correctly + ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True) + + for subroutine in ast.subroutine_definitions: + + loop = subroutine.execution_part.execution[1] + nested_loop = loop.body.execution[1] + + idx = nested_loop.body.execution[1] + assert idx.lval.name == 'tmp_index_0' + assert idx.rval.rval.value == "50" + + idx2 = nested_loop.body.execution[3] + assert idx2.lval.name == 'tmp_index_1' + assert idx2.rval.rval.value == "7" + + # Now test to verify it executes correctly + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(0,5): + for j in range(0,3): + assert a[i, j] == (50+i) * 2 + 3 * (7 + j) + +def test_fortran_frontend_offset_normalizer_2d_arr2loop(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM index_offset_test + implicit none + double precision, dimension(50:54,7:9) :: d + CALL index_test_function(d) + end + + SUBROUTINE index_test_function(d) + double precision, dimension(50:54,7:9) :: d + + do i=50,54 + d(i, :) = i * 2.0 + end do + + END SUBROUTINE index_test_function + """ + + # Test to verify that offset is normalized correctly + ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True) + + for subroutine in ast.subroutine_definitions: + + loop = subroutine.execution_part.execution[1] + nested_loop = loop.body.execution[1] + + idx = nested_loop.body.execution[1] + assert idx.lval.name == 'tmp_index_0' + assert idx.rval.rval.value == "50" + + idx2 = nested_loop.body.execution[3] + assert idx2.lval.name == 'tmp_index_1' + assert idx2.rval.rval.value == "7" + + # Now test to verify it executes correctly with no normalization + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True) + sdfg.save('test.sdfg') + sdfg.simplify(verbose=True) + sdfg.compile() + + assert len(sdfg.data('d').shape) == 2 + assert sdfg.data('d').shape[0] == 5 + assert sdfg.data('d').shape[1] == 3 + + a = np.full([5,3], 42, order="F", dtype=np.float64) + sdfg(d=a) + for i in range(0,5): + for j in range(0,3): + assert a[i, j] == (50 + i) * 2 + +if __name__ == "__main__": + + test_fortran_frontend_offset_normalizer_1d() + test_fortran_frontend_offset_normalizer_2d() + test_fortran_frontend_offset_normalizer_2d_arr2loop() diff --git a/tests/fortran/parent_test.py b/tests/fortran/parent_test.py new file mode 100644 index 0000000000..b1d08eaf37 --- /dev/null +++ b/tests/fortran/parent_test.py @@ -0,0 +1,91 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from dace.frontend.fortran import fortran_parser + +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +def test_fortran_frontend_parent(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM access_test + implicit none + double precision d(4) + d(1)=0 + CALL array_access_test_function(d) + end + + SUBROUTINE array_access_test_function(d) + double precision d(4) + + d(2)=5.5 + + END SUBROUTINE array_access_test_function + """ + ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test") + ast_transforms.ParentScopeAssigner().visit(ast) + + assert ast.parent is None + assert ast.main_program.parent == None + + main_program = ast.main_program + # Both executed lines + for execution in main_program.execution_part.execution: + assert execution.parent == main_program + # call to the function + call_node = main_program.execution_part.execution[1] + assert isinstance(call_node, ast_internal_classes.Call_Expr_Node) + for arg in call_node.args: + assert arg.parent == main_program + + for subroutine in ast.subroutine_definitions: + + assert subroutine.parent == None + assert subroutine.execution_part.parent == subroutine + for execution in subroutine.execution_part.execution: + assert execution.parent == subroutine + +def test_fortran_frontend_module(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + module test_module + implicit none + ! good enough approximation + integer, parameter :: pi = 4 + end module test_module + + PROGRAM access_test + implicit none + double precision d(4) + d(1)=0 + CALL array_access_test_function(d) + end + + SUBROUTINE array_access_test_function(d) + double precision d(4) + + d(2)=5.5 + + END SUBROUTINE array_access_test_function + """ + ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test") + ast_transforms.ParentScopeAssigner().visit(ast) + + assert ast.parent is None + assert ast.main_program.parent == None + + module = ast.modules[0] + assert module.parent == None + specification = module.specification_part.specifications[0] + assert specification.parent == module + + +if __name__ == "__main__": + + test_fortran_frontend_parent() + test_fortran_frontend_module() diff --git a/tests/fortran/scope_arrays.py b/tests/fortran/scope_arrays.py new file mode 100644 index 0000000000..0eb0cf44b2 --- /dev/null +++ b/tests/fortran/scope_arrays.py @@ -0,0 +1,47 @@ +# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. + +from dace.frontend.fortran import fortran_parser + +import dace.frontend.fortran.ast_transforms as ast_transforms +import dace.frontend.fortran.ast_internal_classes as ast_internal_classes + + +def test_fortran_frontend_parent(): + """ + Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct. + """ + test_string = """ + PROGRAM scope_test + implicit none + double precision d(4) + double precision, dimension(5) :: arr + double precision, dimension(50:54) :: arr3 + CALL scope_test_function(d) + end + + SUBROUTINE scope_test_function(d) + double precision d(4) + double precision, dimension(50:54) :: arr4 + + d(2)=5.5 + + END SUBROUTINE scope_test_function + """ + + ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test") + ast_transforms.ParentScopeAssigner().visit(ast) + visitor = ast_transforms.ScopeVarsDeclarations() + visitor.visit(ast) + + for var in ['d', 'arr', 'arr3']: + assert ('scope_test', var) in visitor.scope_vars + assert isinstance(visitor.scope_vars[('scope_test', var)], ast_internal_classes.Var_Decl_Node) + assert visitor.scope_vars[('scope_test', var)].name == var + + for var in ['d', 'arr4']: + assert ('scope_test_function', var) in visitor.scope_vars + assert visitor.scope_vars[('scope_test_function', var)].name == var + +if __name__ == "__main__": + + test_fortran_frontend_parent() diff --git a/tests/fpga/hbm_transform_test.py b/tests/fpga/hbm_transform_test.py index 6438ac7492..0346837fbc 100644 --- a/tests/fpga/hbm_transform_test.py +++ b/tests/fpga/hbm_transform_test.py @@ -1,7 +1,6 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. from dace.fpga_testing import xilinx_test -from numpy.lib import math from dace.sdfg.state import SDFGState import numpy as np from dace import dtypes diff --git a/tests/library/gemm_test.py b/tests/library/gemm_test.py index df60d1aa43..07e9006ece 100644 --- a/tests/library/gemm_test.py +++ b/tests/library/gemm_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import pytest import warnings import itertools @@ -132,7 +132,10 @@ def numpy_gemm(A, B, C, transA, transB, alpha, beta): assert diff <= 1e-5 -@pytest.mark.parametrize(('implementation', ), [('pure', ), ('MKL', ), pytest.param('cuBLAS', marks=pytest.mark.gpu)]) +@pytest.mark.parametrize( + ('implementation', ), + [('pure', ), pytest.param('MKL', marks=pytest.mark.mkl), + pytest.param('cuBLAS', marks=pytest.mark.gpu)]) def test_library_gemm(implementation): param_grid_trans = dict( transA=[True, False], diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index 48853cdf26..d2c348ce95 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -1,231 +1,246 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" -Tests for numpy advanced indexing syntax. See also: -https://numpy.org/devdocs/reference/arrays.indexing.html -""" -import dace -import numpy as np -import pytest - -N = dace.symbol('N') -M = dace.symbol('M') - - -def test_flat(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - return A.flat - - A = np.random.rand(20, 30) - res = indexing_test(A) - assert np.allclose(A.flat, res) - - -def test_flat_noncontiguous(): - with dace.config.set_temporary('compiler', 'allow_view_arguments', value=True): - - @dace.program - def indexing_test(A): - return A.flat - - A = np.random.rand(20, 30).transpose() - res = indexing_test(A) - assert np.allclose(A.flat, res) - - -def test_ellipsis(): - @dace.program - def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): - return A[1:5, ..., 0] - - A = np.random.rand(5, 5, 5, 5, 5) - res = indexing_test(A) - assert np.allclose(A[1:5, ..., 0], res) - - -def test_aug_implicit(): - @dace.program - def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): - A[:, 1:5][:, 0:2] += 5 - - A = np.random.rand(5, 5, 5, 5, 5) - regression = np.copy(A) - regression[:, 1:5][:, 0:2] += 5 - indexing_test(A) - assert np.allclose(A, regression) - - -def test_ellipsis_aug(): - @dace.program - def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): - A[1:5, ..., 0] += 5 - - A = np.random.rand(5, 5, 5, 5, 5) - regression = np.copy(A) - regression[1:5, ..., 0] += 5 - indexing_test(A) - assert np.allclose(A, regression) - - -def test_newaxis(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - return A[:, np.newaxis, None, :] - - A = np.random.rand(20, 30) - res = indexing_test(A) - assert res.shape == (20, 1, 1, 30) - assert np.allclose(A[:, np.newaxis, None, :], res) - - -def test_multiple_newaxis(): - @dace.program - def indexing_test(A: dace.float64[10, 20, 30]): - return A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis] - - A = np.random.rand(10, 20, 30) - res = indexing_test(A) - assert res.shape == (1, 10, 1, 1, 20, 1, 30, 1) - assert np.allclose(A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis], res) - - -def test_index_intarr_1d(): - @dace.program - def indexing_test(A: dace.float64[N], indices: dace.int32[M]): - return A[indices] - - A = np.random.rand(20) - indices = [1, 10, 15] - res = indexing_test(A, indices, M=3) - assert np.allclose(A[indices], res) - - -def test_index_intarr_1d_literal(): - @dace.program - def indexing_test(A: dace.float64[20]): - return A[[1, 10, 15]] - - A = np.random.rand(20) - indices = [1, 10, 15] - res = indexing_test(A) - assert np.allclose(A[indices], res) - - -def test_index_intarr_1d_constant(): - indices = [1, 10, 15] - - @dace.program - def indexing_test(A: dace.float64[20]): - return A[indices] - - A = np.random.rand(20) - res = indexing_test(A) - assert np.allclose(A[indices], res) - - -def test_index_intarr_1d_multi(): - @dace.program - def indexing_test(A: dace.float64[20, 10, 30], indices: dace.int32[3]): - return A[indices, 2:7:2, [15, 10, 1]] - - A = np.random.rand(20, 10, 30) - indices = [1, 10, 15] - res = indexing_test(A, indices) - # FIXME: NumPy behavior is unclear in this case - assert np.allclose(np.diag(A[indices, 2:7:2, [15, 10, 1]]), res) - - -def test_index_intarr_nd(): - @dace.program - def indexing_test(A: dace.float64[4, 3], rows: dace.int64[2, 2], columns: dace.int64[2, 2]): - return A[rows, columns] - - A = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float64) - rows = np.array([[0, 0], [3, 3]], dtype=np.intp) - columns = np.array([[0, 2], [0, 2]], dtype=np.intp) - expected = A[rows, columns] - res = indexing_test(A, rows, columns) - assert np.allclose(expected, res) - - -def test_index_boolarr_rhs(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - return A[A > 15] - - A = np.ndarray((20, 30), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 30) - regression = A[A > 15] - - # Right-hand side boolean array indexing is unsupported - with pytest.raises(IndexError): - res = indexing_test(A) - assert np.allclose(regression, res) - - -def test_index_multiboolarr(): - @dace.program - def indexing_test(A: dace.float64[20, 20], B: dace.bool[20]): - A[B, B] = 2 - - A = np.ndarray((20, 20), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 20) - B = A[:, 1] > 0 - - # Advanced indexing with multiple boolean arrays should be disallowed - with pytest.raises(IndexError): - indexing_test(A, B) - - -def test_index_boolarr_fixed(): - @dace.program - def indexing_test(A: dace.float64[20, 30], barr: dace.bool[20, 30]): - A[barr] += 5 - - A = np.ndarray((20, 30), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 30) - barr = A > 15 - regression = np.copy(A) - regression[barr] += 5 - - indexing_test(A, barr) - - assert np.allclose(regression, A) - - -def test_index_boolarr_inline(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - A[A > 15] = 2 - - A = np.ndarray((20, 30), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 30) - regression = np.copy(A) - regression[A > 15] = 2 - - indexing_test(A) - - assert np.allclose(regression, A) - - -if __name__ == '__main__': - test_flat() - test_flat_noncontiguous() - test_ellipsis() - test_aug_implicit() - test_ellipsis_aug() - test_newaxis() - test_multiple_newaxis() - test_index_intarr_1d() - test_index_intarr_1d_literal() - test_index_intarr_1d_constant() - test_index_intarr_1d_multi() - test_index_intarr_nd() - test_index_boolarr_rhs() - test_index_multiboolarr() - test_index_boolarr_fixed() - test_index_boolarr_inline() +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests for numpy advanced indexing syntax. See also: +https://numpy.org/devdocs/reference/arrays.indexing.html +""" +import dace +from dace.frontend.python.common import DaceSyntaxError +import numpy as np +import pytest + +N = dace.symbol('N') +M = dace.symbol('M') + + +def test_flat(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + return A.flat + + A = np.random.rand(20, 30) + res = indexing_test(A) + assert np.allclose(A.flat, res) + + +def test_flat_noncontiguous(): + with dace.config.set_temporary('compiler', 'allow_view_arguments', value=True): + + @dace.program + def indexing_test(A): + return A.flat + + A = np.random.rand(20, 30).transpose() + res = indexing_test(A) + assert np.allclose(A.flat, res) + + +def test_ellipsis(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + return A[1:5, ..., 0] + + A = np.random.rand(5, 5, 5, 5, 5) + res = indexing_test(A) + assert np.allclose(A[1:5, ..., 0], res) + + +def test_aug_implicit(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + A[:, 1:5][:, 0:2] += 5 + + A = np.random.rand(5, 5, 5, 5, 5) + regression = np.copy(A) + regression[:, 1:5][:, 0:2] += 5 + indexing_test(A) + assert np.allclose(A, regression) + + +def test_ellipsis_aug(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + A[1:5, ..., 0] += 5 + + A = np.random.rand(5, 5, 5, 5, 5) + regression = np.copy(A) + regression[1:5, ..., 0] += 5 + indexing_test(A) + assert np.allclose(A, regression) + + +def test_newaxis(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + return A[:, np.newaxis, None, :] + + A = np.random.rand(20, 30) + res = indexing_test(A) + assert res.shape == (20, 1, 1, 30) + assert np.allclose(A[:, np.newaxis, None, :], res) + + +def test_multiple_newaxis(): + + @dace.program + def indexing_test(A: dace.float64[10, 20, 30]): + return A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis] + + A = np.random.rand(10, 20, 30) + res = indexing_test(A) + assert res.shape == (1, 10, 1, 1, 20, 1, 30, 1) + assert np.allclose(A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis], res) + + +def test_index_intarr_1d(): + + @dace.program + def indexing_test(A: dace.float64[N], indices: dace.int32[M]): + return A[indices] + + A = np.random.rand(20) + indices = [1, 10, 15] + res = indexing_test(A, indices, M=3) + assert np.allclose(A[indices], res) + + +def test_index_intarr_1d_literal(): + + @dace.program + def indexing_test(A: dace.float64[20]): + return A[[1, 10, 15]] + + A = np.random.rand(20) + indices = [1, 10, 15] + res = indexing_test(A) + assert np.allclose(A[indices], res) + + +def test_index_intarr_1d_constant(): + indices = [1, 10, 15] + + @dace.program + def indexing_test(A: dace.float64[20]): + return A[indices] + + A = np.random.rand(20) + res = indexing_test(A) + assert np.allclose(A[indices], res) + + +def test_index_intarr_1d_multi(): + + @dace.program + def indexing_test(A: dace.float64[20, 10, 30], indices: dace.int32[3]): + return A[indices, 2:7:2, [15, 10, 1]] + + A = np.random.rand(20, 10, 30) + indices = [1, 10, 15] + res = indexing_test(A, indices) + # FIXME: NumPy behavior is unclear in this case + assert np.allclose(np.diag(A[indices, 2:7:2, [15, 10, 1]]), res) + + +def test_index_intarr_nd(): + + @dace.program + def indexing_test(A: dace.float64[4, 3], rows: dace.int64[2, 2], columns: dace.int64[2, 2]): + return A[rows, columns] + + A = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float64) + rows = np.array([[0, 0], [3, 3]], dtype=np.intp) + columns = np.array([[0, 2], [0, 2]], dtype=np.intp) + expected = A[rows, columns] + res = indexing_test(A, rows, columns) + assert np.allclose(expected, res) + + +def test_index_boolarr_rhs(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + return A[A > 15] + + A = np.ndarray((20, 30), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 30) + regression = A[A > 15] + + # Right-hand side boolean array indexing is unsupported + with pytest.raises(IndexError): + res = indexing_test(A) + assert np.allclose(regression, res) + + +def test_index_multiboolarr(): + + @dace.program + def indexing_test(A: dace.float64[20, 20], B: dace.bool[20]): + A[B, B] = 2 + + A = np.ndarray((20, 20), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 20) + B = A[:, 1] > 0 + + # Advanced indexing with multiple boolean arrays should be disallowed + with pytest.raises(DaceSyntaxError): + indexing_test(A, B) + + +def test_index_boolarr_fixed(): + + @dace.program + def indexing_test(A: dace.float64[20, 30], barr: dace.bool[20, 30]): + A[barr] += 5 + + A = np.ndarray((20, 30), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 30) + barr = A > 15 + regression = np.copy(A) + regression[barr] += 5 + + indexing_test(A, barr) + + assert np.allclose(regression, A) + + +def test_index_boolarr_inline(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + A[A > 15] = 2 + + A = np.ndarray((20, 30), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 30) + regression = np.copy(A) + regression[A > 15] = 2 + + indexing_test(A) + + assert np.allclose(regression, A) + + +if __name__ == '__main__': + test_flat() + test_flat_noncontiguous() + test_ellipsis() + test_aug_implicit() + test_ellipsis_aug() + test_newaxis() + test_multiple_newaxis() + test_index_intarr_1d() + test_index_intarr_1d_literal() + test_index_intarr_1d_constant() + test_index_intarr_1d_multi() + test_index_intarr_nd() + test_index_boolarr_rhs() + test_index_multiboolarr() + test_index_boolarr_fixed() + test_index_boolarr_inline() diff --git a/tests/numpy/ufunc_support_test.py b/tests/numpy/ufunc_support_test.py index 65737a2ceb..df0234259b 100644 --- a/tests/numpy/ufunc_support_test.py +++ b/tests/numpy/ufunc_support_test.py @@ -127,7 +127,8 @@ def test_ufunc_add_where(): W = np.random.randint(2, size=(10, ), dtype=np.bool_) C = ufunc_add_where(A, B, W) assert (np.array_equal(np.add(A, B, where=W)[W], C[W])) - assert (not np.array_equal((A + B)[np.logical_not(W)], C[np.logical_not(W)])) + if not np.all(W): # If all of W is True, np.logical_not(W) would result in empty arrays + assert (not np.array_equal((A + B)[np.logical_not(W)], C[np.logical_not(W)])) @dace.program @@ -154,18 +155,6 @@ def test_ufunc_add_where_false(): assert (not np.array_equal(A + B, C)) -@dace.program -def ufunc_add_where_false(A: dace.int32[10], B: dace.int32[10]): - return np.add(A, B, where=False) - - -def test_ufunc_add_where_false(): - A = np.random.randint(1, 10, size=(10, ), dtype=np.int32) - B = np.random.randint(1, 10, size=(10, ), dtype=np.int32) - C = ufunc_add_where_false(A, B) - assert (not np.array_equal(A + B, C)) - - @dace.program def ufunc_add_where_list(A: dace.int32[2], B: dace.int32[2]): return np.add(A, B, where=[True, False]) @@ -456,7 +445,7 @@ def test_ufunc_add_outer_where(): B = np.random.randint(1, 10, size=(2, 2, 2, 2, 2), dtype=np.int32) W = np.random.randint(2, size=(2, 2, 2, 2, 2, 2, 2, 2, 2, 2), dtype=np.bool_) s = ufunc_add_outer_where(A, B, W) - assert (np.array_equal(np.add.outer(A, B, where=W)[W], s[W])) + assert np.array_equal(np.add.outer(A, B, where=W)[W], s[W]) @dace.program @@ -472,7 +461,7 @@ def test_ufunc_add_outer_where2(): C = ufunc_add_outer_where2(A, B, W) where = np.empty((2, 2, 2, 2, 2, 2, 2, 2, 2, 2), dtype=np.bool_) where[:] = W - assert (np.array_equal(np.add.outer(A, B, where=W)[where], C[where])) + assert np.array_equal(np.add.outer(A, B, where=W)[where], C[where]) @compare_numpy_output() diff --git a/tests/python_frontend/argument_test.py b/tests/python_frontend/argument_test.py index 1f43337eb8..cb47188029 100644 --- a/tests/python_frontend/argument_test.py +++ b/tests/python_frontend/argument_test.py @@ -2,6 +2,7 @@ import dace import pytest +import numpy as np N = dace.symbol('N') @@ -16,5 +17,29 @@ def test_extra_args(): imgcpy([[1, 2], [3, 4]], [[4, 3], [2, 1]], 0.0, 1.0) +def test_missing_arguments_regression(): + + def nester(a, b, T): + for i, j in dace.map[0:20, 0:20]: + start = 0 + end = min(T, 6) + + elem: dace.float64 = 0 + for ii in range(start, end): + if ii % 2 == 0: + elem += b[ii] + + a[j, i] = elem + + @dace.program + def tester(x: dace.float64[20, 20]): + gdx = np.ones((10, ), dace.float64) + for T in range(2): + nester(x, gdx, T) + + tester.to_sdfg().compile() + + if __name__ == '__main__': test_extra_args() + test_missing_arguments_regression() diff --git a/tests/python_frontend/type_statement_test.py b/tests/python_frontend/type_statement_test.py new file mode 100644 index 0000000000..16ec1613db --- /dev/null +++ b/tests/python_frontend/type_statement_test.py @@ -0,0 +1,22 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import pytest + + +# TODO: Investigate why pytest parses the DaCeProgram, even when the test is not supposed to run. +# @pytest.mark.py312 +# def test_type_statement(): + +# @dace.program +# def type_statement(): +# type Scalar[T] = T +# A: Scalar[dace.float32] = 0 +# return A + +# with pytest.raises(dace.frontend.python.common.DaceSyntaxError): +# type_statement() + + +if __name__ == '__main__': + # test_type_statement() + pass diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py new file mode 100644 index 0000000000..0811682870 --- /dev/null +++ b/tests/schedule_tree/naming_test.py @@ -0,0 +1,204 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.transformation.passes.constant_propagation import ConstantPropagation + +import pytest +from typing import List + + +def _irreducible_loop_to_loop(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain of two for loops with goto from second to first's body + s1 = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + s2 = sdfg.add_state() + e = sdfg.add_state() + + # Add a loop + l1 = sdfg.add_state() + l2 = sdfg.add_state_after(l1) + sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + + l3 = sdfg.add_state() + l4 = sdfg.add_state_after(l3) + sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) + + # Irreducible part + sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) + + # Avoiding undefined behavior + sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' + + return sdfg + + +def _nested_irreducible_loops(): + sdfg = _irreducible_loop_to_loop() + nsdfg = _irreducible_loop_to_loop() + + l1 = sdfg.node(5) + l1.add_nested_sdfg(nsdfg, None, {}, {}) + return sdfg + + +def test_clash_states(): + """ + Same test as test_irreducible_in_loops, but all states in the nested SDFG share names with the top SDFG + """ + sdfg = _nested_irreducible_loops() + + stree = as_schedule_tree(sdfg) + unique_names = set() + for node in stree.preorder_traversal(): + if isinstance(node, tn.StateLabel): + if node.state.name in unique_names: + raise NameError('Name clash') + unique_names.add(node.state.name) + + +@pytest.mark.parametrize('constprop', (False, True)) +def test_clash_symbol_mapping(constprop): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [200], dace.float64) + sdfg.add_symbol('M', dace.int64) + sdfg.add_symbol('N', dace.int64) + sdfg.add_symbol('k', dace.int64) + + state = sdfg.add_state() + state2 = sdfg.add_state() + sdfg.add_edge(state, state2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + + nsdfg = dace.SDFG('nester') + nsdfg.add_symbol('M', dace.int64) + nsdfg.add_symbol('N', dace.int64) + nsdfg.add_symbol('k', dace.int64) + nsdfg.add_array('out', [100], dace.float64) + nsdfg.add_transient('tmp', [100], dace.float64) + nstate = nsdfg.add_state() + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + + # Copy + # The code should end up as `tmp[N:N+2] <- out[M+1:M+3]` + # In the outer SDFG: `tmp[N:N+2] <- A[M+101:M+103]` + r = nstate.add_access('out') + w = nstate.add_access('tmp') + nstate.add_edge(r, None, w, None, dace.Memlet(data='out', subset='k:k+2', other_subset='M:M+2')) + + # Tasklet + # The code should end up as `tmp[M] -> Tasklet -> out[N + 1]` + # In the outer SDFG: `tmp[M] -> Tasklet -> A[N + 101]` + r = nstate2.add_access('tmp') + w = nstate2.add_access('out') + t = nstate2.add_tasklet('dosomething', {'a'}, {'b'}, 'b = a + 1') + nstate2.add_edge(r, None, t, 'a', dace.Memlet('tmp[N]')) + nstate2.add_edge(t, 'b', w, None, dace.Memlet('out[k]')) + + # Connect nested SDFG to parent SDFG with an offset memlet + nsdfg_node = state2.add_nested_sdfg(nsdfg, None, {}, {'out'}, {'N': 'M', 'M': 'N', 'k': 'k'}) + w = state2.add_write('A') + state2.add_edge(nsdfg_node, 'out', w, None, dace.Memlet('A[100:200]')) + + # Get rid of k + if constprop: + ConstantPropagation().apply_pass(sdfg, {}) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) in (2, 4) # Either with assignments or without + + # With assignments + if len(stree.children) == 4: + assert constprop is False + assert isinstance(stree.children[0], tn.AssignNode) + assert isinstance(stree.children[1], tn.CopyNode) + assert isinstance(stree.children[2], tn.AssignNode) + assert isinstance(stree.children[3], tn.TaskletNode) + assert stree.children[1].memlet.data == 'A' + assert str(stree.children[1].memlet.src_subset) == 'k + 100:k + 102' + assert str(stree.children[1].memlet.dst_subset) == 'N:N + 2' + assert stree.children[3].in_memlets['a'].data == 'tmp' + assert str(stree.children[3].in_memlets['a'].src_subset) == 'M' + assert stree.children[3].out_memlets['b'].data == 'A' + assert str(stree.children[3].out_memlets['b'].dst_subset) == 'k + 100' + else: + assert constprop is True + assert isinstance(stree.children[0], tn.CopyNode) + assert isinstance(stree.children[1], tn.TaskletNode) + assert stree.children[0].memlet.data == 'A' + assert str(stree.children[0].memlet.src_subset) == 'M + 101:M + 103' + assert str(stree.children[0].memlet.dst_subset) == 'N:N + 2' + assert stree.children[1].in_memlets['a'].data == 'tmp' + assert str(stree.children[1].in_memlets['a'].src_subset) == 'M' + assert stree.children[1].out_memlets['b'].data == 'A' + assert str(stree.children[1].out_memlets['b'].dst_subset) == 'N + 101' + + +def test_edgecase_symbol_mapping(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('M', dace.int64) + sdfg.add_symbol('N', dace.int64) + + state = sdfg.add_state() + state2 = sdfg.add_state_after(state) + + nsdfg = dace.SDFG('nester') + nsdfg.add_symbol('M', dace.int64) + nsdfg.add_symbol('N', dace.int64) + nsdfg.add_symbol('k', dace.int64) + nstate = nsdfg.add_state() + nstate.add_tasklet('dosomething', {}, {}, 'print(k)', side_effects=True) + nstate2 = nsdfg.add_state() + nstate3 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments={'k': 'M + 1'})) + nsdfg.add_edge(nstate2, nstate3, dace.InterstateEdge(assignments={'l': 'k'})) + + state2.add_nested_sdfg(nsdfg, None, {}, {}, {'N': 'M', 'M': 'N', 'k': 'M + 1'}) + + stree = as_schedule_tree(sdfg) + + # k is reassigned internally, so that should be preserved + assert len(stree.children) == 3 + assert isinstance(stree.children[0], tn.TaskletNode) + assert 'M + 1' in stree.children[0].node.code.as_string + assert isinstance(stree.children[1], tn.AssignNode) + assert stree.children[1].name == 'k' + assert stree.children[1].value.as_string == '(N + 1)' + assert isinstance(stree.children[2], tn.AssignNode) + assert stree.children[2].name == 'l' + assert stree.children[2].value.as_string in ('k', '(N + 1)') + + +def _check_for_name_clashes(stree: tn.ScheduleTreeNode): + + def _traverse(node: tn.ScheduleTreeScope, scopes: List[str]): + for child in node.children: + if isinstance(child, tn.ForScope): + itervar = child.header.itervar + if itervar in scopes: + raise NameError('Nested scope redefines iteration variable') + _traverse(child, scopes + [itervar]) + elif isinstance(child, tn.MapScope): + itervars = child.node.map.params + if any(itervar in scopes for itervar in itervars): + raise NameError('Nested scope redefines iteration variable') + _traverse(child, scopes + itervars) + elif isinstance(child, tn.ScheduleTreeScope): + _traverse(child, scopes) + + _traverse(stree, []) + + +def test_clash_iteration_symbols(): + sdfg = _nested_irreducible_loops() + + stree = as_schedule_tree(sdfg) + _check_for_name_clashes(stree) + + +if __name__ == '__main__': + test_clash_states() + test_clash_symbol_mapping(False) + test_clash_symbol_mapping(True) + test_edgecase_symbol_mapping() + test_clash_iteration_symbols() diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py new file mode 100644 index 0000000000..161f15d6c1 --- /dev/null +++ b/tests/schedule_tree/nesting_test.py @@ -0,0 +1,234 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Nesting and dealiasing tests for schedule trees. +""" +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.transformation.dataflow import RemoveSliceView + +import pytest + +N = dace.symbol('N') +T = dace.symbol('T') + + +def test_stree_mpath_multiscope(): + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + for l in dace.map[0:T]: + A[i + j, k + l] = 1 + + # The test should generate different SDFGs for different simplify configurations, + # but the same schedule tree + stree = as_schedule_tree(tester.to_sdfg()) + assert [type(n) for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.MapScope, tn.TaskletNode] + + +def test_stree_mpath_multiscope_dependent(): + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + for l in dace.map[0:k]: + A[i + j, l] = 1 + + # The test should generate different SDFGs for different simplify configurations, + # but the same schedule tree + stree = as_schedule_tree(tester.to_sdfg()) + assert [type(n) for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.MapScope, tn.TaskletNode] + + +def test_stree_mpath_nested(): + + @dace.program + def nester(A, i, k, j): + for l in range(k): + A[i + j, l] = 1 + + @dace.program + def tester(A: dace.float64[N, N]): + for i in dace.map[0:N:T]: + for j, k in dace.map[0:T, 0:N]: + nester(A, i, j, k) + + stree = as_schedule_tree(tester.to_sdfg()) + + # Simplifying yields a different SDFG due to scalars and symbols, so testing is slightly different + simplified = dace.Config.get_bool('optimizer', 'automatic_simplification') + + if simplified: + assert [type(n) + for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] + + tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] + + if simplified: + assert str(next(iter(tasklet.out_memlets.values()))) == 'A[i + k, l]' + else: + assert str(next(iter(tasklet.out_memlets.values()))).endswith(', l]') + + +@pytest.mark.parametrize('dst_subset', (False, True)) +def test_stree_copy_same_scope(dst_subset): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [3 * N], dace.float64) + sdfg.add_array('B', [3 * N], dace.float64) + state = sdfg.add_state() + + r = state.add_read('A') + w = state.add_write('B') + if not dst_subset: + state.add_nedge(r, w, dace.Memlet(data='A', subset='2*N:3*N', other_subset='N:2*N')) + else: + state.add_nedge(r, w, dace.Memlet(data='B', subset='N:2*N', other_subset='2*N:3*N')) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 and isinstance(stree.children[0], tn.CopyNode) + assert stree.children[0].target == 'B' + assert stree.children[0].as_string() == 'B[N:2*N] = copy A[2*N:3*N]' + + +@pytest.mark.parametrize('dst_subset', (False, True)) +def test_stree_copy_different_scope(dst_subset): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [3 * N], dace.float64) + sdfg.add_array('B', [3 * N], dace.float64) + state = sdfg.add_state() + + r = state.add_read('A') + w = state.add_write('B') + me, mx = state.add_map('something', dict(i='0:1')) + if not dst_subset: + state.add_memlet_path(r, me, w, memlet=dace.Memlet(data='A', subset='2*N:3*N', other_subset='N + i:2*N + i')) + else: + state.add_memlet_path(r, me, w, memlet=dace.Memlet(data='B', subset='N + i:2*N + i', other_subset='2*N:3*N')) + state.add_nedge(w, mx, dace.Memlet()) + + stree = as_schedule_tree(sdfg) + stree_nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in stree_nodes] == [tn.MapScope, tn.CopyNode] + assert stree_nodes[-1].target == 'B' + assert stree_nodes[-1].as_string() == 'B[N + i:2*N + i] = copy A[2*N:3*N]' + + +def test_dealias_nested_call(): + + @dace.program + def nester(a, b): + b[:] = a + + @dace.program + def tester(a: dace.float64[40], b: dace.float64[40]): + nester(b[1:21], a[10:30]) + + sdfg = tester.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(RemoveSliceView) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + copy = stree.children[0] + assert isinstance(copy, tn.CopyNode) + assert copy.target == 'a' + assert copy.memlet.data == 'b' + assert str(copy.memlet.src_subset) == '1:21' + assert str(copy.memlet.dst_subset) == '10:30' + + +def test_dealias_nested_call_samearray(): + + @dace.program + def nester(a, b): + b[:] = a + + @dace.program + def tester(a: dace.float64[40]): + nester(a[1:21], a[10:30]) + + sdfg = tester.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(RemoveSliceView) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + copy = stree.children[0] + assert isinstance(copy, tn.CopyNode) + assert copy.target == 'a' + assert copy.memlet.data == 'a' + assert str(copy.memlet.src_subset) == '1:21' + assert str(copy.memlet.dst_subset) == '10:30' + + +@pytest.mark.parametrize('simplify', (False, True)) +def test_dealias_memlet_composition(simplify): + + def nester2(c): + c[2] = 1 + + def nester1(b): + nester2(b[-5:]) + + @dace.program + def tester(a: dace.float64[N, N]): + nester1(a[:, 1]) + + sdfg = tester.to_sdfg(simplify=simplify) + stree = as_schedule_tree(sdfg) + + # Simplifying yields a different SDFG due to views, so testing is slightly different + if simplify: + assert len(stree.children) == 1 + tasklet = stree.children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert str(next(iter(tasklet.out_memlets.values()))) == 'a[N - 3, 1]' + else: + assert len(stree.children) == 3 + stree_nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in stree_nodes] == [tn.ViewNode, tn.ViewNode, tn.TaskletNode] + + +def test_dealias_interstate_edge(): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + + nsdfg = dace.SDFG('nester') + nsdfg.add_array('A', [19], dace.float64) + nsdfg.add_array('B', [15], dace.float64) + nsdfg.add_symbol('m', dace.float64) + nstate1 = nsdfg.add_state() + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate1, nstate2, dace.InterstateEdge(condition='B[1] > 0', assignments=dict(m='A[2]'))) + + # Connect to nested SDFG both with flipped definitions and offset memlets + state = sdfg.add_state() + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'A', 'B'}, {}) + ra = state.add_read('A') + rb = state.add_read('B') + state.add_edge(ra, None, nsdfg_node, 'B', dace.Memlet('A[1:20]')) + state.add_edge(rb, None, nsdfg_node, 'A', dace.Memlet('B[2:17]')) + + sdfg.validate() + stree = as_schedule_tree(sdfg) + nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in nodes] == [tn.StateIfScope, tn.GotoNode, tn.AssignNode] + assert 'A[2]' in nodes[0].condition.as_string + assert 'B[4]' in nodes[-1].value.as_string + + +if __name__ == '__main__': + test_stree_mpath_multiscope() + test_stree_mpath_multiscope_dependent() + test_stree_mpath_nested() + test_stree_copy_same_scope(False) + test_stree_copy_same_scope(True) + test_stree_copy_different_scope(False) + test_stree_copy_different_scope(True) + test_dealias_nested_call() + test_dealias_nested_call_samearray() + test_dealias_memlet_composition(False) + test_dealias_memlet_composition(True) + test_dealias_interstate_edge() diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py new file mode 100644 index 0000000000..09779c670f --- /dev/null +++ b/tests/schedule_tree/schedule_test.py @@ -0,0 +1,289 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +import numpy as np + + +def test_for_in_map_in_for(): + + @dace.program + def matmul(A: dace.float32[10, 10], B: dace.float32[10, 10], C: dace.float32[10, 10]): + for i in range(10): + for j in dace.map[0:10]: + atile = dace.define_local([10], dace.float32) + atile[:] = A[i] + for k in range(10): + with dace.tasklet: + a << atile[k] + b << B[k, j] + cin << C[i, j] + c >> C[i, j] + c = cin + a * b + + sdfg = matmul.to_sdfg() + stree = as_schedule_tree(sdfg) + + assert len(stree.children) == 1 # for + fornode = stree.children[0] + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # map + mapnode = fornode.children[0] + assert isinstance(mapnode, tn.MapScope) + assert len(mapnode.children) == 2 # copy, for + copynode, fornode = mapnode.children + assert isinstance(copynode, tn.CopyNode) + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # tasklet + tasklet = fornode.children[0] + assert isinstance(tasklet, tn.TaskletNode) + + +def test_libnode(): + M, N, K = (dace.symbol(s) for s in 'MNK') + + @dace.program + def matmul_lib(a: dace.float64[M, K], b: dace.float64[K, N]): + return a @ b + + sdfg = matmul_lib.to_sdfg() + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 1 + assert isinstance(stree.children[0], tn.LibraryCall) + assert (stree.children[0].as_string() == + '__return[0:M, 0:N] = library MatMul[alpha=1, beta=0](a[0:M, 0:K], b[0:K, 0:N])') + + +def test_nesting(): + + @dace.program + def nest2(a: dace.float64[10]): + a += 1 + + @dace.program + def nest1(a: dace.float64[5, 10]): + for i in range(5): + nest2(a[:, i]) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a[:5]) + nest1(a[5:10]) + nest1(a[10:15]) + nest1(a[15:]) + + sdfg = main.to_sdfg(simplify=True) + stree = as_schedule_tree(sdfg) + + # Despite two levels of nesting, immediate children are the 4 for loops + assert len(stree.children) == 4 + offsets = ['', '5', '10', '15'] + for fornode, offset in zip(stree.children, offsets): + assert isinstance(fornode, tn.ForScope) + assert len(fornode.children) == 1 # map + mapnode = fornode.children[0] + assert isinstance(mapnode, tn.MapScope) + assert len(mapnode.children) == 1 # tasklet + tasklet = mapnode.children[0] + assert isinstance(tasklet, tn.TaskletNode) + assert offset in str(next(iter(tasklet.in_memlets.values()))) + + +def test_nesting_view(): + + @dace.program + def nest2(a: dace.float64[40]): + a += 1 + + @dace.program + def nest1(a): + for i in range(5): + subset = a[:, i, :] + nest2(subset.reshape((40, ))) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a.reshape((4, 5, 10))) + + sdfg = main.to_sdfg() + stree = as_schedule_tree(sdfg) + assert any(isinstance(node, tn.ViewNode) for node in stree.children) + + +def test_nesting_nview(): + + @dace.program + def nest2(a: dace.float64[40]): + a += 1 + + @dace.program + def nest1(a: dace.float64[4, 5, 10]): + for i in range(5): + nest2(a[:, i, :]) + + @dace.program + def main(a: dace.float64[20, 10]): + nest1(a) + + sdfg = main.to_sdfg() + stree = as_schedule_tree(sdfg) + assert isinstance(stree.children[0], tn.NView) + + +def test_irreducible_sub_sdfg(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain + s = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + # Add an irreducible CFG + s1 = sdfg.add_state() + s2 = sdfg.add_state() + + sdfg.add_edge(s, s1, dace.InterstateEdge('a < b')) + # sdfg.add_edge(s, s2, dace.InterstateEdge('a >= b')) + sdfg.add_edge(s1, s2, dace.InterstateEdge('b > 9')) + sdfg.add_edge(s2, s1, dace.InterstateEdge('b < 19')) + e = sdfg.add_state() + sdfg.add_edge(s1, e, dace.InterstateEdge('a < 0')) + sdfg.add_edge(s2, e, dace.InterstateEdge('b < 0')) + + # Add a loop following general block + sdfg.add_loop(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1') + + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 # Only one gblock + assert node_types[-1] == tn.ForScope # Check that loop was detected + + +def test_irreducible_in_loops(): + sdfg = dace.SDFG('irreducible') + # Add a simple chain of two for loops with goto from second to first's body + s1 = sdfg.add_state_after(sdfg.add_state_after(sdfg.add_state())) + s2 = sdfg.add_state() + e = sdfg.add_state() + + # Add a loop + l1 = sdfg.add_state() + l2 = sdfg.add_state_after(l1) + sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + + l3 = sdfg.add_state() + l4 = sdfg.add_state_after(l3) + sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) + + # Irreducible part + sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) + + # Avoiding undefined behavior + sdfg.edges_between(l3, l4)[0].data.condition.as_string = 'i >= 5' + + stree = as_schedule_tree(sdfg) + node_types = [type(n) for n in stree.preorder_traversal()] + assert node_types.count(tn.GBlock) == 1 + assert node_types.count(tn.ForScope) == 2 + + +def test_reference(): + sdfg = dace.SDFG('tester') + sdfg.add_symbol('n', dace.int32) + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + sdfg.add_array('C', [20], dace.float64) + sdfg.add_reference('ref', [20], dace.float64) + + init = sdfg.add_state() + s1 = sdfg.add_state() + s2 = sdfg.add_state() + end = sdfg.add_state() + sdfg.add_edge(init, s1, dace.InterstateEdge('n > 0')) + sdfg.add_edge(init, s2, dace.InterstateEdge('n <= 0')) + sdfg.add_edge(s1, end, dace.InterstateEdge()) + sdfg.add_edge(s2, end, dace.InterstateEdge()) + + s1.add_edge(s1.add_access('A'), None, s1.add_access('ref'), 'set', dace.Memlet('A[0:20]')) + s2.add_edge(s2.add_access('B'), None, s2.add_access('ref'), 'set', dace.Memlet('B[0:20]')) + end.add_nedge(end.add_access('ref'), end.add_access('C'), dace.Memlet('ref[0:20]')) + + stree = as_schedule_tree(sdfg) + nodes = list(stree.preorder_traversal())[1:] + assert [type(n) for n in nodes] == [tn.IfScope, tn.RefSetNode, tn.ElseScope, tn.RefSetNode, tn.CopyNode] + assert nodes[1].as_string() == 'ref = refset to A[0:20]' + assert nodes[3].as_string() == 'ref = refset to B[0:20]' + + +def test_code_to_code(): + sdfg = dace.SDFG('tester') + sdfg.add_scalar('scal', dace.int32, transient=True) + state = sdfg.add_state() + t1 = state.add_tasklet('a', {}, {'out'}, 'out = 5') + t2 = state.add_tasklet('b', {'inp'}, {}, 'print(inp)', side_effects=True) + state.add_edge(t1, 'out', t2, 'inp', dace.Memlet('scal')) + + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 2 + assert all(isinstance(c, tn.TaskletNode) for c in stree.children) + assert stree.children[1].as_string().startswith('tasklet(scal') + + +def test_dyn_map_range(): + H = dace.symbol() + nnz = dace.symbol('nnz') + W = dace.symbol() + + @dace.program + def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32[nnz], x: dace.float32[W]): + b = np.zeros([H], dtype=np.float32) + + for i in dace.map[0:H]: + for j in dace.map[A_row[i]:A_row[i + 1]]: + b[i] += A_val[j] * x[A_col[j]] + + return b + + sdfg = spmv.to_sdfg() + stree = as_schedule_tree(sdfg) + assert len(stree.children) == 2 + assert all(isinstance(c, tn.MapScope) for c in stree.children) + mapscope = stree.children[1] + start, end, dynrangemap = mapscope.children + assert isinstance(start, tn.DynScopeCopyNode) + assert isinstance(end, tn.DynScopeCopyNode) + assert isinstance(dynrangemap, tn.MapScope) + + +def test_multiview(): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20, 20], dace.float64) + sdfg.add_array('B', [20, 20], dace.float64) + sdfg.add_view('Av', [400], dace.float64) + sdfg.add_view('Avv', [10, 40], dace.float64) + sdfg.add_view('Bv', [400], dace.float64) + sdfg.add_view('Bvv', [10, 40], dace.float64) + state = sdfg.add_state() + av = state.add_access('Av') + bv = state.add_access('Bv') + bvv = state.add_access('Bvv') + avv = state.add_access('Avv') + state.add_edge(state.add_read('A'), None, av, None, dace.Memlet('A[0:20, 0:20]')) + state.add_edge(av, None, avv, 'views', dace.Memlet('Av[0:400]')) + state.add_edge(avv, None, bvv, None, dace.Memlet('Avv[0:10, 0:40]')) + state.add_edge(bvv, 'views', bv, None, dace.Memlet('Bv[0:400]')) + state.add_edge(bv, 'views', state.add_write('B'), None, dace.Memlet('Bv[0:400]')) + + stree = as_schedule_tree(sdfg) + assert [type(n) for n in stree.children] == [tn.ViewNode, tn.ViewNode, tn.ViewNode, tn.ViewNode, tn.CopyNode] + + +if __name__ == '__main__': + test_for_in_map_in_for() + test_libnode() + test_nesting() + test_nesting_view() + test_nesting_nview() + test_irreducible_sub_sdfg() + test_irreducible_in_loops() + test_reference() + test_code_to_code() + test_dyn_map_range() + test_multiview() diff --git a/tests/sdfg/data/structure_test.py b/tests/sdfg/data/structure_test.py index 02b8f0c174..55e3a936a7 100644 --- a/tests/sdfg/data/structure_test.py +++ b/tests/sdfg/data/structure_test.py @@ -443,6 +443,52 @@ def test_direct_read_structure(): assert np.allclose(B, ref) +def test_direct_read_structure_loops(): + + M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) + csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), + name='CSRMatrix') + + sdfg = dace.SDFG('csr_to_dense_direct_loops') + + sdfg.add_datadesc('A', csr_obj) + sdfg.add_array('B', [M, N], dace.float32) + + state = sdfg.add_state() + + indices = state.add_access('A.indices') + data = state.add_access('A.data') + B = state.add_access('B') + + t = state.add_tasklet('indirection', {'j', '__val'}, {'__out'}, '__out[i, j] = __val') + state.add_edge(indices, None, t, 'j', dace.Memlet(data='A.indices', subset='idx')) + state.add_edge(data, None, t, '__val', dace.Memlet(data='A.data', subset='idx')) + state.add_edge(t, '__out', B, None, dace.Memlet(data='B', subset='0:M, 0:N', volume=1)) + + idx_before, idx_guard, idx_after = sdfg.add_loop(None, state, None, 'idx', 'A.indptr[i]', 'idx < A.indptr[i+1]', 'idx + 1') + i_before, i_guard, i_after = sdfg.add_loop(None, idx_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=idx_after) + + func = sdfg.compile() + + rng = np.random.default_rng(42) + A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = np.zeros((20, 20), dtype=np.float32) + + inpA = csr_obj.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0], + indices=A.indices.__array_interface__['data'][0], + data=A.data.__array_interface__['data'][0], + rows=A.shape[0], + cols=A.shape[1], + M=A.shape[0], + N=A.shape[1], + nnz=A.nnz) + + func(A=inpA, B=B, M=20, N=20, nnz=A.nnz) + ref = A.toarray() + + assert np.allclose(B, ref) + + def test_direct_read_nested_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) csr_obj = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), @@ -505,3 +551,4 @@ def test_direct_read_nested_structure(): test_write_nested_structure() test_direct_read_structure() test_direct_read_nested_structure() + test_direct_read_structure_loops() diff --git a/tests/sdfg/memlet_utils_test.py b/tests/sdfg/memlet_utils_test.py new file mode 100644 index 0000000000..467838fc56 --- /dev/null +++ b/tests/sdfg/memlet_utils_test.py @@ -0,0 +1,67 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +import pytest +from dace.sdfg import memlet_utils as mu + + +def _replace_zero_with_one(memlet: dace.Memlet) -> dace.Memlet: + for i, s in enumerate(memlet.subset): + if s == 0: + memlet.subset[i] = 1 + return memlet + + +@pytest.mark.parametrize('filter_type', ['none', 'same_array', 'different_array']) +def test_replace_memlet(filter_type): + # Prepare SDFG + sdfg = dace.SDFG('replace_memlet') + sdfg.add_array('A', [2, 2], dace.float64) + sdfg.add_array('B', [1], dace.float64) + state1 = sdfg.add_state() + state2 = sdfg.add_state() + state3 = sdfg.add_state() + end_state = sdfg.add_state() + sdfg.add_edge(state1, state2, dace.InterstateEdge('A[0, 0] > 0')) + sdfg.add_edge(state1, state3, dace.InterstateEdge('A[0, 0] <= 0')) + sdfg.add_edge(state2, end_state, dace.InterstateEdge()) + sdfg.add_edge(state3, end_state, dace.InterstateEdge()) + + t2 = state2.add_tasklet('write_one', {}, {'out'}, 'out = 1') + t3 = state3.add_tasklet('write_two', {}, {'out'}, 'out = 2') + w2 = state2.add_write('B') + w3 = state3.add_write('B') + state2.add_memlet_path(t2, w2, src_conn='out', memlet=dace.Memlet('B')) + state3.add_memlet_path(t3, w3, src_conn='out', memlet=dace.Memlet('B')) + + # Filter memlets + if filter_type == 'none': + filter = set() + elif filter_type == 'same_array': + filter = {'A'} + elif filter_type == 'different_array': + filter = {'B'} + + # Replace memlets in conditions + replacer = mu.MemletReplacer(sdfg.arrays, _replace_zero_with_one, filter) + for e in sdfg.edges(): + e.data.condition.code[0] = replacer.visit(e.data.condition.code[0]) + + # Compile and run + sdfg.compile() + + A = np.array([[1, 1], [1, -1]], dtype=np.float64) + B = np.array([0], dtype=np.float64) + sdfg(A=A, B=B) + + if filter_type in {'none', 'same_array'}: + assert B[0] == 2 + else: + assert B[0] == 1 + + +if __name__ == '__main__': + test_replace_memlet('none') + test_replace_memlet('same_array') + test_replace_memlet('different_array') diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py index 133afe8ae4..05375007df 100644 --- a/tests/sdfg/work_depth_tests.py +++ b/tests/sdfg/work_depth_tests.py @@ -1,14 +1,18 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ import dace as dc -from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth +from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth, parse_assumptions from dace.sdfg.work_depth_analysis.helpers import get_uuid +from dace.sdfg.work_depth_analysis.assumptions import ContradictingAssumptions import sympy as sp from dace.transformation.interstate import NestSDFG from dace.transformation.dataflow import MapExpansion +from pytest import raises + # TODO: add tests for library nodes (e.g. reduce, matMul) +# TODO: add tests for average parallelism N = dc.symbol('N') M = dc.symbol('M') @@ -65,11 +69,11 @@ def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): @dc.program def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): if x[10] > 50: - if x[9] > 50: + if x[9] > 40: z[:] = x + y # N work, 1 depth z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth else: - if y[9] > 50: + if y[9] > 30: for i in range(K): sum += x[i] # K work, K depth else: @@ -153,6 +157,22 @@ def break_while_loop(x: dc.float64[N]): x += 1 +@dc.program +def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive + if x[0] > 5: + x[:] += 1 # N+1 work, 1 depth + else: + for i in range(M): # M work, M depth + y[i + 1] += y[i] + if M > N: + y[:N + 1] += x[:] # N+1 work, 1 depth + else: + x[:M + 1] += y[:] # M+1 work, 1 depth + # --> Work: Max(N+1, M) + Max(N+1, M+1) + # Depth: Max(1, M) + 1 + + +#(sdfg, (expected_work, expected_depth)) tests_cases = [ (single_map, (N, 1)), (single_for_loop, (N, N)), @@ -164,25 +184,18 @@ def break_while_loop(x: dc.float64[N]): (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), (max_of_positive_symbol, (3 * N**2, 3 * N)), (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True) * N, sp.Symbol('num_execs_0_2', - nonnegative=True))), + (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)) * N, - sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)))), - (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7', nonnegative=True) * N, - 2 * sp.Symbol('num_execs_0_7', nonnegative=True))), - (continue_for_loop, (sp.Symbol('num_execs_0_6', nonnegative=True) * N, sp.Symbol('num_execs_0_6', - nonnegative=True))), + (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, sp.Max(1, sp.Symbol('num_execs_0_1')))), + (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), + (continue_for_loop, (sp.Symbol('num_execs_0_6') * N, sp.Symbol('num_execs_0_6'))), (break_for_loop, (N**2, N)), - (break_while_loop, (sp.Symbol('num_execs_0_5', nonnegative=True) * N, sp.Symbol('num_execs_0_5', nonnegative=True))) + (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) ] def test_work_depth(): - good = 0 - failed = 0 - exception = 0 - failed_tests = [] for test, correct in tests_cases: w_d_map = {} sdfg = test.to_sdfg() @@ -190,12 +203,60 @@ def test_work_depth(): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test.name: sdfg.apply_transformations(MapExpansion) - - analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} + res = (res[0].subs(reps), res[1].subs(reps)) + reps = { + s: sp.Symbol(s.name) + for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols) + } + correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) # check result assert correct == res +x, y, z, a = sp.symbols('x y z a') + +# (expr, assumptions, result) +assumptions_tests = [ + (sp.Max(x, y), ['x>y'], x), (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), (sp.Max(x, y), ['x==y'], y), + (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x, 3)), (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', + 'x>3'], 11 + x), + (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), (sp.Max(x, 11), ['x==y', 'x>11'], y), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'a<11', 'c>7'], x + 11), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), (sp.Max(x, y), ['y>x', 'y==1000'], 1000), + (sp.Max(x, y), ['y0', 'N<5', 'M>5'], M) +] + +# These assumptions should trigger the ContradictingAssumptions exception. +tests_for_exception = [['x>10', 'x<9'], ['x==y', 'x>10', 'y<9'], + ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], + ['x==5', 'x<4']] + + +def test_assumption_system(): + for expr, assums, res in assumptions_tests: + equality_subs, all_subs = parse_assumptions(assums, set()) + initial_expr = expr + expr = expr.subs(equality_subs[0]) + expr = expr.subs(equality_subs[1]) + for subs1, subs2 in all_subs: + expr = expr.subs(subs1) + expr = expr.subs(subs2) + assert expr == res + + for assums in tests_for_exception: + # check that the Exception gets raised. + with raises(ContradictingAssumptions): + parse_assumptions(assums, set()) + + if __name__ == '__main__': test_work_depth() + test_assumption_system() diff --git a/tests/symbol_dependent_transients_test.py b/tests/symbol_dependent_transients_test.py index f718abf379..8033b6b196 100644 --- a/tests/symbol_dependent_transients_test.py +++ b/tests/symbol_dependent_transients_test.py @@ -45,7 +45,7 @@ def _make_sdfg(name, storage=dace.dtypes.StorageType.CPU_Heap, isview=False): body2_state.add_nedge(read_a, read_tmp1, dace.Memlet(f'A[2:{N}-2, 2:{N}-2, i:{N}]')) else: read_tmp1 = body2_state.add_read('tmp1') - rednode = standard.Reduce(wcr='lambda a, b : a + b', identity=0) + rednode = standard.Reduce('sum', wcr='lambda a, b : a + b', identity=0) if storage == dace.dtypes.StorageType.GPU_Global: rednode.implementation = 'CUDA (device)' elif storage == dace.dtypes.StorageType.FPGA_Global: diff --git a/tests/symbol_mapping_replace_test.py b/tests/symbol_mapping_replace_test.py index cd47320bf1..cbb572bc81 100644 --- a/tests/symbol_mapping_replace_test.py +++ b/tests/symbol_mapping_replace_test.py @@ -27,14 +27,15 @@ def outer(A, inp1: float, inp2: float): def test_symbol_mapping_replace(): - with dace.config.set_temporary('optimizer', 'automatic_simplification', value=True): - A = np.ones((10, 10, 10)) - ref = A.copy() - b = 2.0 - c = 2.0 - outer(A, inp1=b, inp2=c) - outer.f(ref, inp1=b, inp2=c) - assert (np.allclose(A, ref)) + # TODO/NOTE: Setting temporary config values does not work in the CI + # with dace.config.set_temporary('optimizer', 'automatic_simplification', value=True): + A = np.ones((10, 10, 10)) + ref = A.copy() + b = 2.0 + c = 2.0 + outer(A, inp1=b, inp2=c) + outer.f(ref, inp1=b, inp2=c) + assert (np.allclose(A, ref)) if __name__ == '__main__': diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index 67c60c01bf..dca775bb7a 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -96,17 +96,17 @@ def test_multiple_edges(self): def test_itervar_in_map_range(self): sdfg = should_not_apply_1.to_sdfg(simplify=True) count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEquals(count, 0) + self.assertEqual(count, 0) def test_itervar_in_data(self): sdfg = should_not_apply_2.to_sdfg(simplify=True) count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEquals(count, 0) + self.assertEqual(count, 0) def test_non_injective_index(self): sdfg = should_not_apply_3.to_sdfg(simplify=True) count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEquals(count, 0) + self.assertEqual(count, 0) def test_apply_multiple_times(self): sdfg = apply_multiple_times.to_sdfg(simplify=True) diff --git a/tests/transformations/tasklet_fusion_test.py b/tests/transformations/tasklet_fusion_test.py index c7fd6802d5..743010e8c9 100644 --- a/tests/transformations/tasklet_fusion_test.py +++ b/tests/transformations/tasklet_fusion_test.py @@ -213,6 +213,49 @@ def test_map_with_tasklets(language: str, with_data: bool): ref = map_with_tasklets.f(A, B) assert (np.allclose(C, ref)) +def test_none_connector(): + @dace.program + def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]): + tmp = dace.define_local([32], dace.float32) + for i in dace.map[0:32]: + with dace.tasklet: + a >> tmp[i] + a = 0 + + tmp2 = dace.define_local([32], dace.float32) + for i in dace.map[0:32]: + with dace.tasklet: + a << A[i] + b >> tmp2[i] + b = a + 1 + + + for i in dace.map[0:32]: + with dace.tasklet: + a << tmp[i] + b << tmp2[i] + c >> B[i] + c = a + b + + sdfg = sdfg_none_connector.to_sdfg() + sdfg.simplify() + applied = sdfg.apply_transformations_repeated(MapFusion) + assert applied == 2 + + map_entry = None + for node in sdfg.start_state.nodes(): + if isinstance(node, dace.nodes.MapEntry): + map_entry = node + break + + assert map_entry is not None + assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 1 + + applied = sdfg.apply_transformations_repeated(TaskletFusion) + assert applied == 2 + + assert sdfg.start_state.out_degree(map_entry) == 1 + assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 0 if __name__ == '__main__': test_basic() @@ -224,3 +267,4 @@ def test_map_with_tasklets(language: str, with_data: bool): test_map_with_tasklets(language='Python', with_data=True) test_map_with_tasklets(language='CPP', with_data=False) test_map_with_tasklets(language='CPP', with_data=True) + test_none_connector()