diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index e95674adc1..7f4fbc654d 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -2,10 +2,14 @@ """ Transformations to convert subgraphs to write-conflict resolutions. """ import ast import re -from dace import registry, nodes, dtypes +import copy +from dace import registry, nodes, dtypes, Memlet from dace.transformation import transformation, helpers as xfh from dace.sdfg import graph as gr, utils as sdutil from dace import SDFG, SDFGState +from dace.sdfg.state import StateSubgraphView +from dace.transformation import helpers +from dace.sdfg.propagation import propagate_memlets_state class AugAssignToWCR(transformation.SingleStateTransformation): @@ -20,6 +24,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): map_exit = transformation.PatternNode(nodes.MapExit) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] + _FUNCTIONS = ['min', 'max'] _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} @@ -27,6 +32,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): def expressions(cls): return [ sdutil.node_path_graph(cls.input, cls.tasklet, cls.output), + sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): @@ -38,7 +44,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Free tasklet if expr_index == 0: - # Only free tasklets supported for now if graph.entry_node(tasklet) is not None: return False @@ -49,8 +54,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Make sure augmented assignment can be fissioned as necessary if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0: - return False outedge = graph.edges_between(tasklet, outarr)[0] else: # Free map @@ -65,12 +68,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if len(graph.edges_between(tasklet, mx)) > 1: return False - # Currently no fission is supported + # Make sure augmented assignment can be fissioned as necessary if any(e.src is not me and not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(me) + graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0: - return False outedge = graph.edges_between(tasklet, mx)[0] @@ -78,6 +79,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) if tasklet.language is dtypes.Language.Python: # Match a single assignment with a binary operation as RHS @@ -108,18 +110,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Try to match a single C assignment that can be converted to WCR inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) - rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) - if re.match(lhs, cstr) is None: - continue + # rhs: a = (...) op b + rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None: + if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops, + re.escape(inconn)) + if re.match(rhs2, cstr) is None: + continue + # Same memlet if edge.data.subset != outedge.data.subset: continue # If in map, only match if the subset is independent of any # map indices (otherwise no conflict) - if (expr_index == 1 and len(outedge.data.subset.free_symbols - & set(me.map.params)) == len(me.map.params)): - continue + if expr_index == 1: + if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len( + me.map.params): + continue return True else: @@ -132,50 +149,22 @@ def apply(self, state: SDFGState, sdfg: SDFG): input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output + if self.expr_index == 1: + me = self.map_entry + mx = self.map_exit # If state fission is necessary to keep semantics, do it first - if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): - newstate = sdfg.add_state_after(state) - newstate.add_node(tasklet) - new_input, new_output = None, None - - # Keep old edges for after we remove tasklet from the original state - in_edges = list(state.in_edges(tasklet)) - out_edges = list(state.out_edges(tasklet)) - - for e in in_edges: - r = newstate.add_read(e.src.data) - newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) - if e.src is input: - new_input = r - for e in out_edges: - w = newstate.add_write(e.dst.data) - newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) - if e.dst is output: - new_output = w - - # Remove tasklet and resulting isolated nodes - state.remove_node(tasklet) - for e in in_edges: - if state.degree(e.src) == 0: - state.remove_node(e.src) - for e in out_edges: - if state.degree(e.dst) == 0: - state.remove_node(e.dst) - - # Reset state and nodes for rest of transformation - input = new_input - output = new_output - state = newstate - # End of state fission + if state.in_degree(input) > 0: + subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)]) + subgraph_nodes.add(input) + + subgraph = StateSubgraphView(state, subgraph_nodes) + helpers.state_fission(sdfg, subgraph) if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: - me = self.map_entry - mx = self.map_exit - inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] @@ -183,6 +172,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: @@ -206,13 +196,40 @@ def apply(self, state: SDFGState, sdfg: SDFG): inconn = edge.dst_conn match = re.match(r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' % (re.escape(outconn), re.escape(inconn), ops), cstr) if match is None: - # match = re.match( - # r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' % - # (re.escape(outconn), ops, re.escape(inconn)), cstr) - # if match is None: - continue - # op = match.group(2) - # expr = match.group(1) + match = re.match( + r'^\s*%s\s*=\s*\((.*)\)\s*(%s)\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)), cstr) + if match is None: + func_rhs = r'^\s*%s\s*=\s*(%s)\((.*),\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_rhs, cstr) + if match is None: + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,(.*)\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_lhs, cstr) + if match is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*(%s)\s*(%s)\s*%s;$' % ( + re.escape(outconn), re.escape(other_inconn), ops, re.escape(inconn)) + match = re.match(rhs2, cstr) + if match is None: + continue + else: + op = match.group(2) + expr = match.group(1) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(2) + expr = match.group(1) else: op = match.group(1) expr = match.group(2) @@ -232,16 +249,14 @@ def apply(self, state: SDFGState, sdfg: SDFG): raise NotImplementedError # Change output edge - outedge.data.wcr = f'lambda a,b: a {op} b' - - if self.expr_index == 0: - # Remove input node and connector - state.remove_edge_and_connectors(inedge) - if state.degree(input) == 0: - state.remove_node(input) + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' else: - # Remove input edge and dst connector, but not necessarily src - state.remove_memlet_path(inedge) + outedge.data.wcr = f'lambda a,b: a {op} b' + + # Remove input node and connector + state.remove_memlet_path(inedge) + propagate_memlets_state(sdfg, state) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards @@ -252,6 +267,9 @@ def apply(self, state: SDFGState, sdfg: SDFG): sd = sd.parent_sdfg outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): - outedge.data.wcr = f'lambda a,b: a {op} b' + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' + else: + outedge.data.wcr = f'lambda a,b: a {op} b' # At this point we are leading to an access node again and can # traverse further up diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py new file mode 100644 index 0000000000..091b2a9db8 --- /dev/null +++ b/tests/transformations/wcr_conversion_test.py @@ -0,0 +1,247 @@ +import dace + +from dace.transformation.dataflow import AugAssignToWCR + + +def test_aug_assign_tasklet_lhs(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + k + + sdfg = sdfg_aug_assign_tasklet_lhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + (k + 1) + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = k + a + + sdfg = sdfg_aug_assign_tasklet_rhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = (k + 1) + a + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + k; + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + (k + 1); + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = (k + 1) + a; + """ + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(a, c); + """ + + sdfg = sdfg_aug_assign_tasklet_func_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_rhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(c, a); + """ + + sdfg = sdfg_aug_assign_tasklet_func_rhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_free_map(): + + @dace.program + def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[0] + k << B[i] + b >> A[0] + """ + b = k * a; + """ + + sdfg = sdfg_aug_assign_free_map.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_state_fission_map(): + + @dace.program + def sdfg_aug_assign_state_fission(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet: + a << B[i] + b >> A[i] + b = a + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + sdfg = sdfg_aug_assign_state_fission.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 2 + + +def test_free_map_permissive(): + + @dace.program + def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = k * a; + """ + + sdfg = sdfg_free_map_permissive.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=False) + assert applied == 0 + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True) + assert applied == 1