Skip to content

Commit

Permalink
Fix codegen generating non-atomic WCR w.r.t. neighboring edges
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Nov 10, 2024
1 parent 1333054 commit efdaacc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
42 changes: 34 additions & 8 deletions dace/codegen/targets/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from dace.frontend import operations
from dace.frontend.python import astutils
from dace.frontend.python.astutils import ExtNodeTransformer, rname, unparse
from dace.sdfg import nodes, graph as gr, utils
from dace.sdfg import nodes, graph as gr, utils, propagation
from dace.properties import LambdaProperty
from dace.sdfg import SDFG, is_devicelevel_gpu, SDFGState
from dace.codegen.targets import fpga
Expand Down Expand Up @@ -713,6 +713,31 @@ def _check_map_conflicts(map, edge):
return True


def _check_neighbor_conflicts(dfg, edge):
"""
Checks for other memlets writing to edges that may overlap in subsets.
Returns True if there are no conflicts, False if there may be.
"""
outer = propagation.propagate_memlet(dfg, edge.data, edge.dst, False)
siblings = dfg.in_edges(edge.dst)
for sibling in siblings:
if sibling is edge:
continue
if sibling.data.data != edge.data.data:
continue
# Check if there is definitely no overlap in the propagated memlet
sibling_outer = propagation.propagate_memlet(dfg, sibling.data, edge.dst, False)
if subsets.intersects(outer.subset, sibling_outer.subset) == False:
# In that case, continue
continue

# Other cases are indeterminate and will be atomic
return False
# No overlaps in current scope
return True


def write_conflicted_map_params(map, edge):
result = []
for itervar, (_, _, mapskip) in zip(map.params, map.range):
Expand Down Expand Up @@ -769,6 +794,8 @@ def is_write_conflicted_with_reason(dfg, edge, datanode=None, sdfg_schedule=None
for e in path:
if (isinstance(e.dst, nodes.ExitNode) and (e.dst.map.schedule != dtypes.ScheduleType.Sequential
and e.dst.map.schedule != dtypes.ScheduleType.Snitch)):
if not _check_neighbor_conflicts(dfg, e):
return e.dst
if _check_map_conflicts(e.dst.map, e):
# This map is parallel w.r.t. WCR
# print('PAR: Continuing from map')
Expand Down Expand Up @@ -984,10 +1011,9 @@ def unparse_tasklet(sdfg, cfg, state_id, dfg, node, function_stream, callsite_st
# To prevent variables-redefinition, build dictionary with all the previously defined symbols
defined_symbols = state_dfg.symbols_defined_at(node)

defined_symbols.update({
k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v))
for k, v in sdfg.constants.items()
})
defined_symbols.update(
{k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v))
for k, v in sdfg.constants.items()})

for connector, (memlet, _, _, conntype) in memlets.items():
if connector is not None:
Expand Down Expand Up @@ -1038,7 +1064,7 @@ def _Name(self, t: ast.Name):
# Replace values with their code-generated names (for example, persistent arrays)
desc = self.sdfg.arrays[t.id]
self.write(ptr(t.id, desc, self.sdfg, self.codegen))

def _Attribute(self, t: ast.Attribute):
from dace.frontend.python.astutils import rname
name = rname(t)
Expand Down Expand Up @@ -1325,8 +1351,8 @@ def visit_BinOp(self, node: ast.BinOp):
evaluated_constant = symbolic.evaluate(unparsed, self.constants)
evaluated = symbolic.symstr(evaluated_constant, cpp_mode=True)
value = ast.parse(evaluated).body[0].value
if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info
>= (3, 8) else value.n):
if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info >=
(3, 8) else value.n):
raise TypeError
node.right = ast.parse(evaluated).body[0].value
except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError):
Expand Down
1 change: 0 additions & 1 deletion tests/python_frontend/augassign_wcr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def no_wcr(A: dace.int32[5, 5, 5]):
assert (np.allclose(A, ref))


@pytest.mark.skip('Atomic reduction is generated as non-atomic')
def test_augassign_wcr4():

with dace.config.set_temporary('frontend', 'avoid_wcr', value=False):
Expand Down

0 comments on commit efdaacc

Please sign in to comment.