Skip to content

Commit

Permalink
AugAssignToWCR: Support for more cases and increased test coverage (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Dec 2, 2023
1 parent f90cf5b commit 5cacfdd
Showing 1 changed file with 47 additions and 57 deletions.
104 changes: 47 additions & 57 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
""" Transformations to convert subgraphs to write-conflict resolutions. """
import ast
import re
from dace import registry, nodes, dtypes
import copy
from dace import registry, nodes, dtypes, Memlet
from dace.transformation import transformation, helpers as xfh
from dace.sdfg import graph as gr, utils as sdutil
from dace import SDFG, SDFGState
from dace.sdfg.state import StateSubgraphView
from dace.transformation import helpers
from dace.sdfg.propagation import propagate_memlets_state


class AugAssignToWCR(transformation.SingleStateTransformation):
Expand All @@ -20,13 +24,15 @@ class AugAssignToWCR(transformation.SingleStateTransformation):
map_exit = transformation.PatternNode(nodes.MapExit)

_EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/']
_FUNCTIONS = ['min', 'max']
_EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')}
_PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'}

@classmethod
def expressions(cls):
return [
sdutil.node_path_graph(cls.input, cls.tasklet, cls.output),
sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output)
]

def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
Expand All @@ -38,7 +44,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

# Free tasklet
if expr_index == 0:
# Only free tasklets supported for now
if graph.entry_node(tasklet) is not None:
return False

Expand All @@ -49,8 +54,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
# Make sure augmented assignment can be fissioned as necessary
if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)):
return False
if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0:
return False

outedge = graph.edges_between(tasklet, outarr)[0]
else: # Free map
Expand All @@ -65,12 +68,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if len(graph.edges_between(tasklet, mx)) > 1:
return False

# Currently no fission is supported
# Make sure augmented assignment can be fissioned as necessary
if any(e.src is not me and not isinstance(e.src, nodes.AccessNode)
for e in graph.in_edges(me) + graph.in_edges(tasklet)):
return False
if graph.in_degree(inarr) > 0:
return False

outedge = graph.edges_between(tasklet, mx)[0]

Expand All @@ -84,6 +85,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
outconn = outedge.src_conn

ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)
funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS)

if tasklet.language is dtypes.Language.Python:
# Match a single assignment with a binary operation as RHS
Expand Down Expand Up @@ -114,9 +116,23 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
# Try to match a single C assignment that can be converted to WCR
inconn = edge.dst_conn
lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops)
rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn))
if re.match(lhs, cstr) is None:
continue
# rhs: a = (...) op b
rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn))
func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None:
if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None:
inconns = list(self.tasklet.in_connectors)
if len(inconns) != 2:
continue

# Special case: a = <other> op b
other_inconn = inconns[0] if inconns[0] != inconn else inconns[1]
rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops,
re.escape(inconn))
if re.match(rhs2, cstr) is None:
continue

# Same memlet
if edge.data.subset != outedge.data.subset:
continue
Expand All @@ -129,57 +145,30 @@ def apply(self, state: SDFGState, sdfg: SDFG):
input: nodes.AccessNode = self.input
tasklet: nodes.Tasklet = self.tasklet
output: nodes.AccessNode = self.output
if self.expr_index == 1:
me = self.map_entry
mx = self.map_exit

# If state fission is necessary to keep semantics, do it first
if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0):
newstate = sdfg.add_state_after(state)
newstate.add_node(tasklet)
new_input, new_output = None, None

# Keep old edges for after we remove tasklet from the original state
in_edges = list(state.in_edges(tasklet))
out_edges = list(state.out_edges(tasklet))

for e in in_edges:
r = newstate.add_read(e.src.data)
newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data)
if e.src is input:
new_input = r
for e in out_edges:
w = newstate.add_write(e.dst.data)
newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data)
if e.dst is output:
new_output = w

# Remove tasklet and resulting isolated nodes
state.remove_node(tasklet)
for e in in_edges:
if state.degree(e.src) == 0:
state.remove_node(e.src)
for e in out_edges:
if state.degree(e.dst) == 0:
state.remove_node(e.dst)

# Reset state and nodes for rest of transformation
input = new_input
output = new_output
state = newstate
# End of state fission
if state.in_degree(input) > 0:
subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)])
subgraph_nodes.add(input)

subgraph = StateSubgraphView(state, subgraph_nodes)
helpers.state_fission(sdfg, subgraph)

if self.expr_index == 0:
inedges = state.edges_between(input, tasklet)
outedge = state.edges_between(tasklet, output)[0]
else:
me = self.map_entry
mx = self.map_exit

inedges = state.edges_between(me, tasklet)
outedge = state.edges_between(tasklet, mx)[0]

# Get relevant output connector
outconn = outedge.src_conn

ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)
funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS)

# Change tasklet code
if tasklet.language is dtypes.Language.Python:
Expand Down Expand Up @@ -258,16 +247,14 @@ def apply(self, state: SDFGState, sdfg: SDFG):
raise NotImplementedError

# Change output edge
outedge.data.wcr = f'lambda a,b: a {op} b'

if self.expr_index == 0:
# Remove input node and connector
state.remove_edge_and_connectors(inedge)
if state.degree(input) == 0:
state.remove_node(input)
if op in AugAssignToWCR._FUNCTIONS:
outedge.data.wcr = f'lambda a,b: {op}(a, b)'
else:
# Remove input edge and dst connector, but not necessarily src
state.remove_memlet_path(inedge)
outedge.data.wcr = f'lambda a,b: a {op} b'

# Remove input node and connector
state.remove_memlet_path(inedge)
propagate_memlets_state(sdfg, state)

# If outedge leads to non-transient, and this is a nested SDFG,
# propagate outwards
Expand All @@ -278,6 +265,9 @@ def apply(self, state: SDFGState, sdfg: SDFG):
sd = sd.parent_sdfg
outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data)))
for outedge in nstate.memlet_path(outedge):
outedge.data.wcr = f'lambda a,b: a {op} b'
if op in AugAssignToWCR._FUNCTIONS:
outedge.data.wcr = f'lambda a,b: {op}(a, b)'
else:
outedge.data.wcr = f'lambda a,b: a {op} b'
# At this point we are leading to an access node again and can
# traverse further up

0 comments on commit 5cacfdd

Please sign in to comment.