Skip to content

Commit

Permalink
Let LiftStructViews lift interstate edge struct accesses
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 13, 2024
1 parent 896a1e1 commit 200e606
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 22 deletions.
6 changes: 3 additions & 3 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,16 +867,16 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg
return None

in_edge = in_edges[0]
out_edge = out_edges[0]
out_edge = out_edges[0] if len(out_edges) > 0 else None

# If there is one incoming and one outgoing edge, and one leads to a code
# node, the one that leads to an access node is the viewed data.
inmpath = state.memlet_path(in_edge)
outmpath = state.memlet_path(out_edge)
outmpath = state.memlet_path(out_edge) if out_edge else None
src_is_data, dst_is_data = False, False
if isinstance(inmpath[0].src, nd.AccessNode):
src_is_data = True
if isinstance(outmpath[-1].dst, nd.AccessNode):
if outmpath and isinstance(outmpath[-1].dst, nd.AccessNode):
dst_is_data = True

if src_is_data and not dst_is_data:
Expand Down
245 changes: 226 additions & 19 deletions dace/transformation/passes/lift_struct_views.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import ast
from collections import defaultdict
from typing import Any, Dict, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from dace import SDFG, Memlet, SDFGState
from dace.frontend.python import astutils
from dace.properties import CodeBlock
from dace.sdfg import nodes as nd
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.graph import Edge, MultiConnectorEdge
from dace.sdfg.sdfg import InterstateEdge, memlets_in_ast
from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion
from dace.transformation import pass_pipeline as ppl
from dace import data as dt
from dace import dtypes
Expand Down Expand Up @@ -187,6 +190,163 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
else:
raise NotImplementedError()

class InterstateEdgeRecoder(ast.NodeTransformer):

sdfg: SDFG
edge: Edge[InterstateEdge]
data_name: str
data: Union[dt.Structure, dt.ContainerArray]
views_constructed: Set[str]
isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState]

def __init__(self, sdfg: SDFG, edge: Edge[InterstateEdge], data_name: str,
data: Union[dt.Structure, dt.ContainerArray],
isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState]):
self.sdfg = sdfg
self.edge = edge
self.data_name = data_name
self.data = data
self.views_constructed = set()
self.isedge_lifting_state_dict = isedge_lifting_state_dict

def _handle_simple_name_access(self, node: ast.Attribute) -> Any:
struct: dt.Structure = self.data
if not node.attr in struct.members:
raise RuntimeError(
f'Structure attribute {node.attr} is not a member of the structure {struct.name} type definition'
)

# Insert the appropriate view, if it does not exist yet.
view_name = 'v_' + self.data_name + '_' + node.attr
try:
view = self.sdfg.arrays[view_name]
except KeyError:
view = dt.View.view(struct.members[node.attr])
view_name = self.sdfg.add_datadesc(view_name, view, find_new_name=True)
self.views_constructed.add(view_name)

# Construct the correct AST replacement node (direct access, i.e., name node).
replacement = ast.Name()
replacement.ctx = ast.Load()
replacement.id = view_name

# Add access nodes for the view and the original container and connect them appropriately.
lift_state, data_node = self._get_or_create_lifting_state()
view_node = lift_state.add_access(view_name)
lift_state.add_edge(data_node, None, view_node, 'views',
Memlet.from_array(data_node.data + '.' + node.attr, self.data.members[node.attr]))
return self.generic_visit(replacement)

def _handle_sliced_access(self, node: ast.Attribute, val: ast.Subscript) -> Any:
struct = self.data.stype
if not isinstance(struct, dt.Structure):
raise ValueError('Invalid ContainerArray, can only lift ContainerArrays to Structures')
if not node.attr in struct.members:
raise RuntimeError(
f'Structure attribute {node.attr} is not a member of the structure {struct.name} type definition'
)

# We first lift the slice into a separate view, and then the attribute access.
slice_view_name = 'v_' + self.data_name + '_slice'
attr_view_name = slice_view_name + '_' + node.attr
try:
slice_view = self.sdfg.arrays[slice_view_name]
except KeyError:
slice_view = dt.View.view(struct)
slice_view_name = self.sdfg.add_datadesc(slice_view_name, slice_view, find_new_name=True)
try:
attr_view = self.sdfg.arrays[attr_view_name]
except KeyError:
member: dt.Data = struct.members[node.attr]
attr_view = dt.View.view(member)
attr_view_name = self.sdfg.add_datadesc(attr_view_name, attr_view, find_new_name=True)
self.views_constructed.add(slice_view_name)
self.views_constructed.add(attr_view_name)

