Skip to content

Commit

Permalink
AugAssignToWCR: Support AugAssingToWCR for map scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Sep 1, 2023
1 parent 07644f8 commit 5148896
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
72 changes: 67 additions & 5 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
""" Transformations to convert subgraphs to write-conflict resolutions. """
import ast
import re
from dace import registry, nodes, dtypes
import copy
from dace import registry, nodes, dtypes, Memlet
from dace.transformation import transformation, helpers as xfh
from dace.sdfg import graph as gr, utils as sdutil
from dace import SDFG, SDFGState
from dace.sdfg.state import StateSubgraphView


class AugAssignToWCR(transformation.SingleStateTransformation):
Expand All @@ -28,6 +30,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation):
def expressions(cls):
return [
sdutil.node_path_graph(cls.input, cls.tasklet, cls.output),
sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output)
]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
Expand Down Expand Up @@ -70,8 +73,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if any(e.src is not me and not isinstance(e.src, nodes.AccessNode)
for e in graph.in_edges(me) + graph.in_edges(tasklet)):
return False
if graph.in_degree(inarr) > 0:
return False

outedge = graph.edges_between(tasklet, mx)[0]

Expand Down Expand Up @@ -264,8 +265,69 @@ def apply(self, state: SDFGState, sdfg: SDFG):
if state.degree(input) == 0:
state.remove_node(input)
else:
# Remove input edge and dst connector, but not necessarily src
state.remove_memlet_path(inedge)
# Put into NestedSDFG to retain input dependencies
map_entry = self.map_entry
map_exit = self.map_exit
subgraph_nodes = set(state.all_nodes_between(map_entry, map_exit))
subgraph_nodes.add(map_entry)
subgraph_nodes.add(map_exit)

in_access_nodes = set()
out_access_nodes = set()
for edge in state.in_edges(map_entry):
subgraph_nodes.add(edge.src)
in_access_nodes.add(edge.src)
for edge in state.out_edges(map_exit):
subgraph_nodes.add(edge.dst)
out_access_nodes.add(edge.dst)

subgraph = StateSubgraphView(state, subgraph_nodes)

# Add subgraph as nested SDFG
nested_sdfg = SDFG("nested_" + map_entry.label)
inputs = set()
for data in in_access_nodes:
inputs.add(data.data)
nested_sdfg.arrays[data.data] = copy.deepcopy(sdfg.arrays[data.data])
outputs = set()
for data in out_access_nodes:
outputs.add(data.data)
nested_sdfg.arrays[data.data] = copy.deepcopy(sdfg.arrays[data.data])

nested_state = nested_sdfg.add_state("nested_" + map_entry.label + "_state", is_start_state=True)
node_map = {}
new_inedge = None
for node in subgraph.nodes():
new_node = copy.deepcopy(node)
nested_state.add_node(new_node)
node_map[node] = new_node
for edge in subgraph.edges():
new_edge = nested_state.add_edge(node_map[edge.src], edge.src_conn, node_map[edge.dst], edge.dst_conn,
edge.data)
if edge == inedge:
new_inedge = new_edge

nested_sdfg_node = state.add_nested_sdfg(nested_sdfg, sdfg, inputs=inputs, outputs=outputs)
for in_access in in_access_nodes:
nested_sdfg_node.add_in_connector(in_access.data)
state.add_edge(in_access, None, nested_sdfg_node, in_access.data,
Memlet.from_array(in_access.data, sdfg.arrays[in_access.data]))

for out_access in out_access_nodes:
nested_sdfg_node.add_in_connector(out_access.data)
state.add_edge(nested_sdfg_node, out_access.data, out_access, None,
Memlet.from_array(out_access.data, sdfg.arrays[out_access.data]))

# Remove subgraph from state
for edge in subgraph.edges():
state.remove_edge(edge)
for node in subgraph.nodes():
if node in in_access_nodes or node in out_access_nodes:
continue

state.remove_node(node)

nested_state.remove_memlet_path(new_inedge)

# If outedge leads to non-transient, and this is a nested SDFG,
# propagate outwards
Expand Down
46 changes: 46 additions & 0 deletions tests/transformations/wcr_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,49 @@ def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32]):

applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1


def test_aug_assign_free_map():

@dace.program
def sdfg_aug_assign_free_map(A: dace.float64[32]):
for i in dace.map[0:32]:
with dace.tasklet:
a << A[i]
b >> A[i]
b = a * 2

sdfg = sdfg_aug_assign_free_map.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1


def test_aug_assign_dependent_map():

@dace.program
def sdfg_aug_assign_dependent_map(A: dace.float64[32], B: dace.float64[32]):
for i in dace.map[0:32]:
with dace.tasklet:
a << B[i]
b >> A[i]
b = a

for i in dace.map[0:32]:
with dace.tasklet:
a << A[i]
b >> A[i]
b = a * 2

for i in dace.map[0:32]:
with dace.tasklet:
a << A[i]
b >> B[i]
b = a

sdfg = sdfg_aug_assign_dependent_map.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1

0 comments on commit 5148896

Please sign in to comment.