Skip to content

Commit

Permalink
Updated the strict dataflow mode.
Browse files Browse the repository at this point in the history
It is now made the default.
Also added some tests to check if it is properly working.

This mostly refactors the implementation.
It should be the same as before, however, the only thing that I really removed, that could become a problem, is the check if it is used in the same state.
This was always a hack, and the check should have been that it is not used in the same data flow.
But I think it was redundant anyway.
  • Loading branch information
philip-paul-mueller committed Dec 3, 2024
1 parent 244e3ea commit 11a509e
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 69 deletions.
96 changes: 28 additions & 68 deletions dace/transformation/dataflow/map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@ class MapFusion(transformation.SingleStateTransformation):
connections appropriately. Depending on the situation the transformation will
either fully remove or make the intermediate a new output of the second map.
By default, the transformation does not use the strict data flow mode. However,
it might be useful in come cases to enable it.
By default `strict_dataflow` is enabled. In this mode, the transformation
will not fuse maps that could potentially lead to a data race, because the
resulting combined map reads and writes from the same underlying data.
If strict dataflow is disabled, then the transformation might fuse such maps.
However, it will ensure that the accesses are point wise, this means that
in each iteration the map only accesses the same location that it also writes
to. Note that this could still lead to data races, because the order in which
DaCe generates the reads and writes is indeterministic.
Args:
only_inner_maps: Only match Maps that are internal, i.e. inside another Map.
only_toplevel_maps: Only consider Maps that are at the top.
strict_dataflow: If `True`, the transformation ensures a more
stricter version of the data flow.
strict_dataflow: Which dataflow mode should be used, see above.
Notes:
- This transformation modifies more nodes than it matches.
Expand Down Expand Up @@ -302,17 +307,6 @@ def partition_first_outputs(
# Set of intermediate nodes that we have already processed.
processed_inter_nodes: Set[nodes.Node] = set()

# These are the data that is written to multiple times in _this_ state.
# If a data is written to multiple time in a state, it could be
# classified as shared. However, it might happen that the node has zero
# degree. This is not a problem as the maps also induced a before-after
# relationship. But some DaCe transformations do not catch this.
# Thus we will never modify such intermediate nodes and fail instead.
if self.strict_dataflow:
multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg)
else:
multi_write_data = set()

