Skip to content

Commit

Permalink
Merge pull request #1267 from spcl/users/lukas/taskletfusion-bugfix
Browse files Browse the repository at this point in the history
Bugfix: Taskletfusion with map params
  • Loading branch information
tbennun authored Jun 8, 2023
2 parents e7aadae + 642bf2a commit 37b58bb
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 34 deletions.
28 changes: 10 additions & 18 deletions dace/transformation/dataflow/tasklet_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def visit_Name(self, node: ast.Name) -> Any:


class CPPConnectorRenamer():

def __init__(self, repl_dict: Dict[str, str]) -> None:
self.repl_dict = repl_dict

Expand All @@ -44,7 +43,6 @@ def rename(self, code: str) -> str:


class PythonInliner(ast.NodeTransformer):

def __init__(self, target_id, target_ast):
self.target_id = target_id
self.target_ast = target_ast
Expand All @@ -57,7 +55,6 @@ def visit_Name(self, node: ast.AST):


class CPPInliner():

def __init__(self, inline_target, inline_val):
self.inline_target = inline_target
self.inline_val = inline_val
Expand Down Expand Up @@ -144,10 +141,7 @@ class TaskletFusion(pm.SingleStateTransformation):

@classmethod
def expressions(cls):
return [
sdutil.node_path_graph(cls.t1, cls.data, cls.t2),
sdutil.node_path_graph(cls.t1, cls.t2)
]
return [sdutil.node_path_graph(cls.t1, cls.data, cls.t2), sdutil.node_path_graph(cls.t1, cls.t2)]

def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG, permissive: bool = False) -> bool:
t1 = self.t1
Expand Down Expand Up @@ -191,14 +185,15 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
t2_in_edge = graph.out_edges(data if data is not None else t1)[0]

# Remove the connector from the second Tasklet.
inputs = {
k: v for k, v in t2.in_connectors.items() if k != t2_in_edge.dst_conn
}
inputs = {k: v for k, v in t2.in_connectors.items() if k != t2_in_edge.dst_conn}

# Copy the first Tasklet's in connectors.
repldict = {}
for in_edge in graph.in_edges(t1):
old_value = in_edge.dst_conn
if old_value is None:
continue

# Check if there is a conflict.
if in_edge.dst_conn in inputs:
# Conflicts are ok if the Memlets are the same.
Expand All @@ -211,8 +206,8 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
break
else:
t2edge = conflict_edges[0]
if t2edge is not None and (in_edge.data != t2edge.data or in_edge.data.data != t2edge.data.data or
in_edge.data is None or in_edge.data.data is None):
if t2edge is not None and (in_edge.data != t2edge.data or in_edge.data.data != t2edge.data.data
or in_edge.data is None or in_edge.data.data is None):
in_edge.dst_conn = dace.data.find_new_name(in_edge.dst_conn, set(inputs))
repldict[old_value] = in_edge.dst_conn
else:
Expand All @@ -228,9 +223,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
if repldict:
assigned_value = PythonConnectorRenamer(repldict).visit(assigned_value)

new_code = [
PythonInliner(t2_in_edge.dst_conn, assigned_value).visit(line) for line in t2.code.code
]
new_code = [PythonInliner(t2_in_edge.dst_conn, assigned_value).visit(line) for line in t2.code.code]
new_code_str = '\n'.join(astunparse.unparse(line) for line in new_code)
elif t1.language == Language.CPP:
assigned_value = t1.code.as_string
Expand All @@ -252,9 +245,8 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
else:
return

new_tasklet = graph.add_tasklet(
t1.label + '_fused_' + t2.label, inputs, t2.out_connectors, new_code_str, t1.language
)
new_tasklet = graph.add_tasklet(t1.label + '_fused_' + t2.label, inputs, t2.out_connectors, new_code_str,
t1.language)

for in_edge in graph.in_edges(t1):
graph.add_edge(in_edge.src, in_edge.src_conn, new_tasklet, in_edge.dst_conn, in_edge.data)
Expand Down
52 changes: 36 additions & 16 deletions tests/transformations/tasklet_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
M = 10
N = 2 * M


