diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 26b6629a81..8160b1de72 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -867,16 +867,16 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg return None in_edge = in_edges[0] - out_edge = out_edges[0] + out_edge = out_edges[0] if len(out_edges) > 0 else None # If there is one incoming and one outgoing edge, and one leads to a code # node, the one that leads to an access node is the viewed data. inmpath = state.memlet_path(in_edge) - outmpath = state.memlet_path(out_edge) + outmpath = state.memlet_path(out_edge) if out_edge else None src_is_data, dst_is_data = False, False if isinstance(inmpath[0].src, nd.AccessNode): src_is_data = True - if isinstance(outmpath[-1].dst, nd.AccessNode): + if outmpath and isinstance(outmpath[-1].dst, nd.AccessNode): dst_is_data = True if src_is_data and not dst_is_data: diff --git a/dace/transformation/passes/lift_struct_views.py b/dace/transformation/passes/lift_struct_views.py index 619a86d3ed..6744161000 100644 --- a/dace/transformation/passes/lift_struct_views.py +++ b/dace/transformation/passes/lift_struct_views.py @@ -1,12 +1,15 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from collections import defaultdict -from typing import Any, Dict, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from dace import SDFG, Memlet, SDFGState from dace.frontend.python import astutils +from dace.properties import CodeBlock from dace.sdfg import nodes as nd -from dace.sdfg.graph import MultiConnectorEdge +from dace.sdfg.graph import Edge, MultiConnectorEdge +from dace.sdfg.sdfg import InterstateEdge, memlets_in_ast +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.transformation import pass_pipeline as ppl from dace import data as dt from dace import dtypes @@ -187,6 +190,163 @@ def visit_Attribute(self, node: ast.Attribute) -> Any: else: raise NotImplementedError() +class InterstateEdgeRecoder(ast.NodeTransformer): + + sdfg: SDFG + edge: Edge[InterstateEdge] + data_name: str + data: Union[dt.Structure, dt.ContainerArray] + views_constructed: Set[str] + isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] + + def __init__(self, sdfg: SDFG, edge: Edge[InterstateEdge], data_name: str, + data: Union[dt.Structure, dt.ContainerArray], + isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState]): + self.sdfg = sdfg + self.edge = edge + self.data_name = data_name + self.data = data + self.views_constructed = set() + self.isedge_lifting_state_dict = isedge_lifting_state_dict + + def _handle_simple_name_access(self, node: ast.Attribute) -> Any: + struct: dt.Structure = self.data + if not node.attr in struct.members: + raise RuntimeError( + f'Structure attribute {node.attr} is not a member of the structure {struct.name} type definition' + ) + + # Insert the appropriate view, if it does not exist yet. + view_name = 'v_' + self.data_name + '_' + node.attr + try: + view = self.sdfg.arrays[view_name] + except KeyError: + view = dt.View.view(struct.members[node.attr]) + view_name = self.sdfg.add_datadesc(view_name, view, find_new_name=True) + self.views_constructed.add(view_name) + + # Construct the correct AST replacement node (direct access, i.e., name node). + replacement = ast.Name() + replacement.ctx = ast.Load() + replacement.id = view_name + + # Add access nodes for the view and the original container and connect them appropriately. + lift_state, data_node = self._get_or_create_lifting_state() + view_node = lift_state.add_access(view_name) + lift_state.add_edge(data_node, None, view_node, 'views', + Memlet.from_array(data_node.data + '.' + node.attr, self.data.members[node.attr])) + return self.generic_visit(replacement) + + def _handle_sliced_access(self, node: ast.Attribute, val: ast.Subscript) -> Any: + struct = self.data.stype + if not isinstance(struct, dt.Structure): + raise ValueError('Invalid ContainerArray, can only lift ContainerArrays to Structures') + if not node.attr in struct.members: + raise RuntimeError( + f'Structure attribute {node.attr} is not a member of the structure {struct.name} type definition' + ) + + # We first lift the slice into a separate view, and then the attribute access. + slice_view_name = 'v_' + self.data_name + '_slice' + attr_view_name = slice_view_name + '_' + node.attr + try: + slice_view = self.sdfg.arrays[slice_view_name] + except KeyError: + slice_view = dt.View.view(struct) + slice_view_name = self.sdfg.add_datadesc(slice_view_name, slice_view, find_new_name=True) + try: + attr_view = self.sdfg.arrays[attr_view_name] + except KeyError: + member: dt.Data = struct.members[node.attr] + attr_view = dt.View.view(member) + attr_view_name = self.sdfg.add_datadesc(attr_view_name, attr_view, find_new_name=True) + self.views_constructed.add(slice_view_name) + self.views_constructed.add(attr_view_name) + + # Construct the correct AST replacement node (direct access, i.e., name node). + replacement = ast.Name() + replacement.ctx = ast.Load() + replacement.id = attr_view_name + + # Add access nodes for the views to the slice and attribute and connect them appropriately to the original data + # container. + lift_state, data_node = self._get_or_create_lifting_state() + slice_view_node = lift_state.add_access(slice_view_name) + attr_view_node = lift_state.add_access(attr_view_name) + idx = astutils.unparse(val.slice) + if isinstance(val.slice, ast.Tuple): + idx = idx.strip('()') + slice_memlet = Memlet(data_node.data + '[' + idx + ']') + lift_state.add_edge(data_node, None, slice_view_node, 'views', slice_memlet) + attr_memlet = Memlet.from_array(slice_view_name + '.' + node.attr, struct.members[node.attr]) + lift_state.add_edge(slice_view_node, None, attr_view_node, 'views', attr_memlet) + return self.generic_visit(replacement) + + def _get_or_create_lifting_state(self) -> Tuple[SDFGState, nd.AccessNode]: + # Add a state for lifting before the edge, if there isn't one that was created already. + if self.edge.data in self.isedge_lifting_state_dict: + lift_state = self.isedge_lifting_state_dict[self.edge.data] + else: + pre_node: ControlFlowBlock = self.edge.src + lift_state = pre_node.parent_graph.add_state_after(pre_node, self.data_name + '_lifting') + self.isedge_lifting_state_dict[self.edge.data] = lift_state + + # Add a node for the original data container so the view can be connected to it. This may already be a view from + # a previous iteration of lifting, but in that case it is already correctly connected to a root data container. + data_node = None + for dn in lift_state.data_nodes(): + if dn.data == self.data_name: + data_node = dn + break + if data_node is None: + data_node = lift_state.add_access(self.data_name) + + return lift_state, data_node + + def visit_Attribute(self, node: ast.Attribute) -> Any: + if not node.value: + return self.generic_visit(node) + + if isinstance(self.data, dt.Structure): + if isinstance(node.value, ast.Name) and node.value.id == self.data_name: + return self._handle_simple_name_access(node) + elif (isinstance(node.value, ast.Subscript) and isinstance(node.value.slice, ast.Constant) and + node.value.slice.value == 0 and isinstance(node.value.value, ast.Name) and + node.value.value.id == self.data_name): + return self._handle_simple_name_access(node) + return self.generic_visit(node) + else: + # ContainerArray case. + if isinstance(node.value, ast.Name) and node.value.id == self.data_name: + # We are directly accessing a slice of a container array / view. That needs an inserted view to the + # container first. + slice_view_name = 'v_' + self.data_name + '_slice' + try: + slice_view = self.sdfg.arrays[slice_view_name] + except KeyError: + slice_view = dt.View.view(self.data.stype) + slice_view_name = self.sdfg.add_datadesc(slice_view_name, slice_view, find_new_name=True) + self.views_constructed.add(slice_view_name) + + # Add an access node for the slice view and connect it appropriately to the root data container. + lift_state, data_node = self._get_or_create_lifting_state() + slice_view_node = lift_state.add_access(slice_view_name) + lift_state.add_edge(data_node, None, slice_view_node, 'views', + Memlet.from_array(self.data_name, self.sdfg.data(self.data_name))) + elif (isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) and + node.value.value.id == self.data_name): + return self._handle_sliced_access(node, node.value) + return self.generic_visit(node) + + +def _data_containers_in_ast(node: ast.AST, arrnames: Set[str]) -> Set[str]: + result: Set[str] = set() + for subnode in ast.walk(node): + if isinstance(subnode, (ast.Attribute, ast.Subscript)): + data = astutils.rname(subnode.value) + if data in arrnames: + result.add(data) + return result class LiftStructViews(ppl.Pass): """ @@ -200,6 +360,8 @@ class LiftStructViews(ppl.Pass): CATEGORY: str = 'Optimization Preparation' + _isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] = dict() + def modifies(self) -> ppl.Modifies: return ppl.Modifies.Descriptors | ppl.Modifies.AccessNodes | ppl.Modifies.Tasklets | ppl.Modifies.Memlets @@ -209,6 +371,40 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {} + def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], result: Dict[str, Set[str]]) -> bool: + lifted_something = False + for k in edge.data.assignments.keys(): + assignment = edge.data.assignments[k] + assignment_str = str(assignment) + assignment_ast = ast.parse(assignment_str) + data_in_edge = _data_containers_in_ast(assignment_ast, cfg.sdfg.arrays.keys()) + for data in data_in_edge: + if '.' in data: + continue + container = cfg.sdfg.arrays[data] + if isinstance(container, (dt.Structure, dt.ContainerArray)): + visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict) + new_code = visitor.visit(assignment_ast) + edge.data.assignments[k] = astutils.unparse(new_code) + assignment_ast = new_code + result[data].update(visitor.views_constructed) + lifted_something = True + if not edge.data.is_unconditional(): + condition_ast = edge.data.condition.code[0] + data_in_edge = _data_containers_in_ast(condition_ast, cfg.sdfg.arrays.keys()) + for data in data_in_edge: + if '.' in data: + continue + container = cfg.sdfg.arrays[data] + if isinstance(container, (dt.Structure, dt.ContainerArray)): + visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict) + new_code = visitor.visit(condition_ast) + edge.data.condition = CodeBlock([new_code]) + condition_ast = new_code + result[data].update(visitor.views_constructed) + lifted_something = True + return lifted_something + def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd.Tasklet, edge: MultiConnectorEdge[Memlet], data: dt.Structure, connector: str, direction: dirtype) -> Set[str]: @@ -251,23 +447,34 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Dict[str, Set[str]]]: result = defaultdict(set) lifted_something = False - for nsdfg in sdfg.all_sdfgs_recursive(): - for state in nsdfg.states(): - for node in state.data_nodes(): - cont = nsdfg.data(node.data) - if (isinstance(cont, (dt.Structure, dt.StructureView, dt.StructureReference)) or - (isinstance(cont, (dt.ContainerView, dt.ContainerArray, dt.ContainerArrayReference)) and - isinstance(cont.stype, dt.Structure))): - for oedge in state.out_edges(node): - if isinstance(oedge.dst, nd.Tasklet): - res = self._lift_tasklet(state, node, oedge.dst, oedge, cont, oedge.dst_conn, 'in') - result[node.data].update(res) - lifted_something = True - for iedge in state.in_edges(node): - if isinstance(iedge.src, nd.Tasklet): - res = self._lift_tasklet(state, node, iedge.src, iedge, cont, iedge.src_conn, 'out') - result[node.data].update(res) - lifted_something = True + while True: + lifted_something_this_round = False + for cfg in sdfg.all_control_flow_regions(recursive=True): + for block in cfg.nodes(): + if isinstance(block, SDFGState): + for node in block.data_nodes(): + cont = cfg.sdfg.data(node.data) + if (isinstance(cont, (dt.Structure, dt.StructureView, dt.StructureReference)) or + (isinstance(cont, (dt.ContainerView, dt.ContainerArray, dt.ContainerArrayReference)) and + isinstance(cont.stype, dt.Structure))): + for oedge in block.out_edges(node): + if isinstance(oedge.dst, nd.Tasklet): + res = self._lift_tasklet(block, node, oedge.dst, oedge, cont, oedge.dst_conn, + 'in') + result[node.data].update(res) + lifted_something_this_round = True + for iedge in block.in_edges(node): + if isinstance(iedge.src, nd.Tasklet): + res = self._lift_tasklet(block, node, iedge.src, iedge, cont, iedge.src_conn, + 'out') + result[node.data].update(res) + lifted_something_this_round = True + for edge in cfg.edges(): + lifted_something_this_round |= self._lift_isedge(cfg, edge, result) + if not lifted_something_this_round: + break + else: + lifted_something = True if not lifted_something: return None