Skip to content

Commit

Permalink
AugAssignWCR: Special case added & tests with multiple inconns
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Sep 2, 2023
1 parent 07e4a25 commit 7c9e8e2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 21 deletions.
27 changes: 25 additions & 2 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,22 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
# Try to match a single C assignment that can be converted to WCR
inconn = edge.dst_conn
lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops)
# rhs: a = (...) op b
rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn))
func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None:
if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None:
continue
inconns = [edge.dst_conn for edge in inedges]
if len(inconns) != 2:
continue

# Special case: a = <other> op b
other_inconn = inconns[0] if inconns[0] != inconn else inconns[1]
rhs2 = r'^\s*%s\s*=\s*%s\s*%s\s*%s;$' % (re.escape(outconn), re.escape(other_inconn), ops,
re.escape(inconn))
if re.match(rhs2, cstr) is None:
continue

# Same memlet
if edge.data.subset != outedge.data.subset:
Expand Down Expand Up @@ -225,7 +235,20 @@ def apply(self, state: SDFGState, sdfg: SDFG):
re.escape(inconn))
match = re.match(func_lhs, cstr)
if match is None:
continue
inconns = [edge.dst_conn for edge in inedges]
if len(inconns) != 2:
continue

# Special case: a = <other> op b
other_inconn = inconns[0] if inconns[0] != inconn else inconns[1]
rhs2 = r'^\s*%s\s*=\s*(%s)\s*(%s)\s*%s;$' % (
re.escape(outconn), re.escape(other_inconn), ops, re.escape(inconn))
match = re.match(rhs2, cstr)
if match is None:
continue
else:
op = match.group(2)
expr = match.group(1)
else:
op = match.group(1)
expr = match.group(2)
Expand Down
48 changes: 29 additions & 19 deletions tests/transformations/wcr_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
def test_aug_assign_tasklet_lhs():

@dace.program
def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32]):
def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32], B: dace.float64[32]):
for i in range(32):
with dace.tasklet:
a << A[i]
k << B[i]
b >> A[i]
b = a + 1
b = a + k

sdfg = sdfg_aug_assign_tasklet_lhs.to_sdfg()
sdfg.simplify()
Expand All @@ -23,12 +24,13 @@ def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32]):
def test_aug_assign_tasklet_lhs_brackets():

@dace.program
def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32]):
def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32], B: dace.float64[32]):
for i in range(32):
with dace.tasklet:
a << A[i]
k << B[i]
b >> A[i]
b = a + (1 + 1)
b = a + (k + 1)

sdfg = sdfg_aug_assign_tasklet_lhs_brackets.to_sdfg()
sdfg.simplify()
Expand All @@ -40,12 +42,13 @@ def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32]):
def test_aug_assign_tasklet_rhs():

@dace.program
def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32]):
def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32], B: dace.float64[32]):
for i in range(32):
with dace.tasklet:
a << A[i]
k << B[i]
b >> A[i]
b = 1 + a
b = k + a

sdfg = sdfg_aug_assign_tasklet_rhs.to_sdfg()
sdfg.simplify()
Expand All @@ -57,12 +60,13 @@ def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32]):
def test_aug_assign_tasklet_rhs_brackets():

@dace.program
def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32]):
def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32], B: dace.float64[32]):
for i in range(32):
with dace.tasklet:
a << A[i]
k << B[i]
b >> A[i]
b = (1 + 1) + a
b = (k + 1) + a

sdfg = sdfg_aug_assign_tasklet_rhs_brackets.to_sdfg()
sdfg.simplify()
Expand All @@ -74,13 +78,14 @@ def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32]):
def test_aug_assign_tasklet_lhs_cpp():

@dace.program
def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32]):
def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32], B: dace.float64[32]):
for i in range(32):
with dace.tasklet(language=dace.Language.CPP):
a << A[i]
k << B[i]
b >> A[i]
"""
b = a + 1;
b = a + k;
"""

sdfg = sdfg_aug_assign_tasklet_lhs_cpp.to_sdfg()
Expand All @@ -93,13 +98,14 @@ def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32]):
def test_aug_assign_tasklet_lhs_brackets_cpp():

@dace.program
def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32]):
def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]):
for i in range(32):
with dace.tasklet(language=dace.Language.CPP):
a << A[i]
k << B[i]
b >> A[i]
"""
b = a + (1 + 1);
b = a + (k + 1);
"""

sdfg = sdfg_aug_assign_tasklet_lhs_brackets_cpp.to_sdfg()
Expand All @@ -112,13 +118,14 @@ def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32]):
def test_aug_assign_tasklet_rhs_brackets_cpp():

@dace.program
def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32]):
def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]):
for i in range(32):
with dace.tasklet(language=dace.Language.CPP):
a << A[i]
k << B[i]
b >> A[i]
"""
b = (1 + 1) + a;
b = (k + 1) + a;
"""

sdfg = sdfg_aug_assign_tasklet_rhs_brackets_cpp.to_sdfg()
Expand Down Expand Up @@ -171,12 +178,15 @@ def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32], B: dace.float64[32
def test_aug_assign_free_map():

@dace.program
def sdfg_aug_assign_free_map(A: dace.float64[32]):
def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]):
for i in dace.map[0:32]:
with dace.tasklet:
a << A[i]
b >> A[i]
b = a * 2
with dace.tasklet(language=dace.Language.CPP):
a << A[0]
k << B[i]
b >> A[0]
"""
b = k * a;
"""

sdfg = sdfg_aug_assign_free_map.to_sdfg()
sdfg.simplify()
Expand Down

0 comments on commit 7c9e8e2

Please sign in to comment.