diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 613b4f8557..6a2b89e4ae 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -319,7 +319,6 @@ def declare_array(self, name = node.root_data ptrname = cpp.ptr(name, nodedesc, sdfg, self._frame) - print("D2", name, nodedesc) if nodedesc.transient is False: return @@ -344,7 +343,7 @@ def declare_array(self, if nodedesc.transient and nodedesc.storage == dtypes.StorageType.CPU_Heap: size_desc_name = sdfg.arrays[name].size_desc_name if size_desc_name is not None: - size_desc = sdfg.size_arrays[size_desc_name] + size_desc = sdfg.arrays[size_desc_name] size_ctypedef = dtypes.pointer(size_desc.dtype).ctype self._dispatcher.declared_arrays.add(size_desc_name, DefinedType.Pointer, size_ctypedef) return @@ -513,9 +512,13 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV declaration_stream.write(f'{nodedesc.dtype.ctype} *{name};\n', cfg, state_id, node) # Initialize size array size_str = ",".join(["0" if cpp.sym2cpp(dim).startswith("__dace_defer") else cpp.sym2cpp(dim) for dim in nodedesc.shape]) - size_desc_name = nodedesc.size_desc_name - size_nodedesc = sdfg.size_arrays[size_desc_name] - declaration_stream.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n', cfg, state_id, node) + if (nodedesc.transient and ( + nodedesc.storage == dtypes.StorageType.CPU_Heap or + nodedesc.storage == dtypes.StorageType.GPU_Global) + ): + size_desc_name = nodedesc.size_desc_name + size_nodedesc = sdfg.arrays[size_desc_name] + declaration_stream.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n', cfg, state_id, node) if deferred_allocation: allocation_stream.write( "%s = nullptr; // Deferred Allocation" % diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index e5ad6dc9dc..bbc485c336 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -613,7 +613,7 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV if nodedesc.transient: size_desc_name = nodedesc.size_desc_name if size_desc_name is not None: - size_nodedesc = sdfg.size_arrays[size_desc_name] + size_nodedesc = sdfg.arrays[size_desc_name] result_decl.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n') self._dispatcher.defined_vars.add(size_desc_name, DefinedType.Pointer, size_nodedesc.dtype.ctype) self._dispatcher.defined_vars.add(dataname, DefinedType.Pointer, ctypedef) @@ -1586,7 +1586,7 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub if aname in sdfg.arrays: size_arr_name = data_desc.size_desc_name if size_arr_name is not None: - size_arr = sdfg.size_arrays[data_desc.size_desc_name] + size_arr = sdfg.arrays[data_desc.size_desc_name] host_size_args[size_arr_name] = size_arr kernel_args_typed = [('const ' if k in const_params else '') + v.as_arg(name=k) diff --git a/dace/data.py b/dace/data.py index 8a606eac77..355532208b 100644 --- a/dace/data.py +++ b/dace/data.py @@ -183,6 +183,7 @@ def _transient_setter(self, value): default=dtypes.AllocationLifetime.Scope) location = DictProperty(key_type=str, value_type=str, desc='Full storage location identifier (e.g., rank, GPU ID)') debuginfo = DebugInfoProperty(allow_none=True) + size_desc_name = Property(dtype=str, default=None, allow_none=True) def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo): self.dtype = dtype @@ -192,6 +193,7 @@ def __init__(self, dtype, shape, transient, storage, location, lifetime, debugin self.location = location if location is not None else {} self.lifetime = lifetime self.debuginfo = debuginfo + self.size_desc_name = None self._validate() def __call__(self): @@ -1385,9 +1387,6 @@ class Array(Data): 'it is inferred by other properties and the OptionalArrayInference pass.') pool = Property(dtype=bool, default=False, desc='Hint to the allocator that using a memory pool is preferred') - size_desc_name = Property(dtype=str, default=None, allow_none=True, desc='The name of the size array (1D, length is the shape of thte current array)' - 'Of the array (usually _size)') - def __init__(self, dtype, shape, diff --git a/dace/memlet.py b/dace/memlet.py index 85bd0a348d..8d396d8e4c 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -68,9 +68,9 @@ def __init__(self, debuginfo: Optional[dtypes.DebugInfo] = None, wcr_nonatomic: bool = False, allow_oob: bool = False): - """ + """ Constructs a Memlet. - + :param expr: A string expression of the this memlet, given as an ease of use API. Must follow one of the following forms: 1. ``ARRAY``, @@ -82,7 +82,7 @@ def __init__(self, :param subset: The subset to take from the data attached to the edge, represented either as a string or a Subset object. :param other_subset: The subset to offset into the other side of the - memlet, represented either as a string or a Subset + memlet, represented either as a string or a Subset object. :param volume: The exact number of elements moved using this memlet, or the maximum number of elements if @@ -91,14 +91,14 @@ def __init__(self, is runtime-defined and unbounded. :param dynamic: If True, the number of elements moved in this memlet is defined dynamically at runtime. - :param wcr: A lambda function (represented as a string or Python AST) + :param wcr: A lambda function (represented as a string or Python AST) specifying how write-conflicts are resolved. The syntax - of the lambda function receives two elements: ``current`` - value and `new` value, and returns the value after + of the lambda function receives two elements: ``current`` + value and `new` value, and returns the value after resolution. For example, summation is represented by ``'lambda cur, new: cur + new'``. :param debuginfo: Line information from the generating source code. - :param wcr_nonatomic: If True, overrides the automatic code generator + :param wcr_nonatomic: If True, overrides the automatic code generator decision and treat all write-conflict resolution operations as non-atomic, which might cause race conditions in the general case. @@ -225,16 +225,16 @@ def __deepcopy__(self, memo): return node def is_empty(self) -> bool: - """ + """ Returns True if this memlet carries no data. Memlets without data are - primarily used for connecting nodes to scopes without transferring - data to them. + primarily used for connecting nodes to scopes without transferring + data to them. """ return (self.data is None and self.subset is None and self.other_subset is None) @property def num_accesses(self): - """ + """ Returns the total memory movement volume (in elements) of this memlet. """ return self.volume @@ -255,7 +255,7 @@ def simple(data, """ DEPRECATED: Constructs a Memlet from string-based expressions. - :param data: The data object or name to access. + :param data: The data object or name to access. :param subset_str: The subset of `data` that is going to be accessed in string format. Example: '0:N'. :param wcr_str: A lambda function (as a string) specifying @@ -335,7 +335,7 @@ def _parse_from_subexpr(self, expr: str): # [subset] syntax if expr.startswith('['): return None, SubsetProperty.from_string(expr[1:-1]) - + # array[subset] syntax arrname, subset_str = expr[:-1].split('[') if not dtypes.validate_name(arrname): @@ -385,8 +385,8 @@ def _parse_memlet_from_str(self, expr: str): def try_initialize(self, sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState', edge: 'dace.sdfg.graph.MultiConnectorEdge'): - """ - Tries to initialize the internal fields of the memlet (e.g., src/dst + """ + Tries to initialize the internal fields of the memlet (e.g., src/dst subset) once it is added to an SDFG as an edge. """ from dace.sdfg.nodes import AccessNode, CodeNode # Avoid import loops @@ -435,7 +435,7 @@ def get_dst_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dac @staticmethod def from_array(dataname, datadesc, wcr=None): - """ + """ Constructs a Memlet that transfers an entire array's contents. :param dataname: The name of the data descriptor in the SDFG. @@ -456,7 +456,7 @@ def __eq__(self, other): def replace(self, repl_dict): """ Substitute a given set of symbols with a different set of symbols. - + :param repl_dict: A dict of string symbol names to symbols with which to replace them. """ @@ -538,8 +538,8 @@ def validate(self, sdfg, state): def used_symbols(self, all_symbols: bool, edge=None) -> Set[str]: """ - Returns a set of symbols used in this edge's properties. - + Returns a set of symbols used in this edge's properties. + :param all_symbols: If False, only returns the set of symbols that will be used in the generated code and are needed as arguments. :param edge: If given, provides richer context-based tests for the case @@ -606,7 +606,7 @@ def get_free_symbols_by_indices(self, indices_src: List[int], indices_dst: List[ def get_stride(self, sdfg: 'dace.sdfg.SDFG', map: 'dace.sdfg.nodes.Map', dim: int = -1) -> 'dace.symbolic.SymExpr': """ Returns the stride of the underlying memory when traversing a Map. - + :param sdfg: The SDFG in which the memlet resides. :param map: The map in which the memlet resides. :param dim: The dimension that is incremented. By default it is the innermost. diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index c05708670e..6219784cfe 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -52,9 +52,9 @@ def infer_out_connector_type(sdfg: SDFG, state: SDFGState, node: nodes.CodeNode, def infer_connector_types(sdfg: SDFG): - """ + """ Infers connector types throughout an SDFG and its nested SDFGs in-place. - + :param sdfg: The SDFG to infer. """ # Loop over states, and in a topological sort over each state's nodes @@ -125,13 +125,13 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E use_parent_schedule: bool = False, state: SDFGState = None, child_nodes: Dict[nodes.Node, List[nodes.Node]] = None): - """ + """ Sets default storage and schedule types throughout SDFG in-place. Replaces ``ScheduleType.Default`` and ``StorageType.Default`` - with the corresponding types according to the parent scope's schedule. - + with the corresponding types according to the parent scope's schedule. + The defaults for storage types are determined by the - ``dtypes.SCOPEDEFAULT_STORAGE`` dictionary (for example, a GPU device + ``dtypes.SCOPEDEFAULT_STORAGE`` dictionary (for example, a GPU device schedule, by default, will allocate containers on the shared memory). Following storage type inference for a scope, nested scopes (e.g., map entry, nested SDFG) are evaluated using the ``dtypes.STORAGEDEFAULT_SCHEDULE`` dictionary (for example, a diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index e3bea0b807..e34c6228e3 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -54,7 +54,7 @@ def _replsym(symlist, symrepl): def replace_dict(subgraph: 'StateSubgraphView', repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): - """ + """ Finds and replaces all occurrences of a set of symbols/arrays in the given subgraph. :param subgraph: The given graph or subgraph to replace in. @@ -86,7 +86,7 @@ def replace_dict(subgraph: 'StateSubgraphView', def replace(subgraph: 'StateSubgraphView', name: str, new_name: str): """ Finds and replaces all occurrences of a symbol or array in the given subgraph. - + :param subgraph: The given graph or subgraph to replace in. :param name: Name to find. :param new_name: Name to replace. diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 8d8ea82484..eaa0717c86 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -101,8 +101,8 @@ def _nested_arrays_from_json(obj, context=None): return NestedDict({k: dace.serialize.from_json(v, context) for k, v in obj.items()}) -def _replace_dict_keys(d, old, new): - if old in d: +def _replace_dict_keys(d, old, new, filter=None): + if old in d and (filter is None or old in filter): if new in d: warnings.warn('"%s" already exists in SDFG' % new) d[new] = d[old] @@ -418,7 +418,7 @@ class SDFG(ControlFlowRegion): desc="Data descriptors for this SDFG", to_json=_arrays_to_json, from_json=_nested_arrays_from_json) - _size_arrays = Property(dtype=NestedDict, + _arrays = Property(dtype=NestedDict, desc="Data size descriptors for this SDFG", to_json=_arrays_to_json, from_json=_nested_arrays_from_json) @@ -500,7 +500,7 @@ def __init__(self, self._parent_sdfg = None self._parent_nsdfg_node = None self._arrays = NestedDict() # type: Dict[str, dt.Array] - self._size_arrays = NestedDict() + self._arrays = NestedDict() self.arg_names = [] self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -689,8 +689,8 @@ def arrays(self): return self._arrays @property - def size_arrays(self): - return self._size_arrays + def arrays(self): + return self._arrays @property def process_grids(self): @@ -749,18 +749,31 @@ def replace_dict(self, } # Replace in arrays and symbols (if a variable name) + size_arrays = {v.size_desc_name for v in self.arrays.values() + if v.size_desc_name is not None and v.size_desc_name in self.arrays} + non_size_arrays = {k for k in self.arrays if k not in size_arrays} + size_desc_map = dict() + if replace_keys: # Filter out nested data names, as we cannot and do not want to replace names in nested data descriptors repldict_filtered = {k: v for k, v in repldict.items() if '.' not in k} for name, new_name in repldict_filtered.items(): if validate_name(new_name): - _replace_dict_keys(self._arrays, name, new_name) - _replace_dict_keys(self._size_arrays, name + "_size", new_name + "_size") + _replace_dict_keys(self.arrays, name, new_name, non_size_arrays) + if new_name != "__return": + size_desc_map[new_name] = new_name + "_size" + _replace_dict_keys(self.arrays, name + "_size", new_name + "_size", size_arrays) _replace_dict_keys(self.symbols, name, new_name) _replace_dict_keys(self.constants_prop, name, new_name) _replace_dict_keys(self.callback_mapping, name, new_name) _replace_dict_values(self.callback_mapping, name, new_name) + # Update size descriptors + # Return_size break things delete it from the arrays + for arr_name, size_desc_name in size_desc_map.items(): + arr = self.arrays[arr_name] + arr.size_desc_name = size_desc_name if size_desc_name != "__return_size" else None + # Replace inside data descriptors for array in self.arrays.values(): replace_properties_dict(array, repldict, symrepl) @@ -1162,9 +1175,11 @@ def remove_data(self, name, validate=True): f"{node} in state {state}.") size_desc_name = self._arrays[name].size_desc_name + # If unused it might have been removed by optimization + if size_desc_name is not None and size_desc_name in self._arrays: + del self._arrays[size_desc_name] del self._arrays[name] - if size_desc_name is not None: - del self._size_arrays[size_desc_name] + def reset_sdfg_list(self): """ @@ -1689,14 +1704,14 @@ def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | self._subarrays.keys() - | self._rdistrarrays.keys() | self.symbols.keys() | self._size_arrays.keys()) + | self._rdistrarrays.keys() | self.symbols.keys() | self._arrays.keys()) return dt.find_new_name(name, names) def is_name_used(self, name: str) -> bool: """ Checks if `name` is already used inside the SDFG.""" if name in self._arrays: return True - if name in self._size_arrays: + if name in self._arrays: return True if name in self.symbols: return True @@ -1768,22 +1783,6 @@ def add_array(self, if isinstance(dtype, type) and dtype in dtypes._CONSTANT_TYPES[:-1]: dtype = dtypes.typeclass(dtype) - if transient: - size_desc = dt.Array(dtype=dace.uint64, - shape=(len(shape),), - storage=dtypes.StorageType.Default, - location=None, - allow_conflicts=False, - transient=True, - strides=(1,), - offset=(0,), - lifetime=lifetime, - alignment=alignment, - debuginfo=debuginfo, - total_size=len(shape), - may_alias=False, - size_desc_name=None) - desc = dt.Array(dtype=dtype, shape=shape, storage=storage, @@ -1800,12 +1799,6 @@ def add_array(self, size_desc_name=None) array_name = self.add_datadesc(name, desc, find_new_name=find_new_name) - if transient: - size_desc_name = f"{array_name}_size" - self.add_size_datadesc(size_desc_name, size_desc) - # In case find_new_name and a new name is returned - # we need to update the size descriptor name of the array - desc.size_desc_name = size_desc_name return array_name, desc def add_view(self, @@ -2053,15 +2046,14 @@ def add_temp_transient_like(self, desc: Union[dt.Array, dt.Scalar], dtype=None, newdesc.debuginfo = debuginfo return self.add_datadesc(self.temp_data_name(), newdesc), newdesc - @staticmethod - def _add_symbols(sdfg, desc: dt.Data): + def _add_symbols(self, desc: dt.Data): if isinstance(desc, dt.Structure): for v in desc.members.values(): if isinstance(v, dt.Data): - SDFG._add_symbols(sdfg, v) + self._add_symbols(v) for sym in desc.free_symbols: - if sym.name not in sdfg.symbols: - sdfg.add_symbol(sym.name, sym.dtype) + if sym.name not in self.symbols: + self.add_symbol(sym.name, sym.dtype) def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str: """ Adds an existing data descriptor to the SDFG array store. @@ -2092,7 +2084,7 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str else: # We do not check for data constant, because there is a link between the constants and # the data descriptors. - if name in self.arrays or name in self.size_arrays: + if name in self.arrays or name in self.arrays: raise FileExistsError(f'Data descriptor "{name}" already exists in SDFG') if name in self.symbols: raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a symbol.') @@ -2105,36 +2097,34 @@ def add_datadesc(self, name: str, datadesc: dt.Data, find_new_name=False) -> str # Add the data descriptor to the SDFG and all symbols that are not yet known. self._arrays[name] = datadesc - SDFG._add_symbols(self, datadesc) - - return name - - def add_size_datadesc(self, name: str, datadesc: dt.Data) -> str: - """ Adds an existing data descriptor to the SDFG array store. - - :param name: Name to use. - :param datadesc: Data descriptor to add. - :param find_new_name: If True and data descriptor with this name - exists, finds a new name to add. - :return: Name of the new data descriptor - """ - if not isinstance(name, str): - raise TypeError("Data descriptor name must be a string. Got %s" % type(name).__name__) - - if name in self.arrays or name in self.size_arrays: - raise FileExistsError(f'Data descriptor "{name}" already exists in SDFG') - if name in self.symbols: - raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a symbol.') - if name in self._subarrays: - raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a subarray.') - if name in self._rdistrarrays: - raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a RedistrArray.') - if name in self._pgrids: - raise FileExistsError(f'Can not create data descriptor "{name}", the name is used by a ProcessGrid.') - - # Add the data descriptor to the SDFG and all symbols that are not yet known. - self._size_arrays[name] = datadesc - SDFG._add_symbols(self, datadesc) + self._add_symbols(datadesc) + + if ( + datadesc.transient is True and + isinstance(datadesc, dt.Array) and + name != "__return" + ): + size_desc_name = f"{name}_size" + size_desc = dt.Array(dtype=dace.uint64, + shape=(len(datadesc.shape),), + storage=dtypes.StorageType.Default, + location=None, + allow_conflicts=False, + transient=True, + strides=(1,), + offset=(0,), + lifetime=datadesc.lifetime, + alignment=datadesc.alignment, + debuginfo=datadesc.debuginfo, + total_size=len(datadesc.shape), + may_alias=False, + size_desc_name=None) + self._arrays[size_desc_name] = size_desc + # In case find_new_name and a new name is returned + # we need to update the size descriptor name of the array + datadesc.size_desc_name = size_desc_name + self._add_symbols(size_desc) + print(self._arrays) return name diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 5212147c03..73d4913630 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -325,7 +325,7 @@ def _accessible(sdfg: 'dace.sdfg.SDFG', container: str, context: Dict[str, bool] """ Helper function that returns False if a data container cannot be accessed in the current SDFG context. """ - storage = sdfg.arrays[container].storage + storage = sdfg.arrays[container].storage if container in sdfg.arrays else sdfg.arrays[container].storage if storage == dtypes.StorageType.GPU_Global or storage in dtypes.GPU_STORAGES: return context.get('in_gpu', False) if storage == dtypes.StorageType.FPGA_Global or storage in dtypes.FPGA_STORAGES: @@ -901,10 +901,11 @@ def validate_state(state: 'dace.sdfg.SDFGState', # Check dimensionality of memory access if isinstance(e.data.subset, (sbs.Range, sbs.Indices)): - if e.data.subset.dims() != len(sdfg.arrays[e.data.data].shape): + desc = sdfg.arrays[e.data.data] if e.data.data in sdfg.arrays else sdfg.arrays[e.data.data] + if e.data.subset.dims() != len(desc.shape): raise InvalidSDFGEdgeError( "Memlet subset uses the wrong dimensions" - " (%dD for a %dD data node)" % (e.data.subset.dims(), len(sdfg.arrays[e.data.data].shape)), + " (%dD for a %dD data node)" % (e.data.subset.dims(), len(desc.shape)), sdfg, state_id, eid, @@ -913,8 +914,8 @@ def validate_state(state: 'dace.sdfg.SDFGState', # Verify that source and destination subsets contain the same # number of elements if not e.data.allow_oob and e.data.other_subset is not None and not ( - (isinstance(src_node, nd.AccessNode) and isinstance(sdfg.arrays[src_node.data], dt.Stream)) or - (isinstance(dst_node, nd.AccessNode) and isinstance(sdfg.arrays[dst_node.data], dt.Stream))): + (isinstance(src_node, nd.AccessNode) and src_node.data in sdfg.arrays and isinstance(sdfg.arrays[src_node.data], dt.Stream)) or + (isinstance(dst_node, nd.AccessNode) and src_node.data in sdfg.arrays and isinstance(sdfg.arrays[dst_node.data], dt.Stream))): src_expr = (e.data.src_subset.num_elements() * sdfg.arrays[src_node.data].veclen) dst_expr = (e.data.dst_subset.num_elements() * sdfg.arrays[dst_node.data].veclen) if symbolic.inequal_symbols(src_expr, dst_expr): diff --git a/dace/transformation/dataflow/redundant_array.py b/dace/transformation/dataflow/redundant_array.py index 5e5072ff32..ebccf93047 100644 --- a/dace/transformation/dataflow/redundant_array.py +++ b/dace/transformation/dataflow/redundant_array.py @@ -1675,7 +1675,7 @@ def _offset_subset(self, mapping: Dict[int, int], subset: subsets.Range, edge_su class RemoveIntermediateWrite(pm.SingleStateTransformation): """ Moves intermediate writes insde a Map's subgraph outside the Map. - + Currently, the transformation supports only the case `WriteAccess -> MapExit`, where the edge has an empty Memlet. """ diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index 46411478d5..f579180ff2 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -34,7 +34,7 @@ def depends_on(self): def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Set[str]]: """ Removes redundant arrays and access nodes. - + :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a @@ -84,7 +84,12 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S result.update({n.data for n in removed_nodes}) # If node is completely removed from graph, erase data descriptor - for aname, desc in list(sdfg.arrays.items()): + array_items = list(sdfg.arrays.items()) + size_descriptors = set([v.size_desc_name for v in sdfg.arrays.values() if v.size_desc_name is not None]) + for aname, desc in array_items: + # Remove size descriptors only if the original array is removed + if aname in size_descriptors: + continue if not desc.transient or isinstance(desc, data.Scalar): continue if aname not in access_sets or not access_sets[aname]: @@ -92,7 +97,10 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S if isinstance(desc, data.Structure) and len(desc.members) > 0: continue sdfg.remove_data(aname, validate=False) + if desc.size_desc_name is not None: + sdfg.remove_data(desc.size_desc_name, validate=False) result.add(aname) + result.add(desc.size_desc_name) return result or None diff --git a/tests/deferred_alloc_test.py b/tests/deferred_alloc_test.py index 6459ee6105..35b6c6c16b 100644 --- a/tests/deferred_alloc_test.py +++ b/tests/deferred_alloc_test.py @@ -1,4 +1,6 @@ import dace +from dace.transformation.dataflow.redundant_array import RedundantArray, RedundantSecondArray +from dace.transformation.interstate.state_fusion import StateFusion import numpy import cupy import pytest @@ -28,7 +30,8 @@ def _get_trivial_alloc_sdfg(storage_type: dace.dtypes.StorageType, transient: bo an_1 = state.add_access('A') an_1.add_in_connector('_write_size') - an_2 = state.add_array(name="user_size", shape=(2,), dtype=dace.uint64) + sdfg.add_array(name="user_size", shape=(2,), dtype=dace.uint64) + an_2 = state.add_access("user_size") state.add_edge(an_2, None, an_1, '_write_size', dace.Memlet(expr=f"user_size[{write_size}]") ) @@ -48,7 +51,8 @@ def _get_assign_map_sdfg(storage_type: dace.dtypes.StorageType, transient: bool, an_1.add_in_connector('_write_size') an_1.add_out_connector('_read_size') - an_2 = state.add_array(name="user_size", shape=(2,), dtype=dace.uint64) + sdfg.add_array(name="user_size", shape=(2,), dtype=dace.uint64) + an_2 = state.add_access("user_size") state.add_edge(an_2, None, an_1, '_write_size', dace.Memlet(expr="user_size[0:2]") ) @@ -116,6 +120,11 @@ def test_trivial_realloc(storage_type: dace.dtypes.StorageType, transient: bool) sdfg.compile() + sdfg.simplify() + sdfg.apply_transformations_repeated([StateFusion, RedundantArray, RedundantSecondArray]) + sdfg.validate() + sdfg.compile() + def test_realloc_use(storage_type: dace.dtypes.StorageType, transient: bool, schedule_type: dace.dtypes.ScheduleType): sdfg = _get_assign_map_sdfg(storage_type, transient, schedule_type) try: @@ -133,14 +142,28 @@ def test_realloc_use(storage_type: dace.dtypes.StorageType, transient: bool, sch if storage_type == dace.dtypes.StorageType.CPU_Heap: arr = numpy.array([-1.0]).astype(numpy.float32) user_size = numpy.array([10, 10]).astype(numpy.uint64) - compiled_sdfg (user_size=user_size, example_array=arr) + compiled_sdfg(user_size=user_size, example_array=arr) assert ( arr[0] == 3.0 ) if storage_type == dace.dtypes.StorageType.GPU_Global: arr = cupy.array([-1.0]).astype(cupy.float32) user_size = numpy.array([10, 10]).astype(numpy.uint64) - compiled_sdfg (user_size=user_size, example_array=arr) + compiled_sdfg(user_size=user_size, example_array=arr) assert ( arr.get()[0] == 3.0 ) + sdfg.simplify() + sdfg.apply_transformations_repeated([StateFusion, RedundantArray, RedundantSecondArray]) + sdfg.validate() + compiled_sdfg = sdfg.compile() + if storage_type == dace.dtypes.StorageType.CPU_Heap: + arr = numpy.array([-1.0]).astype(numpy.float32) + user_size = numpy.array([10, 10]).astype(numpy.uint64) + compiled_sdfg(user_size=user_size, example_array=arr) + assert ( arr[0] == 3.0 ) + if storage_type == dace.dtypes.StorageType.GPU_Global: + arr = cupy.array([-1.0]).astype(cupy.float32) + user_size = numpy.array([10, 10]).astype(numpy.uint64) + compiled_sdfg(user_size=user_size, example_array=arr) + assert ( arr.get()[0] == 3.0 ) def test_realloc_inside_map(): pass