diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 8af7bfacb5..e1242eb289 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -111,12 +111,22 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Try to match a single C assignment that can be converted to WCR inconn = edge.dst_conn lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops) + # rhs: a = (...) op b rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)) func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn)) if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None: if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None: - continue + inconns = [edge.dst_conn for edge in inedges] + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops, + re.escape(inconn)) + if re.match(rhs2, cstr) is None: + continue # Same memlet if edge.data.subset != outedge.data.subset: @@ -225,7 +235,20 @@ def apply(self, state: SDFGState, sdfg: SDFG): re.escape(inconn)) match = re.match(func_lhs, cstr) if match is None: - continue + inconns = [edge.dst_conn for edge in inedges] + if len(inconns) != 2: + continue + + # Special case: a = op b + other_inconn = inconns[0] if inconns[0] != inconn else inconns[1] + rhs2 = r'^\s*%s\s*=\s*(%s)\s*(%s)\s*%s;$' % ( + re.escape(outconn), re.escape(other_inconn), ops, re.escape(inconn)) + match = re.match(rhs2, cstr) + if match is None: + continue + else: + op = match.group(2) + expr = match.group(1) else: op = match.group(1) expr = match.group(2) diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py index 58e0bb39b3..d1524e2dd0 100644 --- a/tests/transformations/wcr_conversion_test.py +++ b/tests/transformations/wcr_conversion_test.py @@ -6,12 +6,13 @@ def test_aug_assign_tasklet_lhs(): @dace.program - def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32]): + def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32], B: dace.float64[32]): for i in range(32): with dace.tasklet: a << A[i] + k << B[i] b >> A[i] - b = a + 1 + b = a + k sdfg = sdfg_aug_assign_tasklet_lhs.to_sdfg() sdfg.simplify() @@ -23,12 +24,13 @@ def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32]): def test_aug_assign_tasklet_lhs_brackets(): @dace.program - def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32]): + def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32], B: dace.float64[32]): for i in range(32): with dace.tasklet: a << A[i] + k << B[i] b >> A[i] - b = a + (1 + 1) + b = a + (k + 1) sdfg = sdfg_aug_assign_tasklet_lhs_brackets.to_sdfg() sdfg.simplify() @@ -40,12 +42,13 @@ def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32]): def test_aug_assign_tasklet_rhs(): @dace.program - def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32]): + def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32], B: dace.float64[32]): for i in range(32): with dace.tasklet: a << A[i] + k << B[i] b >> A[i] - b = 1 + a + b = k + a sdfg = sdfg_aug_assign_tasklet_rhs.to_sdfg() sdfg.simplify() @@ -57,12 +60,13 @@ def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32]): def test_aug_assign_tasklet_rhs_brackets(): @dace.program - def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32]): + def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32], B: dace.float64[32]): for i in range(32): with dace.tasklet: a << A[i] + k << B[i] b >> A[i] - b = (1 + 1) + a + b = (k + 1) + a sdfg = sdfg_aug_assign_tasklet_rhs_brackets.to_sdfg() sdfg.simplify() @@ -74,13 +78,14 @@ def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32]): def test_aug_assign_tasklet_lhs_cpp(): @dace.program - def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32]): + def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): for i in range(32): with dace.tasklet(language=dace.Language.CPP): a << A[i] + k << B[i] b >> A[i] """ - b = a + 1; + b = a + k; """ sdfg = sdfg_aug_assign_tasklet_lhs_cpp.to_sdfg() @@ -93,13 +98,14 @@ def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32]): def test_aug_assign_tasklet_lhs_brackets_cpp(): @dace.program - def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32]): + def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): for i in range(32): with dace.tasklet(language=dace.Language.CPP): a << A[i] + k << B[i] b >> A[i] """ - b = a + (1 + 1); + b = a + (k + 1); """ sdfg = sdfg_aug_assign_tasklet_lhs_brackets_cpp.to_sdfg() @@ -112,13 +118,14 @@ def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32]): def test_aug_assign_tasklet_rhs_brackets_cpp(): @dace.program - def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32]): + def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): for i in range(32): with dace.tasklet(language=dace.Language.CPP): a << A[i] + k << B[i] b >> A[i] """ - b = (1 + 1) + a; + b = (k + 1) + a; """ sdfg = sdfg_aug_assign_tasklet_rhs_brackets_cpp.to_sdfg() @@ -171,12 +178,15 @@ def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32], B: dace.float64[32 def test_aug_assign_free_map(): @dace.program - def sdfg_aug_assign_free_map(A: dace.float64[32]): + def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]): for i in dace.map[0:32]: - with dace.tasklet: - a << A[i] - b >> A[i] - b = a * 2 + with dace.tasklet(language=dace.Language.CPP): + a << A[0] + k << B[i] + b >> A[0] + """ + b = k * a; + """ sdfg = sdfg_aug_assign_free_map.to_sdfg() sdfg.simplify()