From 5cacfdd9f14db343059eba55f1191628512d4a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lukas=20Tr=C3=BCmper?= Date: Fri, 3 Nov 2023 09:07:31 +0100 Subject: [PATCH] AugAssignToWCR: Support for more cases and increased test coverage (#1359) --- .../transformation/dataflow/wcr_conversion.py | 104 ++++++++---------- 1 file changed, 47 insertions(+), 57 deletions(-) diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 8badd2bf9b..3ef508f7e5 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -2,10 +2,14 @@ """ Transformations to convert subgraphs to write-conflict resolutions. """ import ast import re -from dace import registry, nodes, dtypes +import copy +from dace import registry, nodes, dtypes, Memlet from dace.transformation import transformation, helpers as xfh from dace.sdfg import graph as gr, utils as sdutil from dace import SDFG, SDFGState +from dace.sdfg.state import StateSubgraphView +from dace.transformation import helpers +from dace.sdfg.propagation import propagate_memlets_state class AugAssignToWCR(transformation.SingleStateTransformation): @@ -20,6 +24,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): map_exit = transformation.PatternNode(nodes.MapExit) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] + _FUNCTIONS = ['min', 'max'] _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} @@ -27,6 +32,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): def expressions(cls): return [ sdutil.node_path_graph(cls.input, cls.tasklet, cls.output), + sdutil.node_path_graph(cls.input, cls.map_entry, cls.tasklet, cls.map_exit, cls.output) ] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): @@ -38,7 +44,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Free tasklet if expr_index == 0: - # Only free tasklets supported for now if graph.entry_node(tasklet) is not None: return False @@ -49,8 +54,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Make sure augmented assignment can be fissioned as necessary if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0: - return False outedge = graph.edges_between(tasklet, outarr)[0] else: # Free map @@ -65,12 +68,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if len(graph.edges_between(tasklet, mx)) > 1: return False - # Currently no fission is supported + # Make sure augmented assignment can be fissioned as necessary if any(e.src is not me and not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(me) + graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0: - return False outedge = graph.edges_between(tasklet, mx)[0] @@ -84,6 +85,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) if tasklet.language is dtypes.Language.Python: # Match a single assignment with a binary operation as RHS @@ -114,9 +116,23 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Try to match a single C assignment that can be converted to WCR inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) - rhs = r'^\s*%s\s*=\s*.*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) - if re.match(lhs, cstr) is None: - continue + # rhs: a = (...) op b + rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None: + if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None: + inconns = list(self.tasklet.in_connectors) + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops, + re.escape(inconn)) + if re.match(rhs2, cstr) is None: + continue + # Same memlet if edge.data.subset != outedge.data.subset: continue @@ -129,50 +145,22 @@ def apply(self, state: SDFGState, sdfg: SDFG): input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output + if self.expr_index == 1: + me = self.map_entry + mx = self.map_exit # If state fission is necessary to keep semantics, do it first - if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): - newstate = sdfg.add_state_after(state) - newstate.add_node(tasklet) - new_input, new_output = None, None - - # Keep old edges for after we remove tasklet from the original state - in_edges = list(state.in_edges(tasklet)) - out_edges = list(state.out_edges(tasklet)) - - for e in in_edges: - r = newstate.add_read(e.src.data) - newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) - if e.src is input: - new_input = r - for e in out_edges: - w = newstate.add_write(e.dst.data) - newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) - if e.dst is output: - new_output = w - - # Remove tasklet and resulting isolated nodes - state.remove_node(tasklet) - for e in in_edges: - if state.degree(e.src) == 0: - state.remove_node(e.src) - for e in out_edges: - if state.degree(e.dst) == 0: - state.remove_node(e.dst) - - # Reset state and nodes for rest of transformation - input = new_input - output = new_output - state = newstate - # End of state fission + if state.in_degree(input) > 0: + subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)]) + subgraph_nodes.add(input) + + subgraph = StateSubgraphView(state, subgraph_nodes) + helpers.state_fission(sdfg, subgraph) if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: - me = self.map_entry - mx = self.map_exit - inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] @@ -180,6 +168,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: @@ -258,16 +247,14 @@ def apply(self, state: SDFGState, sdfg: SDFG): raise NotImplementedError # Change output edge - outedge.data.wcr = f'lambda a,b: a {op} b' - - if self.expr_index == 0: - # Remove input node and connector - state.remove_edge_and_connectors(inedge) - if state.degree(input) == 0: - state.remove_node(input) + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' else: - # Remove input edge and dst connector, but not necessarily src - state.remove_memlet_path(inedge) + outedge.data.wcr = f'lambda a,b: a {op} b' + + # Remove input node and connector + state.remove_memlet_path(inedge) + propagate_memlets_state(sdfg, state) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards @@ -278,6 +265,9 @@ def apply(self, state: SDFGState, sdfg: SDFG): sd = sd.parent_sdfg outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): - outedge.data.wcr = f'lambda a,b: a {op} b' + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' + else: + outedge.data.wcr = f'lambda a,b: a {op} b' # At this point we are leading to an access node again and can # traverse further up