@dace.program
def map_with_tasklets(A: datatype[N], B: datatype[M]):
C = np.zeros_like(B)
Expand Down Expand Up @@ -42,15 +43,11 @@ def _make_sdfg(language: str, with_data: bool = False):
outputs = {
'__out': datatype,
}
ta = state.add_tasklet(
'a', inputs, {
'__out1': datatype,
'__out2': datatype,
'__out3': datatype,
},
f'__out1 = __inp1 + __inp2{endl}__out2 = __out1{endl}__out3 = __out1{endl}',
lang
)
ta = state.add_tasklet('a', inputs, {
'__out1': datatype,
'__out2': datatype,
'__out3': datatype,
}, f'__out1 = __inp1 + __inp2{endl}__out2 = __out1{endl}__out3 = __out1{endl}', lang)
tb = state.add_tasklet('b', inputs, outputs, f'__out = __inp1 * __inp2{endl}', lang)
tc = state.add_tasklet('c', inputs, outputs, f'__out = __inp1 + __inp2{endl}', lang)
td = state.add_tasklet('d', inputs, outputs, f'__out = __inp1 / __inp2{endl}', lang)
Expand All @@ -60,12 +57,12 @@ def _make_sdfg(language: str, with_data: bool = False):
state.add_memlet_path(A, me, tb, memlet=dace.Memlet('A[2*i]'), dst_conn='__inp2')
state.add_memlet_path(B, me, tc, memlet=dace.Memlet('B[i]'), dst_conn='__inp2')
if with_data:
sdfg.add_array('tmp1', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp2', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp3', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp4', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp5', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp6', (1,), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp1', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp2', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp3', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp4', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp5', (1, ), datatype, dtypes.StorageType.Default, None, True)
sdfg.add_array('tmp6', (1, ), datatype, dtypes.StorageType.Default, None, True)
atemp1 = state.add_access('tmp1')
atemp2 = state.add_access('tmp2')
atemp3 = state.add_access('tmp3')
Expand Down Expand Up @@ -101,7 +98,7 @@ def test_basic():
def test_basic_tf(A: datatype[5, 5]):
B = A + 1
return B * 2

sdfg = test_basic_tf.to_sdfg(simplify=True)

num_map_fusions = sdfg.apply_transformations(MapFusion)
Expand Down Expand Up @@ -178,6 +175,28 @@ def test_tasklet_fusion_multiline(A: datatype):
assert (result[0] == 11)


def test_map_param():
@dace.program
def map_uses_param(A: dace.float32[10], B: dace.float32[10], C: dace.float32[10]):
for i in dace.map[0:10]:
a = i - A[i]
b = B[i] * i
C[i] = a + b

sdfg = map_uses_param.to_sdfg(simplify=True)

num_tasklet_fusions = sdfg.apply_transformations_repeated(TaskletFusion)
assert (num_tasklet_fusions == 3)

A = np.zeros([10], dtype=np.float32)
B = np.ones([10], dtype=np.float32)
C = np.empty([10], dtype=np.float32)
sdfg(A=A, B=B, C=C)

ref = np.array(range(0, 10, 1)) * 2.0
assert (C == ref).all()


@pytest.mark.parametrize('with_data', [pytest.param(True), pytest.param(False)])
@pytest.mark.parametrize('language', [pytest.param('CPP'), pytest.param('Python')])
def test_map_with_tasklets(language: str, with_data: bool):
Expand All @@ -200,6 +219,7 @@ def test_map_with_tasklets(language: str, with_data: bool):
test_same_name()
test_same_name_different_memlet()
test_tasklet_fusion_multiline()
test_map_param()
test_map_with_tasklets(language='Python', with_data=False)
test_map_with_tasklets(language='Python', with_data=True)
test_map_with_tasklets(language='CPP', with_data=False)
Expand Down

0 comments on commit 37b58bb

Please sign in to comment.