diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 911a792ac9..89239abcb3 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -26,7 +26,7 @@ from dace.frontend import operations from dace.frontend.python import astutils from dace.frontend.python.astutils import ExtNodeTransformer, rname, unparse -from dace.sdfg import nodes, graph as gr, utils +from dace.sdfg import nodes, graph as gr, utils, propagation from dace.properties import LambdaProperty from dace.sdfg import SDFG, is_devicelevel_gpu, SDFGState from dace.codegen.targets import fpga @@ -713,6 +713,31 @@ def _check_map_conflicts(map, edge): return True +def _check_neighbor_conflicts(dfg, edge): + """ + Checks for other memlets writing to edges that may overlap in subsets. + + Returns True if there are no conflicts, False if there may be. + """ + outer = propagation.propagate_memlet(dfg, edge.data, edge.dst, False) + siblings = dfg.in_edges(edge.dst) + for sibling in siblings: + if sibling is edge: + continue + if sibling.data.data != edge.data.data: + continue + # Check if there is definitely no overlap in the propagated memlet + sibling_outer = propagation.propagate_memlet(dfg, sibling.data, edge.dst, False) + if subsets.intersects(outer.subset, sibling_outer.subset) == False: + # In that case, continue + continue + + # Other cases are indeterminate and will be atomic + return False + # No overlaps in current scope + return True + + def write_conflicted_map_params(map, edge): result = [] for itervar, (_, _, mapskip) in zip(map.params, map.range): @@ -769,6 +794,8 @@ def is_write_conflicted_with_reason(dfg, edge, datanode=None, sdfg_schedule=None for e in path: if (isinstance(e.dst, nodes.ExitNode) and (e.dst.map.schedule != dtypes.ScheduleType.Sequential and e.dst.map.schedule != dtypes.ScheduleType.Snitch)): + if not _check_neighbor_conflicts(dfg, e): + return e.dst if _check_map_conflicts(e.dst.map, e): # This map is parallel w.r.t. WCR # print('PAR: Continuing from map') @@ -984,10 +1011,9 @@ def unparse_tasklet(sdfg, cfg, state_id, dfg, node, function_stream, callsite_st # To prevent variables-redefinition, build dictionary with all the previously defined symbols defined_symbols = state_dfg.symbols_defined_at(node) - defined_symbols.update({ - k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v)) - for k, v in sdfg.constants.items() - }) + defined_symbols.update( + {k: v.dtype if hasattr(v, 'dtype') else dtypes.typeclass(type(v)) + for k, v in sdfg.constants.items()}) for connector, (memlet, _, _, conntype) in memlets.items(): if connector is not None: @@ -1038,7 +1064,7 @@ def _Name(self, t: ast.Name): # Replace values with their code-generated names (for example, persistent arrays) desc = self.sdfg.arrays[t.id] self.write(ptr(t.id, desc, self.sdfg, self.codegen)) - + def _Attribute(self, t: ast.Attribute): from dace.frontend.python.astutils import rname name = rname(t) @@ -1325,8 +1351,8 @@ def visit_BinOp(self, node: ast.BinOp): evaluated_constant = symbolic.evaluate(unparsed, self.constants) evaluated = symbolic.symstr(evaluated_constant, cpp_mode=True) value = ast.parse(evaluated).body[0].value - if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info - >= (3, 8) else value.n): + if isinstance(evaluated_node, numbers.Number) and evaluated_node != (value.value if sys.version_info >= + (3, 8) else value.n): raise TypeError node.right = ast.parse(evaluated).body[0].value except (TypeError, AttributeError, NameError, KeyError, ValueError, SyntaxError): diff --git a/tests/python_frontend/augassign_wcr_test.py b/tests/python_frontend/augassign_wcr_test.py index a04ed4a623..46c0dd8802 100644 --- a/tests/python_frontend/augassign_wcr_test.py +++ b/tests/python_frontend/augassign_wcr_test.py @@ -158,7 +158,6 @@ def no_wcr(A: dace.int32[5, 5, 5]): assert (np.allclose(A, ref)) -@pytest.mark.skip('Atomic reduction is generated as non-atomic') def test_augassign_wcr4(): with dace.config.set_temporary('frontend', 'avoid_wcr', value=False):