# Construct the correct AST replacement node (direct access, i.e., name node).
replacement = ast.Name()
replacement.ctx = ast.Load()
replacement.id = attr_view_name

# Add access nodes for the views to the slice and attribute and connect them appropriately to the original data
# container.
lift_state, data_node = self._get_or_create_lifting_state()
slice_view_node = lift_state.add_access(slice_view_name)
attr_view_node = lift_state.add_access(attr_view_name)
idx = astutils.unparse(val.slice)
if isinstance(val.slice, ast.Tuple):
idx = idx.strip('()')
slice_memlet = Memlet(data_node.data + '[' + idx + ']')
lift_state.add_edge(data_node, None, slice_view_node, 'views', slice_memlet)
attr_memlet = Memlet.from_array(slice_view_name + '.' + node.attr, struct.members[node.attr])
lift_state.add_edge(slice_view_node, None, attr_view_node, 'views', attr_memlet)
return self.generic_visit(replacement)

def _get_or_create_lifting_state(self) -> Tuple[SDFGState, nd.AccessNode]:
# Add a state for lifting before the edge, if there isn't one that was created already.
if self.edge.data in self.isedge_lifting_state_dict:
lift_state = self.isedge_lifting_state_dict[self.edge.data]
else:
pre_node: ControlFlowBlock = self.edge.src
lift_state = pre_node.parent_graph.add_state_after(pre_node, self.data_name + '_lifting')
self.isedge_lifting_state_dict[self.edge.data] = lift_state

# Add a node for the original data container so the view can be connected to it. This may already be a view from
# a previous iteration of lifting, but in that case it is already correctly connected to a root data container.
data_node = None
for dn in lift_state.data_nodes():
if dn.data == self.data_name:
data_node = dn
break
if data_node is None:
data_node = lift_state.add_access(self.data_name)

return lift_state, data_node

def visit_Attribute(self, node: ast.Attribute) -> Any:
if not node.value:
return self.generic_visit(node)

if isinstance(self.data, dt.Structure):
if isinstance(node.value, ast.Name) and node.value.id == self.data_name:
return self._handle_simple_name_access(node)
elif (isinstance(node.value, ast.Subscript) and isinstance(node.value.slice, ast.Constant) and
node.value.slice.value == 0 and isinstance(node.value.value, ast.Name) and
node.value.value.id == self.data_name):
return self._handle_simple_name_access(node)
return self.generic_visit(node)
else:
# ContainerArray case.
if isinstance(node.value, ast.Name) and node.value.id == self.data_name:
# We are directly accessing a slice of a container array / view. That needs an inserted view to the
# container first.
slice_view_name = 'v_' + self.data_name + '_slice'
try:
slice_view = self.sdfg.arrays[slice_view_name]
except KeyError:
slice_view = dt.View.view(self.data.stype)
slice_view_name = self.sdfg.add_datadesc(slice_view_name, slice_view, find_new_name=True)
self.views_constructed.add(slice_view_name)

# Add an access node for the slice view and connect it appropriately to the root data container.
lift_state, data_node = self._get_or_create_lifting_state()
slice_view_node = lift_state.add_access(slice_view_name)
lift_state.add_edge(data_node, None, slice_view_node, 'views',
Memlet.from_array(self.data_name, self.sdfg.data(self.data_name)))
elif (isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) and
node.value.value.id == self.data_name):
return self._handle_sliced_access(node, node.value)
return self.generic_visit(node)


def _data_containers_in_ast(node: ast.AST, arrnames: Set[str]) -> Set[str]:
result: Set[str] = set()
for subnode in ast.walk(node):
if isinstance(subnode, (ast.Attribute, ast.Subscript)):
data = astutils.rname(subnode.value)
if data in arrnames:
result.add(data)
return result

