diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 05024561f5..1456eb6d28 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -20,6 +20,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation): map_exit = transformation.PatternNode(nodes.MapExit) _EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/'] + _FUNCTIONS = ['min', 'max'] _EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')} _PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'} @@ -78,6 +79,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) if tasklet.language is dtypes.Language.Python: # Match a single assignment with a binary operation as RHS @@ -109,8 +111,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) + func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None: - continue + if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None: + continue + # Same memlet if edge.data.subset != outedge.data.subset: continue @@ -183,6 +189,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): outconn = outedge.src_conn ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS) + funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS) # Change tasklet code if tasklet.language is dtypes.Language.Python: @@ -209,9 +216,24 @@ def apply(self, state: SDFGState, sdfg: SDFG): match = re.match( r'^\s*%s\s*=\s*\((.*)\)\s*(%s)\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)), cstr) if match is None: - continue - op = match.group(2) - expr = match.group(1) + func_rhs = r'^\s*%s\s*=\s*(%s)\((.*),\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_rhs, cstr) + if match is None: + func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,(.*)\)\s*;$' % (re.escape(outconn), funcs, + re.escape(inconn)) + match = re.match(func_lhs, cstr) + if match is None: + continue + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(1) + expr = match.group(2) + else: + op = match.group(2) + expr = match.group(1) else: op = match.group(1) expr = match.group(2) @@ -231,7 +253,10 @@ def apply(self, state: SDFGState, sdfg: SDFG): raise NotImplementedError # Change output edge - outedge.data.wcr = f'lambda a,b: a {op} b' + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' + else: + outedge.data.wcr = f'lambda a,b: a {op} b' if self.expr_index == 0: # Remove input node and connector @@ -251,6 +276,9 @@ def apply(self, state: SDFGState, sdfg: SDFG): sd = sd.parent_sdfg outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data))) for outedge in nstate.memlet_path(outedge): - outedge.data.wcr = f'lambda a,b: a {op} b' + if op in AugAssignToWCR._FUNCTIONS: + outedge.data.wcr = f'lambda a,b: {op}(a, b)' + else: + outedge.data.wcr = f'lambda a,b: a {op} b' # At this point we are leading to an access node again and can # traverse further up diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py index d4ecc50771..ab3987b041 100644 --- a/tests/transformations/wcr_conversion_test.py +++ b/tests/transformations/wcr_conversion_test.py @@ -128,7 +128,39 @@ def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32]): assert applied == 1 -if __name__ == "__main__": - test_aug_assign_tasklet_lhs_cpp() - test_aug_assign_tasklet_lhs_brackets_cpp() - test_aug_assign_tasklet_rhs_brackets_cpp() +def test_aug_assign_tasklet_func_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_lhs_cpp(A: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + b >> A[i] + """ + b = min(a, 0); + """ + + sdfg = sdfg_aug_assign_tasklet_func_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_rhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + b >> A[i] + """ + b = min(0, a); + """ + + sdfg = sdfg_aug_assign_tasklet_func_rhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1