diff --git a/.gitmodules b/.gitmodules index de1a474e98..bc68bc3441 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,4 +13,4 @@ url = https://github.com/spcl/dace-webclient.git [submodule "dace/external/rtllib"] path = dace/external/rtllib - url = https://github.com/carljohnsen/rtllib.git + url = https://github.com/carljohnsen/rtllib.git diff --git a/dace/sdfg/container_group.py b/dace/sdfg/container_group.py new file mode 100644 index 0000000000..a69f3319be --- /dev/null +++ b/dace/sdfg/container_group.py @@ -0,0 +1,183 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from collections import OrderedDict +from typing import Set, Union +import typing +from dace import data +from dace.data import Data +from dace import serialize, symbolic +from dace.properties import ListProperty, OrderedDictProperty, Property, make_properties +from enum import Enum + +import numpy +import sympy + + +class ContainerGroupFlatteningMode(Enum): + ArrayOfStructs = 1 + StructOfArrays = 2 + + +def _members_to_json(members): + if members is None: + return None + return [(k, serialize.to_json(v)) for k, v in members.items()] + + +def _members_from_json(obj, context=None): + if obj is None: + return {} + return OrderedDict((k, serialize.from_json(v, context)) for k, v in obj) + + +@make_properties +class ContainerGroup: + name = Property(dtype=str, default="", allow_none=False) + members = OrderedDictProperty( + default=OrderedDict(), + desc="Dictionary of structure members", + from_json=_members_from_json, + to_json=_members_to_json, + ) + is_cg = Property(dtype=bool, default=False, allow_none=False) + is_ca = Property(dtype=bool, default=False, allow_none=False) + shape = Property(dtype=tuple, default=(1, ), allow_none=False) + + def __init__(self, name, is_cg, is_ca, shape): + self.name = name + self.members = OrderedDict() + self.is_cg = is_cg + self.is_ca = is_ca + self.shape = shape + self._validate() + + def add_member(self, name: str, member: Union[Data, "ContainerGroup"]): + if name is None or name == "": + name = len(self.members) + self.members[name] = member + + @property + def free_symbols(self) -> Set[symbolic.SymbolicType]: + """Returns a set of undefined symbols in this data descriptor.""" + result = set() + for k, v in self.members.items(): + result |= v.free_symbols + return result + + def __call__(self): + return self + + def validate(self): + self._validate() + + def _validate(self): + return True + + @staticmethod + def from_json(json_obj, context=None): + if json_obj['type'] != 'ContainerGroup': + raise TypeError("Invalid data type") + + ret = ContainerGroup({}) + serialize.set_properties_from_json(ret, json_obj, context=context) + + return ret + + def to_json(self): + attrs = serialize.all_properties_to_json(self) + retdict = {"type": type(self).__name__, "attributes": attrs} + return retdict + + def is_equivalent(self, other): + raise NotImplementedError + + def __eq__(self, other): + return serialize.dumps(self) == serialize.dumps(other) + + def __hash__(self): + return hash(serialize.dumps(self)) + + def __repr__(self): + members_repr = ", ".join( + f"{k}: {v.__repr__()}" for k, v in self.members.items() + ) + return f"ContainerGroup(name='{self.name}', is_cg={self.is_cg}, is_ca={self.is_ca}, shape={self.shape}, members={{ {members_repr} }})" + + def __str__(self): + return self.__repr__() + + def _soa_from_struct(self, name, structure, acc_shape): + self._add_members(name, structure, acc_shape=None) + + @classmethod + def from_struct( + cls, + name: str, + struct_or_container_array: typing.Union[data.Structure, data.ContainerArray], + is_cg: bool, + is_ca: bool, + shape: tuple + ) -> "ContainerGroup": + dg = cls(name=name, is_cg=is_cg, is_ca=is_ca, shape=shape) + assert is_cg ^ is_ca + + if isinstance(struct_or_container_array, data.Structure): + struct = struct_or_container_array + for member_name, member in struct.members.items(): + new_member = None + if isinstance(member, data.Structure): + new_member = cls.from_struct( + name = member_name, + struct_or_container_array = member, + is_cg = True, + is_ca = False, + shape = (1, )) + elif isinstance(member, data.ContainerArray): + new_member = cls.from_struct(name=member_name, + struct_or_container_array = member, + is_cg = False, + is_ca = True, + shape = member.shape) + elif isinstance(member, (data.Array, data.Scalar)): + new_member = member + elif isinstance( + member, (sympy.Basic, symbolic.SymExpr, int, numpy.integer) + ): + new_member = data.Scalar(symbolic.symtype(member)) + else: + raise TypeError(f"Unsupported member type in Structure: {type(member)}") + + dg.add_member( + name=f"{member_name}", + member=new_member + ) + else: + assert isinstance(struct_or_container_array, data.ContainerArray) + container_array = struct_or_container_array + member = container_array.stype + member_name = None + new_member = None + + if isinstance(member, data.Structure): + # Recursively convert nested Structures + member_name = member.name + new_member = cls.from_struct(name=member.name, + struct_or_container_array=member, + is_cg=True, + is_ca=False, + shape=(1,)) + elif isinstance(member, data.ContainerArray): + raise Exception("Two container arrays in a row is currently not supported") + elif isinstance(member, (data.Array, data.Scalar)): + new_member = member + elif isinstance( + member, (sympy.Basic, symbolic.SymExpr, int, numpy.integer) + ): + new_member = data.Scalar(symbolic.symtype(member)) + else: + raise TypeError(f"Unsupported member type in Structure: {type(member)}") + dg.add_member( + name=member_name if member_name is not None else "Leaf", + member=new_member + ) + + return dg diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 09b2325d1c..31735497dd 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -12,6 +12,7 @@ import shutil import sys from typing import Any, AnyStr, Dict, List, Optional, Sequence, Set, Tuple, Type, TYPE_CHECKING, Union +import typing import warnings import dace @@ -40,6 +41,7 @@ from dace.codegen.compiled_sdfg import CompiledSDFG from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeScope +from dace.sdfg.container_group import ContainerGroup, ContainerGroupFlatteningMode class NestedDict(dict): @@ -51,8 +53,11 @@ def __getitem__(self, key): tokens = key.split('.') if isinstance(key, str) else [key] token = tokens.pop(0) result = super(NestedDict, self).__getitem__(token) + while tokens: token = tokens.pop(0) + if isinstance(result, dt.ContainerArray): + result = result.stype result = result.members[token] return result @@ -424,6 +429,10 @@ class SDFG(ControlFlowRegion): desc="Data descriptors for this SDFG", to_json=_arrays_to_json, from_json=_nested_arrays_from_json) + container_groups = Property(dtype=NestedDict, + desc="Data group descriptors for this SDFG", + to_json=_arrays_to_json, + from_json=_nested_arrays_from_json) symbols = DictProperty(str, dtypes.typeclass, desc="Global symbols for this SDFG") instrument = EnumProperty(dtype=dtypes.InstrumentationType, @@ -445,6 +454,7 @@ class SDFG(ControlFlowRegion): debuginfo = DebugInfoProperty(allow_none=True) + _pgrids = DictProperty(str, ProcessGrid, desc="Process-grid descriptors for this SDFG", @@ -503,6 +513,7 @@ def __init__(self, self._parent_sdfg = None self._parent_nsdfg_node = None self._arrays = NestedDict() # type: Dict[str, dt.Array] + self.container_groups = NestedDict() self.arg_names = [] self._labels: Set[str] = set() self.global_code = {'frame': CodeBlock("", dtypes.Language.CPP)} @@ -1328,11 +1339,16 @@ def _used_symbols_internal(self, defined_syms |= set(self.constants_prop.keys()) + init_code_symbols=set() + exit_code_symbols=set() # Add used symbols from init and exit code for code in self.init_code.values(): - free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + init_code_symbols |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) for code in self.exit_code.values(): - free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + exit_code_symbols |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) + + #free_syms|=set(filter(lambda x: not str(x).startswith('__f2dace_ARRAY'),init_code_symbols)) + #free_syms|=set(filter(lambda x: not str(x).startswith('__f2dace_ARRAY'),exit_code_symbols)) return super()._used_symbols_internal(all_symbols=all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, @@ -1413,7 +1429,9 @@ def arglist(self, scalars_only=False, free_symbols=None) -> Dict[str, dt.Data]: } # Add global free symbols used in the generated code to scalar arguments + #TODO LATER investiagte why all_symbols=False leads to bug free_symbols = free_symbols if free_symbols is not None else self.used_symbols(all_symbols=False) + free_symbols = set(filter(lambda x: not str(x).startswith('__f2dace_STRUCTARRAY'), free_symbols)) scalar_args.update({k: dt.Scalar(self.symbols[k]) for k in free_symbols if not k.startswith('__dace')}) # Fill up ordered dictionary @@ -1774,6 +1792,12 @@ def add_array(self, return self.add_datadesc(name, desc, find_new_name=find_new_name), desc + def add_container_group(self, + name: str, + find_new_name: bool = False) -> Tuple[str, ContainerGroup]: + dg_desc = ContainerGroup(name) + return self.add_container_group_desc(name, dg_desc, find_new_name=find_new_name), dg_desc + def add_view(self, name: str, shape, @@ -1914,7 +1938,7 @@ def add_scalar(self, storage=storage, transient=transient, lifetime=lifetime, - debuginfo=debuginfo, + debuginfo=debuginfo ) return self.add_datadesc(name, desc, find_new_name=find_new_name), desc @@ -2074,6 +2098,42 @@ def _add_symbols(sdfg: SDFG, desc: dt.Data): return name + def add_container_group_desc(self, name: str, container_group_desc: ContainerGroup, find_new_name=False) -> str: + if not isinstance(name, str): + raise TypeError("Data descriptor name must be a string. Got %s" % type(name).__name__) + + if find_new_name: + name = self._find_new_name(name) + name = name.replace('.', '_') + if self.is_name_used(name): + name = self._find_new_name(name) + else: + if name in self.arrays: + raise FileExistsError(f'Data group descriptor "{name}" already exists in SDFG') + if name in self.symbols: + raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a symbol.') + if name in self._subarrays: + raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a subarray.') + if name in self._rdistrarrays: + raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a RedistrArray.') + if name in self._pgrids: + raise FileExistsError(f'Can not create data group descriptor "{name}", the name is used by a ProcessGrid.') + + def _add_symbols(sdfg: SDFG, desc: dt.Data): + if isinstance(desc, dt.Structure): + for v in desc.members.values(): + if isinstance(v, dt.Data): + _add_symbols(sdfg, v) + for sym in desc.free_symbols: + if sym.name not in sdfg.symbols: + sdfg.add_symbol(sym.name, sym.dtype) + + # Add the data descriptor to the SDFG and all symbols that are not yet known. + self.container_groups[name] = container_group_desc + _add_symbols(self, container_group_desc) + + return name + def add_datadesc_view(self, name: str, datadesc: dt.Data, find_new_name=False) -> str: """ Adds a view of a given data descriptor to the SDFG array store. @@ -2865,3 +2925,96 @@ def recheck_using_explicit_control_flow(self) -> bool: break self.root_sdfg.using_explicit_control_flow = found_explicit_cf_block return found_explicit_cf_block + + def register_container_group_members(self, flattening_mode): + for name, dg in self.container_groups.items(): + self._register_container_group_members(flattening_mode=flattening_mode, container_group_or_array=dg, prefix_name=f'__CG_{name}', acc_shape=()) + + def _register_container_group_members(self, flattening_mode, + container_group_or_array: typing.Union[ContainerGroup, dace.data.ContainerArray], + prefix_name: str, acc_shape: tuple): + if flattening_mode == ContainerGroupFlatteningMode.StructOfArrays: + if isinstance(container_group_or_array, ContainerGroup): + container_group = container_group_or_array + for name, member in container_group.members.items(): + if isinstance(member, ContainerGroup): + if member.is_cg: + dg_prefix = prefix_name + f'__CG_{member.name}' + else: + dg_prefix = prefix_name + f'__CA_{member.name}' + acc_shape += member.shape + self._register_container_group_members( + container_group_or_array=member, + flattening_mode=flattening_mode, + prefix_name=dg_prefix, + acc_shape=acc_shape) + elif isinstance(member, dace.data.ContainerArray): + assert False + else: + # Add the dimensions accumulated while iterating from root to the leaf node of the trees + member_demangled_name = prefix_name + f'__m_{name}' + if isinstance(member, dace.data.Scalar): + datadesc = dace.data.Array( + dtype=member.dtype, shape=acc_shape, transient=member.transient, + allow_conflicts=member.allow_conflicts, storage=member.storage, + location=member.location, may_alias=member.may_alias, lifetime=member.lifetime, + debuginfo=member.debuginfo, start_offset=member.start_offset + ) + elif isinstance(member, dace.data.Array): + datadesc = dace.data.Array( + dtype=member.dtype, shape=acc_shape + member.shape, transient=member.transient, + allow_conflicts=member.allow_conflicts, storage=member.storage, + location=member.location, may_alias=member.may_alias, lifetime=member.lifetime, + debuginfo=member.debuginfo + ) + else: + raise Exception("Leaf member in a container group needs to be scalar or array") + self.add_datadesc(name=member_demangled_name, datadesc=datadesc, find_new_name=False) + elif isinstance(container_group_or_array, dace.data.ContainerArray): + assert False + else: + raise Exception("?") + elif flattening_mode == ContainerGroupFlatteningMode.ArrayOfStructs: + raise Exception("TODO Support for ArrayOfStructs Flattening") + else: + raise Exception("Unsupported Flattening Mode") + + def get_demangled_container_group_member_name(self, name_hierarchy: List[Type[str]]): + current_dg = None + demangled_name = '' + for i, name in enumerate(name_hierarchy): + if current_dg is None: + current_dg = self.container_groups[name] + demangled_name += f"__CG_{current_dg.name}" + elif name in current_dg.members: + if isinstance(current_dg.members[name], ContainerGroup): + current_dg = current_dg.members[name] + if current_dg.is_cg: + demangled_name += f"__CG_{current_dg.name}" + else: + demangled_name += f"__CA_{current_dg.name}" + elif (isinstance(current_dg.members[name], dace.data.ContainerArray)): + assert (False) + else: + assert isinstance(current_dg.members[name], dace.data.Data) + assert i == len(name_hierarchy) - 1 + demangled_name += f"__m_{name}" + return demangled_name + else: + # if we are at last element and it is a "Leaf" (data had no name) it is not an error + if i == len(name_hierarchy) - 1 and len(current_dg.members) == 1 and "Leaf" in current_dg.members: + demangled_name += f"__m_Leaf" + return demangled_name + raise Exception(f'Name Hierarchy {name_hierarchy} Not in ContainerGroups {self.container_groups}, {self._arrays} 1') + + if i == len(name_hierarchy) - 1 and len(current_dg.members) == 1 and "Leaf" in current_dg.members: + demangled_name += f"__m_Leaf" + return demangled_name + raise Exception(f'Name Hierarchy {name_hierarchy} Not in ContainerGroups {self.container_groups}, {self._arrays} 2') + + def generate_container_groups_from_structs(self, flattening_mode : ContainerGroupFlatteningMode): + for arr_name, arr in self._arrays.items(): + if isinstance(arr, dt.Structure): + dg_name = arr_name + dg = ContainerGroup.from_struct(name=dg_name, struct_or_container_array=arr, is_cg=True, is_ca=False, shape=(1, )) + self.container_groups[dg_name] = dg diff --git a/dace/transformation/passes/struct_to_container_group.py b/dace/transformation/passes/struct_to_container_group.py new file mode 100644 index 0000000000..433d571d72 --- /dev/null +++ b/dace/transformation/passes/struct_to_container_group.py @@ -0,0 +1,525 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" This module contains classes and functions that implement the grid-strided map tiling + transformation.""" + +import copy +from typing import Any, Dict +import dace +from dace.sdfg import SDFG, NestedDict, SDFGState +from dace.sdfg.state import ControlFlowBlock +from dace.properties import make_properties +from dace.sdfg import nodes +from dace.data import Structure, View +from dace.transformation import pass_pipeline as ppl +from dace.sdfg.container_group import ContainerGroupFlatteningMode + + +@make_properties +class StructToContainerGroups(ppl.Pass): + def __init__( + self, + flattening_mode: ContainerGroupFlatteningMode = ContainerGroupFlatteningMode.StructOfArrays, + simplify: bool = True, + validate: bool = True, + validate_all: bool = False, + clean_container_grous: bool = True, + ): + if flattening_mode != ContainerGroupFlatteningMode.StructOfArrays: + raise Exception("Only StructOfArrays is supported") + self._simplify = simplify + self._validate = validate + self._validate_all = validate_all + self._clean_container_grous = clean_container_grous + self._access_names_map = dict() + self._data_connected_to_vsv_struct = dict() + self._flattening_mode = flattening_mode + if self._flattening_mode == ContainerGroupFlatteningMode.ArrayOfStructs: + raise Exception("TODO IMPL") + super().__init__() + + def modifies(self) -> ppl.Modifies: + return ( + ppl.Modifies.Nodes + | ppl.Modifies.Edges + | ppl.Modifies.AccessNodes + | ppl.Modifies.Memlets + | ppl.Modifies.Descriptors + ) + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return False + + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> int: + sdfg.generate_container_groups_from_structs(self._flattening_mode) + sdfg.register_container_group_members(self._flattening_mode) + + # A -> B both access nodes, this should trigger the further check whether we can apply + i = 0 + for state in sdfg.states(): + nodes = state.nodes() + removed_nodes = set() + for node in nodes: + if node in removed_nodes: + continue + if isinstance(node, dace.nodes.AccessNode): + out_edges = state.out_edges(node) + for oe in out_edges: + if oe.dst in removed_nodes: + continue + if isinstance(oe.dst, dace.nodes.AccessNode): + src_access = node + dst_access = oe.dst + pattern_found = self._can_be_applied( + state, sdfg, src_access, dst_access + ) + if pattern_found: + i += 1 + newly_removed_nodes = self._apply( + state, sdfg, src_access, dst_access + ) + removed_nodes = removed_nodes.union(newly_removed_nodes) + + # Clean Mapped Views (Views within data groups) + for state in sdfg.states(): + for edge in state.edges(): + if edge.data.data in self._access_names_map: + data_name = self._access_names_map[edge.data.data] + edge.data.data = data_name + + # View -> Struct -> View patterns result with disconnected compenets reconnect them with saved info + for state in sdfg.states(): + nodes = state.nodes() + for node in nodes: + if node not in state.nodes(): + continue + for ( + in_connected_nodes, + out_connected_nodes, + ) in self._data_connected_to_vsv_struct.values(): + assert ( + len(in_connected_nodes) <= 1 and len(out_connected_nodes) <= 1 + ) + if len(in_connected_nodes) == 1 and len(out_connected_nodes) == 1: + if node in in_connected_nodes: + src = node + dst = out_connected_nodes[0] + for oe in state.out_edges(dst): + assert oe.src_conn is None + state.add_edge( + src, + None, + oe.dst, + oe.dst_conn, + copy.deepcopy(oe.data), + ) + state.remove_node(dst) + elif node in out_connected_nodes: + continue + + # Remove structs + to_rm = [] + for name, desc in sdfg.arrays.items(): + if isinstance(desc, dace.data.Structure): + to_rm.insert(0, name) + sdfg.save("a.sdfg") + for name in to_rm: + sdfg.remove_data(name=name, validate=True) + + if self._simplify: + sdfg.simplify(self._validate, self._validate_all) + + if self._clean_container_grous: + sdfg.container_groups = NestedDict() + + def _can_be_applied( + self, + state: SDFGState, + sdfg: SDFG, + src_access: nodes.AccessNode, + dst_access: nodes.AccessNode, + ): + # Pattern1: A -> B, A struct, B pointer/view or whatever or + # Pattern2: B -> A, B pointer/view, A struct + # Condition: DataGroups have been generated (sdfg.generate_container_groups_from_structs()) + (struct_to_view_pattern, view_to_struct_pattern) = self._get_pattern_type( + state, sdfg, src_access, dst_access + ) + if (not struct_to_view_pattern) and (not view_to_struct_pattern): + return False + if struct_to_view_pattern and view_to_struct_pattern: + raise Exception( + "A -> B and B -> A found in structure + view access (impossible cycle)" + ) + + (struct_access, view_access, struct_data, view_data) = ( + self._assign_src_dst_to_struct_view(sdfg, src_access, dst_access) + ) + if struct_access is None or view_access is None: + return False + + if not (isinstance(struct_data, Structure)): + return False + if not (isinstance(view_data, View)): + return False + + return True + + def _assign_src_dst_to_struct_view( + self, sdfg: SDFG, src_access: nodes.AccessNode, dst_access: nodes.AccessNode + ): + struct_access = None + view_access = None + struct_data = None + view_data = None + + src_data = sdfg.arrays[src_access.data] + dst_data = sdfg.arrays[dst_access.data] + + if isinstance(src_data, Structure): + struct_access = src_access + struct_data = src_data + elif isinstance(dst_data, Structure): + struct_access = dst_access + struct_data = dst_data + + if isinstance(src_data, View): + view_access = src_access + view_data = src_data + elif isinstance(dst_data, View): + view_access = dst_access + view_data = dst_data + + return (struct_access, view_access, struct_data, view_data) + + def _get_pattern_type( + self, + state: SDFGState, + sdfg: SDFG, + src_access: nodes.AccessNode, + dst_access: nodes.AccessNode, + ): + (struct_access, view_access, struct_data, view_data) = ( + self._assign_src_dst_to_struct_view(sdfg, src_access, dst_access) + ) + + struct_to_view_edges = ( + set( + [ + v + for _, _, v, _, _ in state.out_edges(struct_access) + if v == view_access + ] + ) + if struct_access + else set() + ) + view_to_struct_edges = ( + set( + [ + v + for _, _, v, _, _ in state.out_edges(view_access) + if v == struct_access + ] + ) + if view_access + else set() + ) + + struct_to_view_pattern = False + view_to_struct_pattern = False + + if len(struct_to_view_edges) == 0 and len(view_to_struct_edges) == 0: + return (False, False) + elif len(struct_to_view_edges) != 0 and len(view_to_struct_edges) != 0: + raise Exception( + "A -> B and B -> A found in structure + view access (impossible cycle)" + ) + elif len(struct_to_view_edges) != 0: + struct_to_view_pattern = True + elif len(view_to_struct_edges) != 0: + view_to_struct_pattern = True + + return (struct_to_view_pattern, view_to_struct_pattern) + + def _get_struct_to_view_view_chain( + self, state: SDFGState, sdfg: SDFG, first_view_access: nodes.AccessNode + ): + view_accesses = [first_view_access] + current_view_access = first_view_access + while True: + out_edges = state.out_edges(current_view_access) + assert len(out_edges) == 1 + out_edge = out_edges[0] + u, uc, v, vc, memlet = out_edge + if isinstance(v, nodes.AccessNode) and isinstance( + sdfg.arrays[v.data], View + ): + current_view_access = v + view_accesses.append(v) + else: + return view_accesses + + def _get_view_to_struct_view_chain( + self, state: SDFGState, sdfg: SDFG, last_view_access: nodes.AccessNode + ): + view_accesses = [last_view_access] + current_view_access = last_view_access + while True: + in_edges = state.in_edges(current_view_access) + assert len(in_edges) == 1 + out_edge = in_edges[0] + u, uc, v, vc, memlet = out_edge + if isinstance(u, nodes.AccessNode) and isinstance( + sdfg.arrays[u.data], View + ): + current_view_access = u + view_accesses.insert(0, u) + else: + return view_accesses + + def _process_edges(self, edge_list, name_hierarchy, take_last=False): + assert len(edge_list) == 1 + edge = edge_list[0] + data = edge.data.data + tokenized_data = data.split(".") + assert len(tokenized_data) == 2 or len(tokenized_data) == 1 + name_hierarchy += tokenized_data + + def _apply( + self, + state: SDFGState, + sdfg: SDFG, + src_access: nodes.AccessNode, + dst_access: nodes.AccessNode, + ): + removed_nodes = set() + + struct_to_view, view_to_struct = self._get_pattern_type( + state, sdfg, src_access, dst_access + ) + if not (struct_to_view or view_to_struct): + raise Exception("StructToDataGroup not applicable") + assert not (struct_to_view and view_to_struct) + + if struct_to_view: + struct_access = src_access + view_access = dst_access + else: # view_to_struct + view_access = src_access + struct_access = dst_access + + view_chain = ( + self._get_struct_to_view_view_chain(state, sdfg, view_access) + if struct_to_view + else self._get_view_to_struct_view_chain(state, sdfg, view_access) + ) + + assert len(view_chain) >= 1 + name_hierarchy = [] + + if struct_to_view: + struct_to_view_edges = [ + e for e in state.out_edges(struct_access) if e.dst == view_chain[0] + ] + self._process_edges( + edge_list=struct_to_view_edges, name_hierarchy=name_hierarchy + ) + + for current_view_access in view_chain[:-1]: + view_to_next_edges = state.out_edges(current_view_access) + self._process_edges( + edge_list=view_to_next_edges, + name_hierarchy=name_hierarchy, + take_last=True, + ) + + if view_to_struct: + view_to_struct_edges = [ + e for e in state.in_edges(struct_access) if e.src == view_chain[-1] + ] + self._process_edges( + edge_list=view_to_struct_edges, name_hierarchy=name_hierarchy + ) + + for current_view_access in view_chain[:-1]: + view_to_next_edges = state.out_edges(current_view_access) + self._process_edges( + edge_list=view_to_next_edges, + name_hierarchy=name_hierarchy, + take_last=True, + ) + + demangled_name = sdfg.get_demangled_container_group_member_name(name_hierarchy) + + an = nodes.AccessNode(data=demangled_name) + + assert len(view_chain) <= 2 + if struct_to_view: + assert len(state.out_edges(view_chain[0])) == 1 + src_edge = state.out_edges(view_chain[0])[0] + assert len(state.out_edges(view_chain[-1])) == 1 + dst_edge = state.out_edges(view_chain[-1])[0] + else: # view_to_struct + assert len(state.in_edges(view_chain[0])) == 1 + src_edge = state.in_edges(view_chain[0])[0] + assert len(state.out_edges(view_chain[-1])) == 1 + dst_edge = state.out_edges(view_chain[-1])[0] + + if self._flattening_mode == ContainerGroupFlatteningMode.StructOfArrays: + memlet_shape = () + + if struct_to_view: + assert len(struct_to_view_edges) == 1 + struct_to_view_edge = struct_to_view_edges[0] + memlet_shape += tuple(struct_to_view_edge.data.subset.ranges) + + if isinstance( + sdfg.arrays[struct_to_view_edge.src.data], dace.data.Structure + ) and not isinstance( + sdfg.arrays[struct_to_view_edge.data.data], dace.data.ContainerArray + ): + skip = True + else: + skip = False + for vc in view_chain: + if skip: + skip = False + continue + dst_edge = state.out_edges(vc)[0] + if isinstance( + sdfg.arrays[vc.data], dace.data.Structure + ) and not isinstance( + sdfg.arrays[dst_edge.data.data], dace.data.ContainerArray + ): + skip = True + memlet_shape += tuple(dst_edge.data.subset.ranges) + + if view_to_struct: + assert len(view_to_struct_edges) == 1 + view_to_struct_edge = view_to_struct_edges[0] + memlet_shape += tuple(view_to_struct_edge.data.subset.ranges) + + if isinstance( + sdfg.arrays[view_to_struct_edge.dst.data], dace.data.Structure + ) and not isinstance( + sdfg.arrays[view_to_struct_edge.data.data], dace.data.ContainerArray + ): + skip = True + else: + skip = False + for vc in reversed(view_chain): + if skip: + skip = False + continue + src_edge = state.in_edges(vc)[0] + memlet_shape += tuple(src_edge.data.subset.ranges) + if isinstance( + sdfg.arrays[vc.data], dace.data.Structure + ) and not isinstance( + sdfg.arrays[src_edge.data.data], dace.data.ContainerArray + ): + skip = True + else: + raise Exception("ArrayOfStructs mode is not implemented yet") + + mc = dace.memlet.Memlet( + subset=dace.subsets.Range(memlet_shape), data=demangled_name + ) + + # If Struct -> View -> Dst: + # Then Struct (uc) -> (None) \ View \ (None) -> (vc) Dst + # Becomes NewData (None) -> (vc) Dst + + # If View -> Struct -> Dst: + # Then Src (uc) -> (None) \ View \ (None) -> (vc) Struct + # Becomes Src (uc) -> (None) NewData + state.add_node(an) + # Simplify manages to remove this + if struct_to_view: + view_name = "v_" + demangled_name + a = sdfg.arrays[dst_edge.data.data] + if view_name not in sdfg.arrays: + sdfg.add_view( + name=view_name, + shape=a.shape, + dtype=a.dtype, + storage=a.storage, + strides=a.strides, + offset=a.offset, + allow_conflicts=a.allow_conflicts, + find_new_name=False, + may_alias=a.may_alias, + ) + view_access = state.add_access(view_name) + state.add_edge(an, None, view_access, "views", mc) + nm = copy.deepcopy(dst_edge.data) + nm.data = view_name + state.add_edge(view_access, None, dst_edge.dst, dst_edge.dst_conn, nm) + + if struct_access.guid in self._data_connected_to_vsv_struct: + self._data_connected_to_vsv_struct[struct_access.guid][1].append(an) + else: + self._data_connected_to_vsv_struct[struct_access.guid] = ([], [an]) + + else: # view_to_struct + + view_name = "v_" + demangled_name + if view_name not in sdfg.arrays: + a = sdfg.arrays[dst_edge.data.data] + sdfg.add_view( + name=view_name, + shape=a.shape, + dtype=a.dtype, + storage=a.storage, + strides=a.strides, + offset=a.offset, + allow_conflicts=a.allow_conflicts, + find_new_name=False, + may_alias=a.may_alias, + ) + view_access = state.add_access(view_name) + nm = copy.deepcopy(src_edge.data) + nm.data = view_name + state.add_edge(src_edge.src, src_edge.src_conn, view_access, None, nm) + state.add_edge(view_access, "views", an, None, mc) + + if struct_access.guid in self._data_connected_to_vsv_struct: + self._data_connected_to_vsv_struct[struct_access.guid][0].append(an) + else: + self._data_connected_to_vsv_struct[struct_access.guid] = ([an], []) + + # Clean-up + for view_node in view_chain: + state.remove_node(view_node) + removed_nodes.add(view_node) + if (len(state.in_edges(struct_access)) == 0) and ( + len(state.out_edges(struct_access)) == 0 + ): + state.remove_node(struct_access) + removed_nodes.add(struct_access) + + # All acccess from the view need to me mapped to the newly added array + # The leaf node will not have access to all of the dimensions in the generated array we need to do that + # missing_dims = memlet_shape[:-len(sdfg.arrays[view_chain[-1 if struct_to_view else 0].data].shape)] + # if not isinstance(missing_dims, List): + # missing_dims = list(missing_dims) + self._access_names_map[view_chain[-1 if struct_to_view else 0].data] = view_name + + return removed_nodes + + def _get_src_dst(self, state: SDFGState, n1: nodes.Any, n2: nodes.Any): + n1_to_n2 = [e.dst for e in state.out_edges(n1) if e.dst == n2] + n2_to_n1 = [e.dst for e in state.out_edges(n2) if e.dst == n1] + if len(n2_to_n1) == 0 and len(n1_to_n2) == 0: + raise Exception("E1") + elif len(n2_to_n1) != 0 and len(n1_to_n2) != 0: + raise Exception("E2") + elif len(n2_to_n1) == 0: + assert len(n1_to_n2) > 0 + return (n1, n2) + else: + assert len(n2_to_n1) > 0 + return (n2, n1) + + def annotates_memlets(): + return False diff --git a/tests/container_group_test.py b/tests/container_group_test.py new file mode 100644 index 0000000000..45aeb54234 --- /dev/null +++ b/tests/container_group_test.py @@ -0,0 +1,451 @@ +import pytest +import dace +import numpy as np + +from dace.sdfg.container_group import ContainerGroupFlatteningMode +from dace.transformation.passes.struct_to_container_group import StructToContainerGroups + + +def _get_jacobi_sdfg(container_variant: str): + jacobi_sdfg = dace.SDFG("jacobi_" + container_variant) + + initialize = jacobi_sdfg.add_state("initialize") + for_guard = jacobi_sdfg.add_state("for_guard") + kernel = jacobi_sdfg.add_state("kernel") + finalize = jacobi_sdfg.add_state("finalize") + + step_init = dace.InterstateEdge(assignments={"step": "0"}) + step_check_true = dace.InterstateEdge(condition="not (step < NUM_STEPS)") + step_check_false = dace.InterstateEdge(condition="(step < NUM_STEPS)") + step_increment = dace.InterstateEdge(assignments={"step": "(step + 1)"}) + + jacobi_sdfg.add_edge(initialize, for_guard, step_init) + jacobi_sdfg.add_edge(for_guard, kernel, step_check_true) + jacobi_sdfg.add_edge(for_guard, finalize, step_check_false) + jacobi_sdfg.add_edge(kernel, for_guard, step_increment) + + N = dace.symbol("N") + jacobi_sdfg.add_symbol(name="N", stype=np.int64) + jacobi_sdfg.add_symbol(name="NUM_STEPS", stype=np.int64) + + if container_variant == "ContainerArray": + struct = dace.data.Structure( + members={ + "As": dace.data.ContainerArray( + stype=dace.data.Array( + shape=[N, N], + storage=dace.dtypes.StorageType.CPU_Heap, + dtype=dace.typeclass(np.float32), + ), + shape=(2,), + transient=False, + ), + }, + name="AB", + storage=dace.dtypes.StorageType.CPU_Heap, + ) + elif container_variant == "Struct": + struct = dace.data.Structure( + members={ + "A": dace.data.Array( + shape=[N, N], + storage=dace.dtypes.StorageType.CPU_Heap, + dtype=dace.typeclass(np.float32), + ), + "B": dace.data.Array( + shape=[N, N], + storage=dace.dtypes.StorageType.CPU_Heap, + dtype=dace.typeclass(np.float32), + ), + }, + name="AB", + storage=dace.dtypes.StorageType.CPU_Heap, + ) + else: + assert container_variant == "Baseline" + for n in ["v_A", "v_B"]: + jacobi_sdfg.add_array( + name=n, + shape=[N, N], + storage=dace.dtypes.StorageType.CPU_Heap, + dtype=dace.typeclass(np.float32), + ) + + if container_variant != "Baseline": + jacobi_sdfg.add_datadesc(name="AB", datadesc=struct) + + v_A_name, v_A = jacobi_sdfg.add_view( + name="v_A", + shape=[N, N], + storage=dace.dtypes.StorageType.CPU_Heap, + dtype=dace.typeclass(np.float32), + ) + + v_B_name, v_B = jacobi_sdfg.add_view( + name="v_B", + shape=[N, N], + storage=dace.dtypes.StorageType.CPU_Heap, + dtype=dace.typeclass(np.float32), + ) + + ab_access = dace.nodes.AccessNode(data="AB") + kernel.add_node(ab_access) + ab2_access = dace.nodes.AccessNode(data="AB") + kernel.add_node(ab2_access) + ab3_access = dace.nodes.AccessNode(data="AB") + kernel.add_node(ab3_access) + + if container_variant == "ContainerArray": + jacobi_sdfg.add_view( + name="v_AB_As", + shape=[N, N], + storage=dace.dtypes.StorageType.CPU_Heap, + dtype=dace.typeclass(np.float32), + ) + ab4_access = dace.nodes.AccessNode(data="v_AB_As") + kernel.add_node(ab4_access) + ab5_access = dace.nodes.AccessNode(data="v_AB_As") + kernel.add_node(ab5_access) + + if container_variant != "Baseline": + a_access = dace.nodes.AccessNode(data="v_A") + a_access.add_in_connector("views") + b_dst_access = dace.nodes.AccessNode(data="v_B") + b_dst_access.add_out_connector("views") + kernel.add_node(a_access) + b_access = dace.nodes.AccessNode(data="v_B") + b_access.add_in_connector("views") + a_dst_access = dace.nodes.AccessNode(data="v_A") + a_dst_access.add_out_connector("views") + kernel.add_node(b_access) + + if container_variant == "Baseline": + a_access = dace.nodes.AccessNode(data="v_A") + kernel.add_node(a_access) + a_dst_access = dace.nodes.AccessNode(data="v_A") + b_dst_access = dace.nodes.AccessNode(data="v_B") + + if container_variant == "Struct": + kernel.add_edge( + ab_access, + None, + a_access, + "views", + dace.Memlet(data="AB.A", subset=dace.subsets.Range.from_string("0:N, 0:N")), + ) + kernel.add_edge( + ab2_access, + None, + b_access, + "views", + dace.Memlet(data="AB.B", subset=dace.subsets.Range.from_string("0:N, 0:N")), + ) + elif container_variant == "ContainerArray": + kernel.add_edge( + ab_access, + None, + ab4_access, + "views", + dace.Memlet(data="AB.As", subset=dace.subsets.Range.from_string("0:1")), + ) + kernel.add_edge( + ab2_access, + None, + ab5_access, + "views", + dace.Memlet(data="AB.As", subset=dace.subsets.Range.from_string("1:2")), + ) + kernel.add_edge( + ab4_access, + None, + a_access, + "views", + dace.Memlet( + data="v_AB_As", subset=dace.subsets.Range.from_string("0:N, 0:N") + ), + ) + kernel.add_edge( + ab5_access, + None, + b_access, + "views", + dace.Memlet( + data="v_AB_As", subset=dace.subsets.Range.from_string("0:N, 0:N") + ), + ) + else: + assert container_variant == "Baseline" + + if container_variant != "Baseline": + vars = [("A", "B", a_access, b_dst_access), ("B", "A", b_access, a_dst_access)] + else: + vars = [ + ("A", "B", a_access, b_dst_access), + ("B", "A", b_dst_access, a_dst_access), + ] + for j, (src, dst, src_access, dst_access) in enumerate(vars): + update_map_entry, update_map_exit = kernel.add_map( + name=f"{dst}_update", + ndrange={ + "i": dace.subsets.Range(ranges=[(0, N - 3, 1)]), + "j": dace.subsets.Range(ranges=[(0, N - 3, 1)]), + }, + ) + + update_map_entry.add_in_connector(f"IN_v_{src}") + update_map_entry.add_out_connector(f"OUT_v_{src}") + update_map_exit.add_in_connector(f"IN_v_{dst}") + update_map_exit.add_out_connector(f"OUT_v_{dst}") + + kernel.add_edge( + src_access, + None, + update_map_entry, + f"IN_v_{src}", + dace.Memlet(expr=f"v_{src}[0:N,0:N]"), + ) + + jacobi_sdfg.add_scalar( + name=f"acc{j}", + dtype=dace.float32, + transient=True, + storage=dace.dtypes.StorageType.Register, + lifetime=dace.dtypes.AllocationLifetime.Scope, + ) + jacobi_sdfg.add_scalar( + name=f"acc_2_{j}", + dtype=dace.float32, + transient=True, + storage=dace.dtypes.StorageType.Register, + lifetime=dace.dtypes.AllocationLifetime.Scope, + ) + san = kernel.add_access(f"acc{j}") + + kernel.add_edge( + update_map_entry, + None, + san, + None, + dace.Memlet(None), + ) + + if container_variant != "Baseline": + sub_domain_access = dace.nodes.AccessNode(data=f"v_{src}") + sub_domain_access_2 = dace.nodes.AccessNode(data=f"v_{dst}") + + kernel.add_edge( + update_map_entry, + f"OUT_v_{src}", + sub_domain_access, + None, + dace.Memlet(expr=f"v_{src}[i:i+2,j:j+2]"), + ) + + inner_map_entry, inner_map_exit = kernel.add_map( + name=f"{dst}_inner_stencil", + ndrange={ + "_i": dace.subsets.Range(ranges=[(1, 3, 2)]), + "_j": dace.subsets.Range(ranges=[(1, 3, 2)]), + }, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + inner_map_entry.add_in_connector(f"IN_v_{src}") + inner_map_entry.add_out_connector(f"OUT_v_{src}") + inner_map_entry.add_in_connector(f"IN_acc") + inner_map_entry.add_out_connector(f"OUT_acc") + inner_map_exit.add_in_connector(f"IN_v_{dst}") + inner_map_exit.add_out_connector(f"OUT_v_{dst}") + + if container_variant != "Baseline": + kernel.add_edge( + sub_domain_access, + None, + inner_map_entry, + f"IN_v_{src}", + dace.Memlet(expr=f"v_{src}[i:i+2,j:j+2]"), + ) + else: + kernel.add_edge( + update_map_entry, + f"OUT_v_{src}", + inner_map_entry, + f"IN_v_{src}", + dace.Memlet(expr=f"v_{src}[i:i+2,j:j+2]"), + ) + + kernel.add_edge( + san, + None, + inner_map_entry, + f"IN_acc", + dace.Memlet(expr=f"acc{j}[0]"), + ) + + if container_variant != "Baseline": + kernel.add_edge( + inner_map_exit, + f"OUT_v_{dst}", + sub_domain_access_2, + None, + dace.Memlet(expr=f"v_{dst}[i:i+2,j:j+2]"), + ) + kernel.add_edge( + sub_domain_access_2, + None, + update_map_exit, + f"IN_v_{dst}", + dace.Memlet(expr=f"v_{dst}[i:i+2,j:j+2]"), + ) + else: + kernel.add_edge( + inner_map_exit, + f"OUT_v_{dst}", + update_map_exit, + f"IN_v_{dst}", + dace.Memlet(expr=f"v_{dst}[i:i+2,j:j+2]"), + ) + + access_str = f"v_{src}[i+_i,j+_j]" + t1 = kernel.add_tasklet( + name="Add", inputs={"_in"}, outputs={"_out"}, code=f"_out = 0.2 * _in " + ) + t2 = kernel.add_tasklet( + name="Acc", + inputs={"_in", "_acc_in"}, + outputs={"_out"}, + code=f"_out = _acc_in + _in", + ) + e1 = kernel.add_edge( + inner_map_entry, f"OUT_v_{src}", t1, "_in", dace.Memlet(expr=access_str) + ) + e2 = kernel.add_edge(t1, "_out", t2, "_in", dace.Memlet(expr=f"acc_2_{j}")) + e3 = kernel.add_edge( + t2, + "_out", + inner_map_exit, + f"IN_v_{dst}", + dace.Memlet(expr=f"v_{dst}[i+_i,j+_j]"), + ) + e4 = kernel.add_edge( + inner_map_entry, f"OUT_acc", t2, "_acc_in", dace.Memlet(expr=f"acc{j}") + ) + + update_map_exit.add_out_connector(f"OUT_v_{dst}") + kernel.add_edge( + update_map_exit, + f"OUT_v_{dst}", + dst_access, + None, + dace.Memlet(expr=f"v_{dst}[0:N,0:N]"), + ) + + if j == 0: + if container_variant == "ContainerArray": + ab6_access = kernel.add_access("v_AB_As") + kernel.add_edge( + dst_access, + f"views", + ab6_access, + None, + dace.Memlet(expr=f"v_AB_As[0:N, 0:N]"), + ) + kernel.add_edge( + ab6_access, + f"views", + ab2_access, + None, + dace.Memlet(expr=f"AB.As[{j}:{j}+1]"), + ) + elif container_variant == "Struct": + kernel.add_edge( + dst_access, + f"views", + ab2_access, + None, + dace.Memlet(expr=f"AB.{dst}[0:N,0:N]"), + ) + else: + # kernel.add_edge(dst_access, f"views", bb_access, None, dace.Memlet(expr=f"v_B[0:N,0:N]")) + pass + if j == 1: + if container_variant == "ContainerArray": + ab7_access = kernel.add_access("v_AB_As") + kernel.add_edge( + dst_access, + f"views", + ab7_access, + None, + dace.Memlet(expr=f"v_AB_As[0:N, 0:N]"), + ) + kernel.add_edge( + ab7_access, + f"views", + ab3_access, + None, + dace.Memlet(expr=f"AB.As[{j}:{j}+1]"), + ) + elif container_variant == "Struct": + kernel.add_edge( + dst_access, + f"views", + ab3_access, + None, + dace.Memlet(expr=f"AB.{dst}[0:N,0:N]"), + ) + else: + # kernel.add_edge(src_access, None, b_dst_access, None, dace.Memlet(expr=f"v_A[0:N,0:N]")) + pass + + jacobi_sdfg.validate() + return jacobi_sdfg + + +@pytest.fixture(params=["ContainerArray", "Struct"]) +def container_variant(request): + return request.param + + +def test_struct_to_container_group(container_variant: str): + baseline_sdfg = _get_jacobi_sdfg("Baseline") + baseline_sdfg.simplify(validate_all=True) + _N = 256 + _NS = 512 + np.random.seed(42) + A_ref = np.random.rand(_N, _N).astype(np.float32) + B_ref = np.random.rand(_N, _N).astype(np.float32) + baseline_sdfg(v_A=A_ref, v_B=B_ref, N=_N, NUM_STEPS=_NS) + + sdfg = _get_jacobi_sdfg(container_variant) + use_container_array = container_variant == "ContainerArray" + sdfg.simplify(validate_all=True) + + StructToContainerGroups( + flattening_mode=ContainerGroupFlatteningMode.StructOfArrays, + simplify=True, + validate=True, + validate_all=True, + ).apply_pass(sdfg, {}) + + for arr in sdfg.arrays.values(): + assert isinstance(arr, (dace.data.Array, dace.data.Scalar)) + + np.random.seed(42) + if use_container_array is True: + AB = np.random.rand(2, _N, _N).astype(np.float32) + sdfg(__CG_AB__CA_As__m_Leaf=AB, NUM_STEPS=_NS, N=_N) + + A_view = AB[0, :, :] + B_view = AB[1, :, :] + assert np.allclose(A_ref, A_view) + assert np.allclose(B_ref, B_view) + else: + A = np.random.rand(_N, _N).astype(np.float32) + B = np.random.rand(_N, _N).astype(np.float32) + sdfg(__CG_AB__m_A=A, __CG_AB__m_B=B, NUM_STEPS=_NS, N=_N) + assert np.allclose(A_ref, A) + assert np.allclose(B_ref, B) + + +if __name__ == "__main__": + test_struct_to_container_group("ContainerArray") + test_struct_to_container_group("Struct")