class LiftStructViews(ppl.Pass):
"""
Expand All @@ -200,6 +360,8 @@ class LiftStructViews(ppl.Pass):

CATEGORY: str = 'Optimization Preparation'

_isedge_lifting_state_dict: Dict[InterstateEdge, SDFGState] = dict()

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Descriptors | ppl.Modifies.AccessNodes | ppl.Modifies.Tasklets | ppl.Modifies.Memlets

Expand All @@ -209,6 +371,40 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {}

def _lift_isedge(self, cfg: ControlFlowRegion, edge: Edge[InterstateEdge], result: Dict[str, Set[str]]) -> bool:
lifted_something = False
for k in edge.data.assignments.keys():
assignment = edge.data.assignments[k]
assignment_str = str(assignment)
assignment_ast = ast.parse(assignment_str)
data_in_edge = _data_containers_in_ast(assignment_ast, cfg.sdfg.arrays.keys())
for data in data_in_edge:
if '.' in data:
continue
container = cfg.sdfg.arrays[data]
if isinstance(container, (dt.Structure, dt.ContainerArray)):
visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict)
new_code = visitor.visit(assignment_ast)
edge.data.assignments[k] = astutils.unparse(new_code)
assignment_ast = new_code
result[data].update(visitor.views_constructed)
lifted_something = True
if not edge.data.is_unconditional():
condition_ast = edge.data.condition.code[0]
data_in_edge = _data_containers_in_ast(condition_ast, cfg.sdfg.arrays.keys())
for data in data_in_edge:
if '.' in data:
continue
container = cfg.sdfg.arrays[data]
if isinstance(container, (dt.Structure, dt.ContainerArray)):
visitor = InterstateEdgeRecoder(cfg.sdfg, edge, data, container, self._isedge_lifting_state_dict)
new_code = visitor.visit(condition_ast)
edge.data.condition = CodeBlock([new_code])
condition_ast = new_code
result[data].update(visitor.views_constructed)
lifted_something = True
return lifted_something

def _lift_tasklet(self, state: SDFGState, data_node: nd.AccessNode, tasklet: nd.Tasklet,
edge: MultiConnectorEdge[Memlet], data: dt.Structure, connector: str,
direction: dirtype) -> Set[str]:
Expand Down Expand Up @@ -251,23 +447,34 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Dict[str, Set[str]]]:
result = defaultdict(set)

lifted_something = False
for nsdfg in sdfg.all_sdfgs_recursive():
for state in nsdfg.states():
for node in state.data_nodes():
cont = nsdfg.data(node.data)
if (isinstance(cont, (dt.Structure, dt.StructureView, dt.StructureReference)) or
(isinstance(cont, (dt.ContainerView, dt.ContainerArray, dt.ContainerArrayReference)) and
isinstance(cont.stype, dt.Structure))):
for oedge in state.out_edges(node):
if isinstance(oedge.dst, nd.Tasklet):
res = self._lift_tasklet(state, node, oedge.dst, oedge, cont, oedge.dst_conn, 'in')
result[node.data].update(res)
lifted_something = True
for iedge in state.in_edges(node):
if isinstance(iedge.src, nd.Tasklet):
res = self._lift_tasklet(state, node, iedge.src, iedge, cont, iedge.src_conn, 'out')
result[node.data].update(res)
lifted_something = True
while True:
lifted_something_this_round = False
for cfg in sdfg.all_control_flow_regions(recursive=True):
for block in cfg.nodes():
if isinstance(block, SDFGState):
for node in block.data_nodes():
cont = cfg.sdfg.data(node.data)
if (isinstance(cont, (dt.Structure, dt.StructureView, dt.StructureReference)) or
(isinstance(cont, (dt.ContainerView, dt.ContainerArray, dt.ContainerArrayReference)) and
isinstance(cont.stype, dt.Structure))):
for oedge in block.out_edges(node):
if isinstance(oedge.dst, nd.Tasklet):
res = self._lift_tasklet(block, node, oedge.dst, oedge, cont, oedge.dst_conn,
'in')
result[node.data].update(res)
lifted_something_this_round = True
for iedge in block.in_edges(node):
if isinstance(iedge.src, nd.Tasklet):
res = self._lift_tasklet(block, node, iedge.src, iedge, cont, iedge.src_conn,
'out')
result[node.data].update(res)
lifted_something_this_round = True
for edge in cfg.edges():
lifted_something_this_round |= self._lift_isedge(cfg, edge, result)
if not lifted_something_this_round:
break
else:
lifted_something = True

if not lifted_something:
return None
Expand Down

0 comments on commit 200e606

Please sign in to comment.