# Now scan all output edges of the first exit and classify them
for out_edge in state.out_edges(map_exit_1):
intermediate_node: nodes.Node = out_edge.dst
Expand Down Expand Up @@ -356,16 +350,6 @@ def partition_first_outputs(
if self.is_view(intermediate_node, sdfg):
return None

# Checks if the intermediate node refers to data that is accessed by
# _other_ access nodes in _this_ state. If this is the case then never
# touch this intermediate node.
# TODO(phimuell): Technically it would be enough to turn the node into
# a shared output node, because this will still fulfil the dependencies.
# However, some DaCe transformation can not handle this properly, so we
# are _forced_ to reject this node.
if intermediate_node.data in multi_write_data:
return None

# Empty Memlets are only allowed if they are in `\mathbb{P}`, which
# is also the only place they really make sense (for a map exit).
# Thus if we now found an empty Memlet we reject it.
Expand Down Expand Up @@ -1033,11 +1017,6 @@ def has_read_write_dependency(
if not real_write_map_1.isdisjoint(real_write_map_2):
return True

# If there is no overlap in what is (totally) read and written, there will be no conflict.
# This must come before the check of disjoint write.
if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2):
return False

# These are the names (unresolved) and the access nodes of the data that is used
# to transmit information between the maps. The partition function ensures that
# these nodes are directly connected to the two maps.
Expand All @@ -1062,11 +1041,29 @@ def has_read_write_dependency(
if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names):
return True

# In strict data flow mode we require that the input and the output of
# the fused map is distinct.
# NOTE: The code below is able to handle cases were an input to map 1
# is also used as output of map 2. In this case the function check
# if they are point wise, i.e. every iteration reads from the same
# location it later writes to. However, even then it might cause
# problems because in which order the reads and writes are done is
# indeterministic. But if this is handled through other means, then
# it allows powerful optimizations.
if self.strict_dataflow:
if len(fused_inout_data_names) != 0:
return True

# A data container can be used as input and output. But we do not allow that
# it is also used as intermediate or exchange data. This is an important check.
# it is also used as intermediate or exchange data.
if not fused_inout_data_names.isdisjoint(exchange_names):
return True

# If there is no intersection between the input and output data, then we can
# we have nothing to check.
if len(fused_inout_data_names) == 0:
return False

# Get the replacement dict for changing the map variables from the subsets of
# the second map.
repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map)
Expand Down Expand Up @@ -1251,43 +1248,6 @@ def _compute_shared_data_in(
self._shared_data[sdfg] = shared_data


def _compute_multi_write_data(
self,
state: SDFGState,
sdfg: SDFG,
) -> Set[str]:
"""Computes data inside a _single_ state, that is written multiple times.
Essentially this function computes the set of data that does not follow
the single static assignment idiom. The function also resolves views.
If an access node, refers to a view, not only the view itself, but also
the data it refers to is added to the set.
Args:
state: The state that should be examined.
sdfg: The SDFG object.
Note:
This information is used by the partition function (in case strict data
flow mode is enabled), in strict data flow mode only. The current
implementation is rather simple as it only checks if a data is written
to multiple times in the same state.
"""
data_written_to: Set[str] = set()
multi_write_data: Set[str] = set()

for access_node in state.data_nodes():
if state.in_degree(access_node) == 0:
continue
if access_node.data in data_written_to:
multi_write_data.add(access_node.data)
elif self.is_view(access_node, sdfg):
# This is an over approximation.
multi_write_data.update([access_node.data, self.track_view(access_node, state, sdfg).data])
data_written_to.add(access_node.data)
return multi_write_data


def find_parameter_remapping(self, first_map: nodes.Map, second_map: nodes.Map) -> Union[Dict[str, str], None]:
"""Computes the parameter remapping for the parameters of the _second_ map.
Expand Down
115 changes: 114 additions & 1 deletion tests/transformations/mapfusion_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from typing import Any, Union
from typing import Any, Union, Tuple, Optional

import numpy as np
import os
import dace
import copy
import uuid

from dace import SDFG, SDFGState
from dace.sdfg import nodes
Expand Down Expand Up @@ -719,7 +720,119 @@ def _make_sdfg(N: int, M: int) -> dace.SDFG:
assert np.allclose(ref, res)


def _make_strict_dataflow_sdfg_pointwise(
input_data: str = "A",
intermediate_data: str = "T",
output_data: Optional[str] = None,
input_read: str = "__i0",
output_write: Optional[str] = None,
) -> Tuple[dace.SDFG, dace.SDFGState]:
"""
Creates the SDFG for the strict data flow tests.
The SDFG will read and write into `A`, but it is pointwise, thus the Maps can
be fused. Furthermore, this particular SDFG guarantees that no data race occurs.
"""
if output_data is None:
output_data = input_data
if output_write is None:
output_write = input_read

sdfg = dace.SDFG(f"strict_dataflow_sdfg_pointwise_{str(uuid.uuid1()).replace('-', '_')}")
state = sdfg.add_state(is_start_block=True)
for name in {input_data, intermediate_data, output_data}:
sdfg.add_array(
name,
shape=(10,),
dtype=dace.float64,
transient=False,
)

if intermediate_data not in {input_data, output_data}:
sdfg.arrays[intermediate_data].transient = True

input_node, intermediate_node, output_node = (state.add_access(name) for name in [input_data, intermediate_data, output_data])

state.add_mapped_tasklet(
"first_comp",
map_ranges={"__i0": "0:10"},
inputs={"__in1": dace.Memlet(f"{input_data}[{input_read}]")},
code="__out = __in1 + 2.0",
outputs={"__out": dace.Memlet(f"{intermediate_data}[__i0]")},
input_nodes={input_node},
output_nodes={intermediate_node},
external_edges=True,
)
state.add_mapped_tasklet(
"second_comp",
map_ranges={"__i1": "0:10"},
inputs={"__in1": dace.Memlet(f"{intermediate_data}[__i1]")},
code="__out = __in1 + 3.0",
outputs={"__out": dace.Memlet(f"{output_data}[{output_write}]")},
input_nodes={intermediate_node},
output_nodes={output_node},
external_edges=True,
)
sdfg.validate()
return sdfg, state


def test_fusion_strict_dataflow_pointwise():
sdfg, state = _make_strict_dataflow_sdfg_pointwise(input_data="A")

# Because `A` is used as input and output in strict data flow mode,
# the maps can not be fused.
count = sdfg.apply_transformations_repeated(
MapFusion(strict_dataflow=True),
validate=True,
validate_all=True,
)
assert count == 0

# However, if strict dataflow is disabled, then it will be able to fuse.
count = sdfg.apply_transformations_repeated(
MapFusion(strict_dataflow=False),
validate=True,
validate_all=True,
)
assert count == 1


def test_fusion_strict_dataflow_not_pointwise():
sdfg, state = _make_strict_dataflow_sdfg_pointwise(
input_data="A",
input_read="__i0",
output_write="9 - __i0",
)

# Because the dependency is not pointwise even disabling strict dataflow
# will not make it work.
count = sdfg.apply_transformations_repeated(
MapFusion(strict_dataflow=False),
validate=True,
validate_all=True,
)
assert count == 0


def test_fusion_dataflow_intermediate():
sdfg, _ = _make_strict_dataflow_sdfg_pointwise(
input_data="A",
intermediate_data="O",
output_data="O",
)
count = sdfg.apply_transformations_repeated(
MapFusion(strict_dataflow=True),
validate=True,
validate_all=True,
)
assert count == 0


if __name__ == '__main__':
test_fusion_strict_dataflow_pointwise()
test_fusion_strict_dataflow_not_pointwise()
test_fusion_dataflow_intermediate()
test_indirect_accesses()
test_fusion_shared()
test_fusion_with_transient()
Expand Down

0 comments on commit 11a509e

Please sign in to comment.