diff --git a/dace/distr_types.py b/dace/distr_types.py index 1b595a1b84..77e4730ad1 100644 --- a/dace/distr_types.py +++ b/dace/distr_types.py @@ -598,3 +598,27 @@ def exit_code(self, sdfg): delete[] __state->{self.name}_self_dst; delete[] __state->{self.name}_self_size; """ + +@make_properties +class RMA_window(object): + """ + RMA_window is the descriptor class for MPI Remote Memory Access window + Real window creation is implemented in mpi.nodes.win_create.Win_create + """ + + name = Property(dtype=str, desc="The name of new window.") + def __init__(self, + name: str): + self.name = name + self._validate() + + def validate(self): + """ Validate the correctness of this object. + Raises an exception on error. """ + self._validate() + + # Validation of this class is in a separate function, so that this + # class can call `_validate()` without calling the subclasses' + # `validate` function. + def _validate(self): + return True diff --git a/dace/frontend/common/distr.py b/dace/frontend/common/distr.py index d6f22da358..26130e0560 100644 --- a/dace/frontend/common/distr.py +++ b/dace/frontend/common/distr.py @@ -15,6 +15,27 @@ RankType = Union[Integral, str, symbolic.symbol, symbolic.SymExpr, symbolic.sympy.Basic] ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor' +# a helper function for getting an access node by argument name +# creates a scalar if it's a number +def _get_int_arg_node(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + argument: Union[str, sp.Expr, Number] + ): + if isinstance(argument, str) and argument in sdfg.arrays.keys(): + arg_name = argument + arg_node = state.add_read(arg_name) + else: + # create a transient scalar and take its name + arg_name = _define_local_scalar(pv, sdfg, state, dace.int32) + arg_node = state.add_access(arg_name) + # every tasklet is in different scope, no need to worry about name confilct + color_tasklet = state.add_tasklet(f'_set_{arg_name}_', {}, {'__out'}, f'__out = {argument}') + state.add_edge(color_tasklet, '__out', arg_node, None, Memlet.simple(arg_node, '0')) + + return arg_name, arg_node + + ##### MPI Cartesian Communicators @@ -894,6 +915,554 @@ def _wait(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, request: str): return None +def _get_last_rma_op(sdfg: SDFG, + cur_op_name: str, + window_name: str, + is_trans: bool = False): + """ Get last RMA operation name of a window from the SDFG. + And do some logical checks if is_trans is True. + + :param sdfg: The sdfg for searching. + :param cur_op_name: current operation in the window. + :param window_name: The RMA window name for searching. + :param is_trans: check RMA sync is exist before op if this param is true + :return: Name of the last RMA operation. + """ + + all_rma_ops_name = list(sdfg._rma_ops.keys()) + cur_window_rma_ops = [rma_op for rma_op in all_rma_ops_name + if f"{window_name}_" in rma_op] + if len(cur_window_rma_ops) == 1: + last_rma_op_name = window_name + else: + last_rma_op_name = cur_window_rma_ops[cur_window_rma_ops.index(cur_op_name) - 1] + + if is_trans: + # if only odd number of fences or locks, + # that means we're in a ongoing epoch + # if even number, + # that means this operation might have corrupted sync + cur_window_fences = [rma_op for rma_op in cur_window_rma_ops + if f"{window_name}_fence" in rma_op] + cur_window_passive_syncs = [rma_op for rma_op in cur_window_rma_ops + if "lock" in rma_op] + if len(cur_window_fences) % 2 == 0 and len(cur_window_passive_syncs) % 2 == 0: + # if we don't have even number of syncs, give user a warning + print("You might have a bad synchronization of RMA calls!") + + return last_rma_op_name + + +@oprepo.replaces('mpi4py.MPI.Win.Create') +@oprepo.replaces('dace.Win.Create') +def _rma_window_create(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + buffer: str, + comm: Union[str, ShapeType], + grid: str = None): + """ Adds a RMA window to the DaCe Program. + + :param buffer: The name of window buffer. + :param comm: The comm world name of this window + :process_grid: Name of the process-grid for collective scatter/gather operations. + :return: Name of the window. + """ + + from dace.libraries.mpi.nodes.win_create import Win_create + + # if 'comm' is not a 'str' means it's using mpi4py objects + # which can only be deafult the comm world + if not isinstance(comm, str): + comm = None + + # fine a new window name + window_name = sdfg.add_window() + + window_node = Win_create(window_name, comm) + + buf_desc = sdfg.arrays[buffer] + buf_node = state.add_read(buffer) + state.add_edge(buf_node, + None, + window_node, + '_win_buffer', + Memlet.from_array(buffer, buf_desc)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = state.add_write(window_name) + state.add_edge(window_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + + return window_name + + +@oprepo.replaces_method('RMA_window', 'Fence') +def _rma_fence(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + assertion: Union[str, sp.Expr, Number] = 0): + """ Adds a RMA fence to the DaCe Program. + + :param window_name: The name of the window to be sychronized. + :param assertion: A value or scalar for fence assertion. + :return: Name of the fence. + """ + + from dace.libraries.mpi.nodes.win_fence import Win_fence + + # fine a new fence name + fence_name = sdfg.add_rma_ops(window_name, "fence") + + _, assertion_node = _get_int_arg_node(pv, sdfg, state, assertion) + + fence_node = Win_fence(fence_name, window_name) + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, fence_name, window_name) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + + # for window fence ordering + state.add_edge(last_rma_op_node, + None, + fence_node, + None, + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + state.add_edge(assertion_node, + None, + fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = state.add_write(fence_name) + state.add_edge(fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + + return window_name + + +@oprepo.replaces_method('RMA_window', 'Flush') +def _rma_flush(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + rank: Union[str, sp.Expr, Number]): + """ Adds a RMA flush to the DaCe Program. + flush will completes all outdtanding RMA operations + + :param window_name: The name of the window to be sychronized. + :param rank: A value or scalar to specify the target rank. + :return: Name of the flush. + """ + + from dace.libraries.mpi.nodes.win_flush import Win_flush + + # fine a new flush name + flush_name = sdfg.add_rma_ops(window_name, "flush") + + _, rank_node = _get_int_arg_node(pv, sdfg, state, rank) + + flush_node = Win_flush(flush_name, window_name) + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, flush_name, window_name) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + + # for ordering + state.add_edge(last_rma_op_node, + None, + flush_node, + None, + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + state.add_edge(rank_node, + None, + flush_node, + '_rank', + Memlet.simple(rank_node, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(flush_name, dace.int32, transient=True) + wnode = state.add_write(flush_name) + state.add_edge(flush_node, + "_out", + wnode, + None, + Memlet.from_array(flush_name, scal)) + + return window_name + + +@oprepo.replaces_method('RMA_window', 'Free') +def _rma_free(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + assertion: Union[str, sp.Expr, Number] = 0): + """ Adds a RMA free to the DaCe Program. + + :param window_name: The name of the window to be freed. + :return: Name of the free. + """ + + from dace.libraries.mpi.nodes.win_free import Win_free + + # fine a new free name + free_name = sdfg.add_rma_ops(window_name, "free") + + _, assertion_node = _get_int_arg_node(pv, sdfg, state, assertion) + + free_node = Win_free(free_name, window_name) + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, free_name, window_name) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + + # for window free ordering + state.add_edge(last_rma_op_node, + None, + free_node, + "_in", + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(free_name, dace.int32, transient=True) + wnode = state.add_write(free_name) + state.add_edge(free_node, + "_out", + wnode, + None, + Memlet.from_array(free_name, scal)) + + return window_name + + +@oprepo.replaces_method('RMA_window', 'Lock') +def _rma_lock(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + rank: Union[str, sp.Expr, Number], + lock_type: Union[str, sp.Expr, Number] = 234, # in intel MPI MPI.LOCK_EXCLUSIVE = 234 + assertion: Union[str, sp.Expr, Number] = 0): + """ Adds a RMA lock to the DaCe Program. + + :param window_name: The name of the window to be sychronized. + :param assertion: A value or scalar for lock assertion. + :return: Name of the lock. + """ + + from dace.libraries.mpi.nodes.win_lock import Win_lock + + # fine a new lock name + lock_name = sdfg.add_rma_ops(window_name, "lock") + lock_node = Win_lock(lock_name, window_name) + + # different MPI might get other value + if lock_type == 234: + from mpi4py import MPI + lock_type = MPI.LOCK_EXCLUSIVE + + _, rank_node = _get_int_arg_node(pv, sdfg, state, rank) + _, lock_type_node = _get_int_arg_node(pv, sdfg, state, lock_type) + _, assertion_node = _get_int_arg_node(pv, sdfg, state, assertion) + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, lock_name, window_name) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + + # for window lock ordering + state.add_edge(last_rma_op_node, + None, + lock_node, + None, + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + state.add_edge(rank_node, + None, + lock_node, + '_rank', + Memlet.simple(rank_node, "0:1", num_accesses=1)) + + state.add_edge(lock_type_node, + None, + lock_node, + '_lock_type', + Memlet.simple(lock_type_node, "0:1", num_accesses=1)) + + state.add_edge(assertion_node, + None, + lock_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(lock_name, dace.int32, transient=True) + wnode = state.add_write(lock_name) + state.add_edge(lock_node, + "_out", + wnode, + None, + Memlet.from_array(lock_name, scal)) + + return window_name + + +@oprepo.replaces_method('RMA_window', 'Unlock') +def _rma_unlock(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + rank: Union[str, sp.Expr, Number]): + """ Adds a RMA unlock to the DaCe Program. + Completes an RMA access epoch at the target process + + :param window_name: The name of the window to be sychronized. + :param rank: A value or scalar to specify the target rank. + :return: Name of the Unlock. + """ + + from dace.libraries.mpi.nodes.win_unlock import Win_unlock + + # fine a new unlock name + unlock_name = sdfg.add_rma_ops(window_name, "unlock") + + _, rank_node = _get_int_arg_node(pv, sdfg, state, rank) + + unlock_node = Win_unlock(unlock_name, window_name) + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, unlock_name, window_name) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + + # for ordering + state.add_edge(last_rma_op_node, + None, + unlock_node, + None, + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + state.add_edge(rank_node, + None, + unlock_node, + '_rank', + Memlet.simple(rank_node, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(unlock_name, dace.int32, transient=True) + wnode = state.add_write(unlock_name) + state.add_edge(unlock_node, + "_out", + wnode, + None, + Memlet.from_array(unlock_name, scal)) + + return window_name + + +@oprepo.replaces_method('RMA_window', 'Put') +def _rma_put(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + origin: str, + target_rank: Union[str, sp.Expr, Number]): + """ Initiate a RMA put for the DaCe Program. + + :param window_name: The name of the window to be sychronized. + :param origin: The name of origin buffer. + :target_rank: A value or scalar of the target rank. + :return: Name of the new RMA put descriptor. + """ + + from dace.libraries.mpi.nodes.win_put import Win_put + + put_name = sdfg.add_rma_ops(window_name, "put") + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, put_name, window_name, is_trans=True) + + put_node = Win_put(put_name, window_name) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + state.add_edge(last_rma_op_node, + None, + put_node, + "_in", + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + origin_node = state.add_read(origin) + origin_desc = sdfg.arrays[origin] + state.add_edge(origin_node, + None, + put_node, + '_inbuffer', + Memlet.from_array(origin, origin_desc)) + + _, target_rank_node = _get_int_arg_node(pv, sdfg, state, target_rank) + state.add_edge(target_rank_node, + None, + put_node, + '_target_rank', + Memlet.simple(target_rank_node, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(put_name, dace.int32, transient=True) + wnode = state.add_write(put_name) + state.add_edge(put_node, + "_out", + wnode, + None, + Memlet.from_array(put_name, scal)) + + return put_name + + +@oprepo.replaces_method('RMA_window', 'Get') +def _rma_get(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + origin: str, + target_rank: Union[str, sp.Expr, Number]): + """ Initiate a RMA get for the DaCe Program. + + :param window_name: The name of the window to be sychronized. + :param origin: The name of origin buffer. + :target_rank: A value or scalar of the target rank. + :return: Name of the new RMA get descriptor. + """ + + from dace.libraries.mpi.nodes.win_get import Win_get + + get_name = sdfg.add_rma_ops(window_name, "get") + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, get_name, window_name, is_trans=True) + + get_node = Win_get(get_name, window_name) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + state.add_edge(last_rma_op_node, + None, + get_node, + "_in", + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + _, target_rank_node = _get_int_arg_node(pv, sdfg, state, target_rank) + state.add_edge(target_rank_node, + None, + get_node, + '_target_rank', + Memlet.simple(target_rank_node, "0:1", num_accesses=1)) + + origin_node = state.add_write(origin) + origin_desc = sdfg.arrays[origin] + state.add_edge(get_node, + '_outbuffer', + origin_node, + None, + Memlet.from_array(origin, origin_desc)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(get_name, dace.int32, transient=True) + wnode = state.add_write(get_name) + state.add_edge(get_node, + '_out', + wnode, + None, + Memlet.from_array(get_name, scal)) + + return get_name + + +@oprepo.replaces_method('RMA_window', 'Accumulate') +def _rma_accumulate(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + window_name: str, + origin: str, + target_rank: Union[str, sp.Expr, Number], + op: str = "MPI_SUM"): + """ Initiate a RMA accumulate for the DaCe Program. + + :param window_name: The name of the window to be sychronized. + :param origin: The name of origin buffer. + :target_rank: A value or scalar of the target rank. + :op: The name of MPI reduction + :return: Name of the new RMA accumulate descriptor. + """ + from mpi4py import MPI + from dace.libraries.mpi.nodes.win_accumulate import Win_accumulate + + accumulate_name = sdfg.add_rma_ops(window_name, "accumulate") + + if isinstance(op, MPI.Op): + op = _mpi4py_to_MPI(MPI, op) + + # check for the last RMA operation + last_rma_op_name = _get_last_rma_op(sdfg, accumulate_name, window_name, is_trans=True) + + accumulate_node = Win_accumulate(accumulate_name, window_name, op) + + last_rma_op_node = state.add_read(last_rma_op_name) + last_rma_op_desc = sdfg.arrays[last_rma_op_name] + state.add_edge(last_rma_op_node, + None, + accumulate_node, + "_in", + Memlet.from_array(last_rma_op_name, last_rma_op_desc)) + + origin_node = state.add_read(origin) + origin_desc = sdfg.arrays[origin] + state.add_edge(origin_node, + None, + accumulate_node, + '_inbuffer', + Memlet.from_array(origin, origin_desc)) + + _, target_rank_node = _get_int_arg_node(pv, sdfg, state, target_rank) + state.add_edge(target_rank_node, + None, + accumulate_node, + '_target_rank', + Memlet.simple(target_rank_node, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(accumulate_name, dace.int32, transient=True) + wnode = state.add_write(accumulate_name) + state.add_edge(accumulate_node, + "_out", + wnode, + None, + Memlet.from_array(accumulate_name, scal)) + + return accumulate_name + + @oprepo.replaces('dace.comm.Subarray') def _subarray(pv: ProgramVisitor, sdfg: SDFG, diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c9d92b7860..5e92f6b487 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1307,6 +1307,9 @@ def defined(self): result.update( {k: self.sdfg.process_grids[v] for k, v in self.variables.items() if v in self.sdfg.process_grids}) + result.update( + {k: self.sdfg.rma_windows[v] + for k, v in self.variables.items() if v in self.sdfg.rma_windows}) try: from mpi4py import MPI result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)}) @@ -4686,6 +4689,8 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]: for operand in operands: if isinstance(operand, str) and operand in self.sdfg.process_grids: result.append((operand, type(self.sdfg.process_grids[operand]).__name__)) + elif isinstance(operand, str) and operand in self.sdfg.rma_windows: + result.append((operand, type(self.sdfg.rma_windows[operand]).__name__)) elif isinstance(operand, str) and operand in self.sdfg.arrays: result.append((operand, type(self.sdfg.arrays[operand]))) elif isinstance(operand, str) and operand in self.scope_arrays: diff --git a/dace/libraries/mpi/nodes/__init__.py b/dace/libraries/mpi/nodes/__init__.py index 0cd36cc82f..91d97091ac 100644 --- a/dace/libraries/mpi/nodes/__init__.py +++ b/dace/libraries/mpi/nodes/__init__.py @@ -13,3 +13,12 @@ from .alltoall import Alltoall from .dummy import Dummy from .redistribute import Redistribute +from .win_create import Win_create +from .win_fence import Win_fence +from .win_put import Win_put +from .win_get import Win_get +from .win_accumulate import Win_accumulate +from .win_lock import Win_lock +from .win_unlock import Win_unlock +from .win_flush import Win_flush +from .win_free import Win_free diff --git a/dace/libraries/mpi/nodes/win_accumulate.py b/dace/libraries/mpi/nodes/win_accumulate.py new file mode 100644 index 0000000000..6cc13b4bcd --- /dev/null +++ b/dace/libraries/mpi/nodes/win_accumulate.py @@ -0,0 +1,72 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinAccumulateMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + inbuffer, in_count_str = node.validate(parent_sdfg, parent_state) + mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(inbuffer.dtype.base_type) + + window_name = node.window_name + op = node.op + + code = f""" + MPI_Accumulate(_inbuffer, {in_count_str}, {mpi_dtype_str}, \ + _target_rank, 0, {in_count_str}, {mpi_dtype_str}, \ + {op}, __state->{window_name}_window); + """ + + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_accumulate(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinAccumulateMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + op = dace.properties.Property(dtype=str, default='MPI_SUM') + + def __init__(self, name, window_name, op="MPI_SUM", *args, **kwargs): + super().__init__(name, *args, inputs={"_in", "_inbuffer", "_target_rank"}, outputs={"_out"}, **kwargs) + self.window_name = window_name + self.op = op + + def validate(self, sdfg, state): + """ + :return: A three-tuple (buffer, root) of the three data descriptors in the + parent SDFG. + """ + + inbuffer = None + for e in state.in_edges(self): + if e.dst_conn == "_inbuffer": + inbuffer = sdfg.arrays[e.data.data] + + in_count_str = "XXX" + for _, _, _, dst_conn, data in state.in_edges(self): + if dst_conn == '_inbuffer': + dims = [str(e) for e in data.subset.size_exact()] + in_count_str = "*".join(dims) + + return inbuffer, in_count_str diff --git a/dace/libraries/mpi/nodes/win_create.py b/dace/libraries/mpi/nodes/win_create.py new file mode 100644 index 0000000000..7abfc02b96 --- /dev/null +++ b/dace/libraries/mpi/nodes/win_create.py @@ -0,0 +1,82 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinCreateMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + win_buffer, win_buf_count_str = node.validate(parent_sdfg, parent_state) + win_buffer_dtype = dace.libraries.mpi.utils.MPI_DDT(win_buffer.dtype.base_type) + window_name = node.name + + node.fields = [ + f"MPI_Win {window_name}_window;" + ] + + comm = "MPI_COMM_WORLD" + if node.comm: + comm = f"__state->{node.comm}_comm" + + code = f""" + MPI_Win_create(_win_buffer, + {win_buf_count_str} * sizeof({win_buffer_dtype}), + sizeof({win_buffer_dtype}), + MPI_INFO_NULL, + {comm}, + &__state->{window_name}_window); + """ + + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + state_fields=node.fields, + language=dace.dtypes.Language.CPP, + side_effects=True) + + return tasklet + + +@dace.library.node +class Win_create(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinCreateMPI, + } + default_implementation = "MPI" + + comm = dace.properties.Property(dtype=str, allow_none=True, default=None) + + def __init__(self, name, comm=None, *args, **kwargs): + super().__init__(name, *args, inputs={"_win_buffer"}, outputs={"_out"}, **kwargs) + self.comm = comm + + def validate(self, sdfg, state): + """ + :return: A three-tuple (buffer, root) of the three data descriptors in the + parent SDFG. + """ + + win_buffer = None + for e in state.in_edges(self): + if e.dst_conn == "_win_buffer": + win_buffer = sdfg.arrays[e.data.data] + + win_buf_count_str = "XXX" + for _, _, _, dst_conn, data in state.in_edges(self): + if dst_conn == '_win_buffer': + dims = [str(e) for e in data.subset.size_exact()] + win_buf_count_str = "*".join(dims) + + return win_buffer, win_buf_count_str diff --git a/dace/libraries/mpi/nodes/win_fence.py b/dace/libraries/mpi/nodes/win_fence.py new file mode 100644 index 0000000000..ae2d0a0dda --- /dev/null +++ b/dace/libraries/mpi/nodes/win_fence.py @@ -0,0 +1,43 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinFenceMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + window_name = node.window_name + code = f""" + MPI_Win_fence(_assertion, __state->{window_name}_window); + """ + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_fence(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinFenceMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + + def __init__(self, name, window_name, *args, **kwargs): + super().__init__(name, *args, inputs={"_assertion"}, outputs={"_out"}, **kwargs) + self.window_name = window_name diff --git a/dace/libraries/mpi/nodes/win_flush.py b/dace/libraries/mpi/nodes/win_flush.py new file mode 100644 index 0000000000..70e2ac1905 --- /dev/null +++ b/dace/libraries/mpi/nodes/win_flush.py @@ -0,0 +1,43 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinFlushMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + window_name = node.window_name + code = f""" + MPI_Win_flush(_rank, __state->{window_name}_window); + """ + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_flush(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinFlushMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + + def __init__(self, name, window_name, *args, **kwargs): + super().__init__(name, *args, inputs={"_rank"}, outputs={"_out"}, **kwargs) + self.window_name = window_name diff --git a/dace/libraries/mpi/nodes/win_free.py b/dace/libraries/mpi/nodes/win_free.py new file mode 100644 index 0000000000..81009093fc --- /dev/null +++ b/dace/libraries/mpi/nodes/win_free.py @@ -0,0 +1,43 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinFreeMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + window_name = node.window_name + code = f""" + MPI_Win_free(&__state->{window_name}_window); + """ + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_free(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinFreeMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + + def __init__(self, name, window_name, *args, **kwargs): + super().__init__(name, *args, inputs={"_in"}, outputs={"_out"}, **kwargs) + self.window_name = window_name diff --git a/dace/libraries/mpi/nodes/win_get.py b/dace/libraries/mpi/nodes/win_get.py new file mode 100644 index 0000000000..fb8f6bacb9 --- /dev/null +++ b/dace/libraries/mpi/nodes/win_get.py @@ -0,0 +1,68 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinGetMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + outbuffer, out_count_str = node.validate(parent_sdfg, parent_state) + mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(outbuffer.dtype.base_type) + + window_name = node.window_name + + code = f""" + MPI_Get(_outbuffer, {out_count_str}, {mpi_dtype_str}, \ + _target_rank, 0, {out_count_str}, {mpi_dtype_str}, \ + __state->{window_name}_window); + """ + + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_get(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinGetMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + + def __init__(self, name, window_name, *args, **kwargs): + super().__init__(name, *args, inputs={"_in", "_target_rank"}, outputs={"_out", "_outbuffer"}, **kwargs) + self.window_name = window_name + + def validate(self, sdfg, state): + """ + :return: A three-tuple (buffer, root) of the three data descriptors in the + parent SDFG. + """ + + outbuffer = None + for e in state.out_edges(self): + if e.src_conn == "_outbuffer": + outbuffer = sdfg.arrays[e.data.data] + out_count_str = "XXX" + for _, src_conn, _, _, data in state.out_edges(self): + if src_conn == '_outbuffer': + dims = [str(e) for e in data.subset.size_exact()] + out_count_str = "*".join(dims) + + return outbuffer, out_count_str diff --git a/dace/libraries/mpi/nodes/win_lock.py b/dace/libraries/mpi/nodes/win_lock.py new file mode 100644 index 0000000000..48a5fe6fd4 --- /dev/null +++ b/dace/libraries/mpi/nodes/win_lock.py @@ -0,0 +1,46 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinLockMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + window_name = node.window_name + code = f""" + MPI_Win_lock(_lock_type, + _rank, + _assertion, + __state->{window_name}_window); + """ + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_lock(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinLockMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + + def __init__(self, name, window_name, *args, **kwargs): + super().__init__(name, *args, inputs={"_rank", "_lock_type", "_assertion"}, outputs={"_out"}, **kwargs) + self.window_name = window_name diff --git a/dace/libraries/mpi/nodes/win_put.py b/dace/libraries/mpi/nodes/win_put.py new file mode 100644 index 0000000000..de3811cd7c --- /dev/null +++ b/dace/libraries/mpi/nodes/win_put.py @@ -0,0 +1,69 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinPutMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + inbuffer, in_count_str = node.validate(parent_sdfg, parent_state) + mpi_dtype_str = dace.libraries.mpi.utils.MPI_DDT(inbuffer.dtype.base_type) + + window_name = node.window_name + + code = f""" + MPI_Put(_inbuffer, {in_count_str}, {mpi_dtype_str}, \ + _target_rank, 0, {in_count_str}, {mpi_dtype_str}, \ + __state->{window_name}_window); + """ + + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_put(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinPutMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + + def __init__(self, name, window_name, *args, **kwargs): + super().__init__(name, *args, inputs={"_in", "_inbuffer", "_target_rank"}, outputs={"_out"}, **kwargs) + self.window_name = window_name + + def validate(self, sdfg, state): + """ + :return: A three-tuple (buffer, root) of the three data descriptors in the + parent SDFG. + """ + + inbuffer = None + for e in state.in_edges(self): + if e.dst_conn == "_inbuffer": + inbuffer = sdfg.arrays[e.data.data] + + in_count_str = "XXX" + for _, _, _, dst_conn, data in state.in_edges(self): + if dst_conn == '_inbuffer': + dims = [str(e) for e in data.subset.size_exact()] + in_count_str = "*".join(dims) + + return inbuffer, in_count_str diff --git a/dace/libraries/mpi/nodes/win_unlock.py b/dace/libraries/mpi/nodes/win_unlock.py new file mode 100644 index 0000000000..7bd6963fa9 --- /dev/null +++ b/dace/libraries/mpi/nodes/win_unlock.py @@ -0,0 +1,43 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandWinUnlockMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, **kwargs): + window_name = node.window_name + code = f""" + MPI_Win_unlock(_rank, __state->{window_name}_window); + """ + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Win_unlock(MPINode): + + # Global properties + implementations = { + "MPI": ExpandWinUnlockMPI, + } + default_implementation = "MPI" + + window_name = dace.properties.Property(dtype=str, default=None) + + def __init__(self, name, window_name, *args, **kwargs): + super().__init__(name, *args, inputs={"_rank"}, outputs={"_out"}, **kwargs) + self.window_name = window_name diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a23d2616f9..fa05505635 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -32,7 +32,7 @@ from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView from dace.sdfg.state import SDFGState from dace.sdfg.propagation import propagate_memlets_sdfg -from dace.distr_types import ProcessGrid, SubArray, RedistrArray +from dace.distr_types import ProcessGrid, SubArray, RedistrArray, RMA_window from dace.dtypes import validate_name from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty, TransformationHistProperty, OptionalSDFGReferenceProperty, DictProperty, CodeBlock) @@ -450,6 +450,16 @@ class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): desc="Process-grid descriptors for this SDFG", to_json=_arrays_to_json, from_json=_arrays_from_json) + _windows = DictProperty(str, + RMA_window, + desc="MPI RMA window descriptors for this SDFG", + to_json=_arrays_to_json, + from_json=_arrays_from_json) + _rma_ops = DictProperty(str, + str, + desc="MPI RMA ops descriptors for this SDFG", + to_json=_arrays_to_json, + from_json=_arrays_from_json) _subarrays = DictProperty(str, SubArray, desc="Sub-array descriptors for this SDFG", @@ -518,6 +528,8 @@ def __init__(self, # Grid-distribution-related fields self._pgrids = {} + self._windows = {} + self._rma_ops = {} self._subarrays = {} self._rdistrarrays = {} @@ -688,6 +700,16 @@ def process_grids(self): """ Returns a dictionary of process-grid descriptors (`ProcessGrid` objects) used in this SDFG. """ return self._pgrids + @property + def rma_windows(self): + """ Returns a dictionary of RMA window descriptors (`RMA_window` objects) used in this SDFG. """ + return self._windows + + @property + def rma_ops(self): + """ Returns a dictionary of RMA operations descriptors (an empty string) used in this SDFG. """ + return self._rma_ops + @property def subarrays(self): """ Returns a dictionary of sub-array descriptors (`SubArray` objects) used in this SDFG. """ @@ -1707,8 +1729,9 @@ def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ - names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | self._subarrays.keys() - | self._rdistrarrays.keys()) + names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | + self._subarrays.keys() | self._rdistrarrays.keys() | self._windows.keys() | + self._rma_ops.keys()) return dt.find_new_name(name, names) def find_new_constant(self, name: str): @@ -2091,6 +2114,26 @@ def add_pgrid(self, return grid_name + def add_window(self): + """ Adds a RMA window to the RMA window descriptor store. + """ + + window_name = self._find_new_name('__win') + + self._windows[window_name] = RMA_window(window_name) + + return window_name + + def add_rma_ops(self, window_name:str, op:str): + """ Adds a RMA op to the RMA ops descriptor store. + """ + + rma_op_name = self._find_new_name(f'{window_name}_{op}') + + self._rma_ops[rma_op_name] = "" + + return rma_op_name + def add_subarray(self, dtype: dtypes.typeclass, shape: ShapeType, diff --git a/samples/mpi/mat_mul.py b/samples/mpi/mat_mul.py new file mode 100644 index 0000000000..94a67139df --- /dev/null +++ b/samples/mpi/mat_mul.py @@ -0,0 +1,165 @@ +import numpy as np +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from mpi4py import MPI +import time + + +def matrix_mul(comm_world, dim_1, dim_2): + # comm init + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + a_mat = np.full((dim_1, dim_2), 1 + comm_rank, dtype=np.float32) + b_mat = np.full((dim_2, dim_1), 1 + comm_rank, dtype=np.float32) + c_mat = np.zeros((dim_1, dim_1), dtype=np.float32) + + @dace.program + def dist_mat_mult(a_mat: dace.float32[a_mat.shape[0], a_mat.shape[1]], + b_mat: dace.float32[b_mat.shape[0], b_mat.shape[1]], + c_mat: dace.float32[a_mat.shape[0], b_mat.shape[1]], + comm_rank: dace.int32, + comm_size: dace.int32): + grid_dim = int(np.floor(np.sqrt(comm_size))) + grid_i = comm_rank // grid_dim + grid_j = comm_rank % grid_dim + + local_i_dim = a_mat.shape[0] + local_j_dim = b_mat.shape[1] + local_k_dim = a_mat.shape[1] + + whole_i_dim = grid_dim * a_mat.shape[0] + whole_j_dim = grid_dim * b_mat.shape[1] + whole_k_dim = grid_dim * a_mat.shape[1] + + # local buffers for remote fetching + foreign_a_mat = np.zeros(a_mat.shape, dtype=np.float32) + foreign_b_mat = np.zeros(b_mat.shape, dtype=np.float32) + + # RMA windows + a_win = MPI.Win.Create(a_mat, comm=comm_world) + b_win = MPI.Win.Create(b_mat, comm=comm_world) + for i in range(whole_i_dim // local_i_dim): + for j in range(whole_j_dim // local_j_dim): + for k in range(whole_k_dim // local_k_dim): + # check if this process owns this chunk of data + if i == grid_i and j == grid_j: + target_rank_a = i * grid_dim + k + target_rank_b = k * grid_dim + j + a_win.Lock(target_rank_a) + a_win.Get(foreign_a_mat, target_rank=target_rank_a) + a_win.Flush(target_rank_a) + a_win.Unlock(target_rank_a) + + b_win.Lock(target_rank_b) + b_win.Get(foreign_b_mat, target_rank=target_rank_b) + b_win.Flush(target_rank_b) + b_win.Unlock(target_rank_b) + + c_mat += foreign_a_mat @ foreign_b_mat + + # as MPI barrier + # to ensure every process completed the calculation + a_win.Fence(0) + a_win.Fence(0) + + sdfg = None + if comm_rank == 0: + # ValueError: Node type "Win_lock" not supported for promotion + sdfg = dist_mat_mult.to_sdfg(simplify=False) + func = utils.distributed_compile(sdfg, comm_world) + + start = time.time() + + func(a_mat=a_mat, b_mat=b_mat, c_mat=c_mat, comm_rank=comm_rank, comm_size=comm_size) + + time_con = time.time() - start + + return c_mat, time_con + + +def weak_scaling(comm_world, comm_rank, comm_size): + grid_dim = int(np.floor(np.sqrt(comm_size))) + grid_i = comm_rank // grid_dim + grid_j = comm_rank % grid_dim + + dim_1 = 1024 + dim_2 = 1024 + + c_mat, time_con = matrix_mul(comm_world, dim_1, dim_2) + # print(comm_rank, c_mat) + # print(comm_rank, "matrix_mul time:", time_con) + + whole_a = np.ones((dim_1 * grid_dim, dim_2 * grid_dim), dtype=np.float32) + for i in range(grid_dim): + for j in range(grid_dim): + whole_a[i * dim_1:(i+1) * dim_1, j * dim_2:(j+1) * dim_2] += i * grid_dim + j + + whole_b = np.ones((dim_2 * grid_dim, dim_1 * grid_dim), dtype=np.float32) + for i in range(grid_dim): + for j in range(grid_dim): + whole_b[i * dim_2:(i+1) * dim_2, j * dim_1:(j+1) * dim_1] += i * grid_dim + j + + start = time.time() + c_np = np.matmul(whole_a, whole_b) + time_con = time.time() - start + + # print(comm_rank, c_np[grid_i * dim_1:(grid_i+1) * dim_1, grid_j* dim_2:(grid_j+1) * dim_2]) + # print(comm_rank, "np.matmul time:", time_con) + + # print("Result correctness:", np.allclose(c_mat, c_np[grid_i * dim_1:(grid_i+1) * dim_1, grid_j* dim_2:(grid_j+1) * dim_2])) + assert(np.allclose(c_mat, c_np[grid_i * dim_1:(grid_i+1) * dim_1, grid_j* dim_2:(grid_j+1) * dim_2])) + + +def strong_scaling(comm_world, comm_rank, comm_size): + grid_dim = int(np.floor(np.sqrt(comm_size))) + grid_i = comm_rank // grid_dim + grid_j = comm_rank % grid_dim + + total_dim = 8192 + dim_1 = total_dim + dim_2 = total_dim + if total_dim % comm_size > 0: + dim_1 += comm_size - total_dim % comm_size + dim_2 += comm_size - total_dim % comm_size + + local_dim_1 = dim_1 // grid_dim + local_dim_2 = dim_2 // grid_dim + + a = np.ones((local_dim_1, local_dim_2), dtype=np.float32) + b = np.ones((local_dim_2, local_dim_1), dtype=np.float32) + + c_mat, time_con = matrix_mul(comm_world, local_dim_1, local_dim_2) + # print(comm_rank, c_mat) + # print(comm_rank, "matrix_mul time:", time_con) + + # validation, since it will compute the whole matrix in the edge + # whole_a = np.ones((local_dim_1 * grid_dim, local_dim_2 * grid_dim), dtype=np.float32) + # for i in range(grid_dim): + # for j in range(grid_dim): + # whole_a[i * local_dim_1:(i+1) * local_dim_1, j * local_dim_2:(j+1) * local_dim_2] += i * grid_dim + j + + # whole_b = np.ones((local_dim_2 * grid_dim, local_dim_1 * grid_dim), dtype=np.float32) + # for i in range(grid_dim): + # for j in range(grid_dim): + # whole_b[i * local_dim_2:(i+1) * local_dim_2, j * local_dim_1:(j+1) * local_dim_1] += i * grid_dim + j + + # start = time.time() + # c_np = np.matmul(whole_a, whole_b) + # time_con = time.time() - start + # # print("Result correctness:", np.allclose(c_mat, c_np[grid_i * local_dim_1:(grid_i+1) * local_dim_1, grid_j* local_dim_2:(grid_j+1) * local_dim_2])) + # assert(np.allclose(c_mat, c_np[grid_i * local_dim_1:(grid_i+1) * local_dim_1, grid_j* local_dim_2:(grid_j+1) * local_dim_2])) + +if __name__ == "__main__": + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + grid_dim = int(np.floor(np.sqrt(comm_size))) + + if comm_size != grid_dim ** 2: + raise ValueError("Please run this test with a square number of processes.") + + # weak_scaling(comm_world, comm_rank, comm_size) + strong_scaling(comm_world, comm_rank, comm_size) diff --git a/samples/mpi/ping_pong.py b/samples/mpi/ping_pong.py new file mode 100644 index 0000000000..d8cf490f62 --- /dev/null +++ b/samples/mpi/ping_pong.py @@ -0,0 +1,117 @@ +import numpy as np +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from mpi4py import MPI +import time + +dim_1 = 128 +dim_2 = 128 + +a = np.arange(dim_1 * dim_2).reshape(dim_1, dim_2) +b = np.arange(dim_1 * dim_2).reshape(dim_2, dim_1) + +# to check if this process owns this chunk of data +# compare given i and j with grid_i and grid_j +@dace.program +def owner(i, j, grid_i, grid_j): + if i == grid_i and j == grid_j: + return True + else: + return False + +# get matrix form remote rank +@dace.program +def get_mat(win: dace.RMA_window, buffer: dace.int32[dim_1,dim_2], dim_0: dace.int32, dim_1: dace.int32, grid_dim: dace.int32): + rank = dim_0 * grid_dim + dim_1 + win.Lock(rank) + win.Get(buffer, target_rank=rank) + win.Flush(rank) + win.Unlock(rank) + +def matrix_mul(a, b): + # check if matrix multiplication is valid + if a.shape[1] != b.shape[0]: + raise ValueError("A, B matrix dimension mismatched!") + + # comm init + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + grid_dim = 2 + grid_i = comm_rank // grid_dim + grid_j = comm_rank % grid_dim + + if comm_size != 2: + raise ValueError("Please run this test with two processes.") + + a_mat = np.array(a + comm_rank, dtype=np.int64) + b_mat = np.array(b + comm_rank, dtype=np.int64) + foreign_a_mat = np.zeros(a.shape, dtype=np.int64) + foreign_b_mat = np.zeros(b.shape, dtype=np.int64) + c_mat = np.zeros((a_mat.shape[0], b_mat.shape[1]), dtype=np.int64) + + + # more or less like C stationary + # for i in range(a_mat.shape[0]): + # for j in range(b_mat.shape[1]): + # for k in range(a_mat.shape[1]): + # c_mat[i][j] += a_mat[i][k] * b_mat[k][j] + + + @dace.program + def mpi4py_send_recv(comm_rank: dace.int32, a_mat: dace.int32[dim_1,dim_2], foreign_a_mat: dace.int32[dim_1,dim_2], grid_dim: dace.int32): + a_win = MPI.Win.Create(a_mat, comm=comm_world) + if comm_rank == 0: + get_mat(a_win, foreign_a_mat, 0, 1, grid_dim) + else: + get_mat(a_win, foreign_a_mat, 0, 0, grid_dim) + return foreign_a_mat + + sdfg = None + if comm_rank == 0: + sdfg = mpi4py_send_recv.to_sdfg(simplify=True) + func = utils.distributed_compile(sdfg, comm_world) + + + start = time.time() + + foreign_a_mat = func(comm_rank=comm_rank, a_mat=a_mat, foreign_a_mat=foreign_a_mat, grid_dim=grid_dim) + if comm_rank == 0: + if(np.allclose(a_mat+1, foreign_a_mat)): + print("Good") + else: + if(np.allclose(a_mat-1, foreign_a_mat)): + print("Good") + + time_con = time.time() - start + + + # to ensure every process completed the calculation + comm_world.Barrier() + +matrix_mul(a,b) + + # more or less like C stationary + # for i in range(a_mat.shape[0]): + # for j in range(b_mat.shape[1]): + # for k in range(a_mat.shape[1]): + # c_mat[i][j] += a_mat[i][k] * b_mat[k][j] + + # @dace.program + # def mpi4py_passive_rma_put(a_mat: dace.int32[dim_1,dim_2], b_mat: dace.int32[dim_1,dim_2], c_mat: dace.int32[dim_1,dim_2], tile: dace.int32): + # for i_tile in range(a_mat.shape[0] // tile): + # for j_tile in range(b_mat.shape[1] // tile): + # for k_tile in range(a_mat.shape[1] // tile): + # for i in range(i_tile * tile, min((i_tile + 1) * tile, a_mat.shape[0])): + # for j in range(j_tile * tile, min((j_tile + 1) * tile, b_mat.shape[1])): + # for k in range(k_tile * tile, min((k_tile + 1) * tile, a_mat.shape[1])): + # c_mat[i][j] += a_mat[i][k] * b_mat[k][j] + + # sdfg = None + # sdfg = mpi4py_passive_rma_put.to_sdfg() + # sdfg.openmp_sections = False + # func = sdfg.compile() + + # func(a_mat, b_mat, c_mat, tile) \ No newline at end of file diff --git a/tests/library/mpi/mpi4py_test.py b/tests/library/mpi/mpi4py_test.py index 52b5deb7a8..124f4299dd 100644 --- a/tests/library/mpi/mpi4py_test.py +++ b/tests/library/mpi/mpi4py_test.py @@ -39,6 +39,226 @@ def comm_world_bcast(A: dace.int32[10]): assert (np.array_equal(A, A_ref)) +@pytest.mark.mpi +def test_RMA_put(): + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def mpi4py_rma_put(win_buf: dace.int32[10], send_buf: dace.int32[10], rank: dace.int32): + win = MPI.Win.Create(win_buf, comm=commworld) + win.Fence(0) + win.Put(send_buf, target_rank=rank) + win.Fence(0) + win.Free() + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = mpi4py_rma_put.to_sdfg() + func = utils.distributed_compile(sdfg, commworld) + + window_size = 10 + win_buffer = np.full(window_size, rank, dtype=np.int32) + win_buffer_ref = np.full(window_size, rank, dtype=np.int32) + send_buffer = np.full(window_size, rank, dtype=np.int32) + + func(win_buf=win_buffer, send_buf=send_buffer, rank=((rank + 1) % size)) + mpi4py_rma_put.f(win_buf=win_buffer_ref, send_buf=send_buffer, rank=((rank + 1) % size)) + + assert (np.array_equal(win_buffer, win_buffer_ref)) + + +@pytest.mark.mpi +def test_RMA_get(): + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def mpi4py_rma_get(win_buf: dace.int32[10], recv_buf: dace.int32[10], rank: dace.int32): + win = MPI.Win.Create(win_buf, comm=commworld) + win.Fence(0) + win.Get(recv_buf, target_rank=rank) + win.Fence(0) + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = mpi4py_rma_get.to_sdfg() + func = utils.distributed_compile(sdfg, commworld) + + window_size = 10 + win_buffer = np.full(window_size, rank, dtype=np.int32) + recv_buf = np.full(window_size, rank, dtype=np.int32) + recv_buf_ref = np.full(window_size, rank, dtype=np.int32) + + func(win_buf=win_buffer, recv_buf=recv_buf, rank=((rank + 1) % size)) + mpi4py_rma_get.f(win_buf=win_buffer, recv_buf=recv_buf_ref, rank=((rank + 1) % size)) + + assert (np.array_equal(recv_buf, recv_buf_ref)) + + +@pytest.mark.mpi +def test_RMA_accumulate(): + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + # sum all rank at rank 0 + @dace.program + def mpi4py_rma_accumulate(win_buf: dace.int32[10], send_buf: dace.int32[10], rank: dace.int32): + win = MPI.Win.Create(win_buf, comm=commworld) + win.Fence(0) + win.Accumulate(send_buf, target_rank=rank, op=MPI.SUM) + win.Fence(0) + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = mpi4py_rma_accumulate.to_sdfg() + func = utils.distributed_compile(sdfg, commworld) + + window_size = 10 + win_buffer = np.full(window_size, rank, dtype=np.int32) + win_buffer_ref = np.full(window_size, rank, dtype=np.int32) + send_buffer = np.full(window_size, rank, dtype=np.int32) + + func(win_buf=win_buffer, send_buf=send_buffer, rank=0) + mpi4py_rma_accumulate.f(win_buf=win_buffer_ref, send_buf=send_buffer, rank=0) + + assert (np.array_equal(win_buffer, win_buffer_ref)) + + +@pytest.mark.mpi +def test_passive_RMA_put(): + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def mpi4py_passive_rma_put(win_buf: dace.int32[10], send_buf: dace.int32[10], rank: dace.int32): + win = MPI.Win.Create(win_buf, comm=commworld) + win.Lock(rank) + win.Put(send_buf, target_rank=rank) + win.Flush(rank) + win.Unlock(rank) + + # as MPI barrier + win.Fence(0) + win.Fence(0) + + win.Free() + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = mpi4py_passive_rma_put.to_sdfg() + func = utils.distributed_compile(sdfg, commworld) + + window_size = 10 + win_buffer = np.full(window_size, rank, dtype=np.int32) + win_buffer_ref = np.full(window_size, rank, dtype=np.int32) + send_buffer = np.full(window_size, rank, dtype=np.int32) + + + func(win_buf=win_buffer, send_buf=send_buffer, rank=((rank + 1) % size)) + mpi4py_passive_rma_put.f(win_buf=win_buffer_ref, send_buf=send_buffer, rank=((rank + 1) % size)) + + assert (np.array_equal(win_buffer, win_buffer_ref)) + + +@pytest.mark.mpi +def test_passive_RMA_get(): + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def mpi4py_passive_rma_get(win_buf: dace.int32[10], recv_buf: dace.int32[10], rank: dace.int32): + win = MPI.Win.Create(win_buf, comm=commworld) + win.Lock(rank) + win.Get(recv_buf, target_rank=rank) + win.Flush(rank) + win.Unlock(rank) + + # as MPI barrier + win.Fence(0) + win.Fence(0) + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = mpi4py_passive_rma_get.to_sdfg() + func = utils.distributed_compile(sdfg, commworld) + + window_size = 10 + win_buffer = np.full(window_size, rank, dtype=np.int32) + recv_buf = np.full(window_size, rank, dtype=np.int32) + recv_buf_ref = np.full(window_size, rank, dtype=np.int32) + + func(win_buf=win_buffer, recv_buf=recv_buf, rank=((rank + 1) % size)) + mpi4py_passive_rma_get.f(win_buf=win_buffer, recv_buf=recv_buf_ref, rank=((rank + 1) % size)) + + assert (np.array_equal(recv_buf, recv_buf_ref)) + + +@pytest.mark.mpi +def test_RMA_passive_accumulate(): + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + # sum all rank at rank 0 + @dace.program + def mpi4py_passive_rma_accumulate(win_buf: dace.int32[10], send_buf: dace.int32[10], rank: dace.int32): + win = MPI.Win.Create(win_buf, comm=commworld) + win.Lock(rank) + win.Accumulate(send_buf, target_rank=rank, op=MPI.SUM) + win.Flush(rank) + win.Unlock(rank) + + # as MPI barrier + win.Fence(0) + win.Fence(0) + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = mpi4py_passive_rma_accumulate.to_sdfg() + func = utils.distributed_compile(sdfg, commworld) + + window_size = 10 + win_buffer = np.full(window_size, rank, dtype=np.int32) + win_buffer_ref = np.full(window_size, rank, dtype=np.int32) + send_buffer = np.full(window_size, rank, dtype=np.int32) + + func(win_buf=win_buffer, send_buf=send_buffer, rank=0) + mpi4py_passive_rma_accumulate.f(win_buf=win_buffer_ref, send_buf=send_buffer, rank=0) + + if rank == 0: + assert (np.array_equal(win_buffer, win_buffer_ref)) + + @pytest.mark.mpi def test_external_comm_bcast(): @@ -348,3 +568,9 @@ def mpi4py_alltoall(rank: dace.int32, size: dace.compiletime): test_isend_irecv() test_send_recv() test_alltoall() + test_RMA_put() + test_RMA_get() + test_RMA_accumulate() + test_passive_RMA_put() + test_passive_RMA_get() + test_RMA_passive_accumulate() diff --git a/tests/library/mpi/mpi_free_test.py b/tests/library/mpi/mpi_free_test.py new file mode 100644 index 0000000000..f87220e4bc --- /dev/null +++ b/tests/library/mpi/mpi_free_test.py @@ -0,0 +1,102 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +############################################################################### + + +def make_sdfg(dtype): + n = dace.symbol("n") + + sdfg = dace.SDFG("mpi_win_free") + window_state = sdfg.add_state("create_window") + + sdfg.add_array("win_buffer", [n], dtype=dtype, transient=False) + win_buffer = window_state.add_access("win_buffer") + + window_name = sdfg.add_window() + win_create_node = mpi.nodes.win_create.Win_create(window_name) + + window_state.add_edge(win_buffer, + None, + win_create_node, + '_win_buffer', + Memlet.simple(win_buffer, "0:n", num_accesses=n)) + + # for other nodes depends this window to connect + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = window_state.add_write(window_name) + window_state.add_edge(win_create_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + +############################################################################### + + free_state = sdfg.add_state("win_free") + + sdfg.add_edge(window_state, free_state, dace.InterstateEdge()) + + free_name = sdfg.add_rma_ops(window_name, "free") + win_free_node = mpi.nodes.win_free.Win_free(free_name, window_name) + + # pseudo access for ordering + window_node = free_state.add_access(window_name) + window_desc = sdfg.arrays[window_name] + + free_state.add_edge(window_node, + None, + win_free_node, + "_in", + Memlet.from_array(window_name, window_desc)) + + _, scal = sdfg.add_scalar(free_name, dace.int32, transient=True) + wnode = free_state.add_write(free_name) + free_state.add_edge(win_free_node, + "_out", + wnode, + None, + Memlet.from_array(free_name, scal)) + + return sdfg + + +############################################################################### + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("MPI", dace.float32, marks=pytest.mark.mpi), + pytest.param("MPI", dace.int32, marks=pytest.mark.mpi) +]) +def test_win_free(dtype): + from mpi4py import MPI + np_dtype = getattr(np, dtype.to_string()) + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("This test is supposed to be run with at least two processes!") + + mpi_func = None + for r in range(0, comm_size): + if r == comm_rank: + sdfg = make_sdfg(dtype) + mpi_func = sdfg.compile() + comm_world.Barrier() + + window_size = 10 + win_buffer = np.arange(0, window_size, dtype=np_dtype) + + mpi_func(win_buffer=win_buffer, n=window_size) + +if __name__ == "__main__": + test_win_free(dace.int32) + test_win_free(dace.float32) diff --git a/tests/library/mpi/win_accumulate_test.py b/tests/library/mpi/win_accumulate_test.py new file mode 100644 index 0000000000..c5338e12ac --- /dev/null +++ b/tests/library/mpi/win_accumulate_test.py @@ -0,0 +1,205 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +############################################################################### + + +def make_sdfg(dtype): + n = dace.symbol("n") + + sdfg = dace.SDFG("mpi_win_accumulate") + window_state = sdfg.add_state("create_window") + + sdfg.add_array("assertion", [1], dtype=dace.int32, transient=False) + sdfg.add_array("win_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("send_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("target_rank", [1], dace.dtypes.int32, transient=False) + + win_buffer = window_state.add_access("win_buffer") + + window_name = sdfg.add_window() + win_create_node = mpi.nodes.win_create.Win_create(window_name) + + window_state.add_edge(win_buffer, + None, + win_create_node, + '_win_buffer', + Memlet.simple(win_buffer, "0:n", num_accesses=n)) + + # for other nodes depends this window to connect + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = window_state.add_write(window_name) + window_state.add_edge(win_create_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + +############################################################################### + + fence_state_1 = sdfg.add_state("win_fence_1") + + sdfg.add_edge(window_state, fence_state_1, dace.InterstateEdge()) + + fence_name = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name, window_name) + + # pseudo access for ordering + window_node = fence_state_1.add_access(window_name) + window_desc = sdfg.arrays[window_name] + + fence_state_1.add_edge(window_node, + None, + win_fence_node, + None, + Memlet.from_array(window_name, window_desc)) + + assertion_node = fence_state_1.add_access("assertion") + + fence_state_1.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = fence_state_1.add_write(fence_name) + fence_state_1.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + +############################################################################### + + accumulate_state = sdfg.add_state("win_accumulate") + + sdfg.add_edge(fence_state_1, accumulate_state, dace.InterstateEdge()) + + accumulate_name = sdfg.add_rma_ops(window_name, "accumulate") + win_accumulate_node = mpi.nodes.win_accumulate.Win_accumulate(accumulate_name, window_name) + + # pseudo access for ordering + fence_node = accumulate_state.add_access(fence_name) + fence_desc = sdfg.arrays[fence_name] + + send_buffer = accumulate_state.add_access("send_buffer") + + target_rank = accumulate_state.add_access("target_rank") + + accumulate_state.add_edge(fence_node, + None, + win_accumulate_node, + "_in", + Memlet.from_array(fence_name, fence_desc)) + + accumulate_state.add_edge(send_buffer, + None, + win_accumulate_node, + "_inbuffer", + Memlet.simple(send_buffer, "0:n", num_accesses=n)) + + accumulate_state.add_edge(target_rank, + None, + win_accumulate_node, + "_target_rank", + Memlet.simple(target_rank, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(accumulate_name, dace.int32, transient=True) + wnode = accumulate_state.add_write(accumulate_name) + accumulate_state.add_edge(win_accumulate_node, + "_out", + wnode, + None, + Memlet.from_array(accumulate_name, scal)) + +############################################################################### + + fence_state_2 = sdfg.add_state("win_fence_2") + + sdfg.add_edge(accumulate_state, fence_state_2, dace.InterstateEdge()) + + fence_name = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name, window_name) + + # pseudo access for ordering + accumulate_node = fence_state_2.add_access(accumulate_name) + accumulate_desc = sdfg.arrays[accumulate_name] + + fence_state_2.add_edge(accumulate_node, + None, + win_fence_node, + None, + Memlet.from_array(accumulate_name, accumulate_desc)) + + assertion_node = fence_state_2.add_access("assertion") + + fence_state_2.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = fence_state_2.add_write(fence_name) + fence_state_2.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + + return sdfg + + +############################################################################### + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("MPI", dace.float32, marks=pytest.mark.mpi), + pytest.param("MPI", dace.int32, marks=pytest.mark.mpi) +]) +def test_win_accumulate(dtype): + from mpi4py import MPI + np_dtype = getattr(np, dtype.to_string()) + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("This test is supposed to be run with at least two processes!") + + mpi_func = None + for r in range(0, comm_size): + if r == comm_rank: + sdfg = make_sdfg(dtype) + mpi_func = sdfg.compile() + comm_world.Barrier() + + window_size = 10 + win_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + send_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + + # accumulate all ranks in rank 0 + target_rank = np.full([1], 0, dtype=np.int32) + assertion = np.full([1], 0, dtype=np.int32) + + mpi_func(assertion=assertion, + win_buffer=win_buffer, + send_buffer=send_buffer, + target_rank=target_rank, + n=window_size) + + correct_data = np.full(window_size, comm_size * (comm_size - 1) / 2, dtype=np_dtype) + if (comm_rank == 0 and not np.allclose(win_buffer, correct_data)): + raise (ValueError("The received values are not what I expected on root.")) + +if __name__ == "__main__": + test_win_accumulate(dace.int32) + test_win_accumulate(dace.float32) diff --git a/tests/library/mpi/win_create_test.py b/tests/library/mpi/win_create_test.py new file mode 100644 index 0000000000..db9d356c74 --- /dev/null +++ b/tests/library/mpi/win_create_test.py @@ -0,0 +1,80 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +############################################################################### + + +def make_sdfg(dtype): + n = dace.symbol("n") + + sdfg = dace.SDFG("mpi_win_create") + state = sdfg.add_state("start") + + sdfg.add_array("win_buffer", [n], dtype=dtype, transient=False) + win_buffer = state.add_access("win_buffer") + + window_name = sdfg.add_window() + win_create_node = mpi.nodes.win_create.Win_create(window_name) + + state.add_edge(win_buffer, + None, + win_create_node, + '_win_buffer', + Memlet.simple(win_buffer, "0:n", num_accesses=n)) + + # for other nodes depends this window to connect + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = state.add_write(window_name) + state.add_edge(win_create_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + + return sdfg + + +############################################################################### + + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("MPI", dace.float32, marks=pytest.mark.mpi), + pytest.param("MPI", dace.int32, marks=pytest.mark.mpi) +]) +def test_win_create(dtype): + from mpi4py import MPI + np_dtype = getattr(np, dtype.to_string()) + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("This test is supposed to be run with at least two processes!") + + mpi_func = None + for r in range(0, comm_size): + if r == comm_rank: + sdfg = make_sdfg(dtype) + mpi_func = sdfg.compile() + comm_world.Barrier() + + window_size = 10 + win_buffer = np.arange(0, window_size, dtype=np_dtype) + + mpi_func(win_buffer=win_buffer, n=window_size) + + +############################################################################### + + +if __name__ == "__main__": + test_win_create(dace.float32) + test_win_create(dace.int32) diff --git a/tests/library/mpi/win_fence_test.py b/tests/library/mpi/win_fence_test.py new file mode 100644 index 0000000000..20a6b11f0f --- /dev/null +++ b/tests/library/mpi/win_fence_test.py @@ -0,0 +1,112 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +############################################################################### + + +def make_sdfg(dtype): + n = dace.symbol("n") + + sdfg = dace.SDFG("mpi_win_fence") + window_state = sdfg.add_state("create_window") + + sdfg.add_array("win_buffer", [n], dtype=dtype, transient=False) + win_buffer = window_state.add_access("win_buffer") + + window_name = sdfg.add_window() + win_create_node = mpi.nodes.win_create.Win_create(window_name) + + window_state.add_edge(win_buffer, + None, + win_create_node, + '_win_buffer', + Memlet.simple(win_buffer, "0:n", num_accesses=n)) + + # for other nodes depends this window to connect + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = window_state.add_write(window_name) + window_state.add_edge(win_create_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + +############################################################################### + + fence_state = sdfg.add_state("win_fence") + + sdfg.add_edge(window_state, fence_state, dace.InterstateEdge()) + + fence_name = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name, window_name) + + # pseudo access for ordering + window_node = fence_state.add_access(window_name) + window_desc = sdfg.arrays[window_name] + + fence_state.add_edge(window_node, + None, + win_fence_node, + None, + Memlet.from_array(window_name, window_desc)) + + sdfg.add_array("assertion", [1], dtype=dace.int32, transient=False) + assertion_node = fence_state.add_access("assertion") + + fence_state.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = fence_state.add_write(fence_name) + fence_state.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + + return sdfg + + +############################################################################### + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("MPI", dace.float32, marks=pytest.mark.mpi), + pytest.param("MPI", dace.int32, marks=pytest.mark.mpi) +]) +def test_win_fence(dtype): + from mpi4py import MPI + np_dtype = getattr(np, dtype.to_string()) + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("This test is supposed to be run with at least two processes!") + + mpi_func = None + for r in range(0, comm_size): + if r == comm_rank: + sdfg = make_sdfg(dtype) + mpi_func = sdfg.compile() + comm_world.Barrier() + + window_size = 10 + win_buffer = np.arange(0, window_size, dtype=np_dtype) + assertion = np.full([1], 0, dtype=np.int32) + + mpi_func(assertion=assertion, win_buffer=win_buffer, n=window_size) + +if __name__ == "__main__": + test_win_fence(dace.int32) + test_win_fence(dace.float32) diff --git a/tests/library/mpi/win_get_test.py b/tests/library/mpi/win_get_test.py new file mode 100644 index 0000000000..9f2e780d69 --- /dev/null +++ b/tests/library/mpi/win_get_test.py @@ -0,0 +1,204 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +############################################################################### + + +def make_sdfg(dtype): + n = dace.symbol("n") + + sdfg = dace.SDFG("mpi_win_get") + window_state = sdfg.add_state("create_window") + + sdfg.add_array("assertion", [1], dtype=dace.int32, transient=False) + sdfg.add_array("win_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("receive_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("target_rank", [1], dace.dtypes.int32, transient=False) + + win_buffer = window_state.add_access("win_buffer") + + window_name = sdfg.add_window() + win_create_node = mpi.nodes.win_create.Win_create(window_name) + + window_state.add_edge(win_buffer, + None, + win_create_node, + '_win_buffer', + Memlet.simple(win_buffer, "0:n", num_accesses=n)) + + # for other nodes depends this window to connect + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = window_state.add_write(window_name) + window_state.add_edge(win_create_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + +############################################################################### + + fence_state_1 = sdfg.add_state("win_fence_1") + + sdfg.add_edge(window_state, fence_state_1, dace.InterstateEdge()) + + fence_name = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name, window_name) + + # pseudo access for ordering + window_node = fence_state_1.add_access(window_name) + window_desc = sdfg.arrays[window_name] + + fence_state_1.add_edge(window_node, + None, + win_fence_node, + None, + Memlet.from_array(window_name, window_desc)) + + assertion_node = fence_state_1.add_access("assertion") + + fence_state_1.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = fence_state_1.add_write(fence_name) + fence_state_1.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + +############################################################################### + + get_state = sdfg.add_state("win_get") + + sdfg.add_edge(fence_state_1, get_state, dace.InterstateEdge()) + + get_name = sdfg.add_rma_ops(window_name, "get") + win_get_node = mpi.nodes.win_get.Win_get(get_name, window_name) + + # pseudo access for ordering + fence_node = get_state.add_access(fence_name) + fence_desc = sdfg.arrays[fence_name] + + target_rank = get_state.add_access("target_rank") + + get_state.add_edge(fence_node, + None, + win_get_node, + "_in", + Memlet.from_array(fence_name, fence_desc)) + + get_state.add_edge(target_rank, + None, + win_get_node, + "_target_rank", + Memlet.simple(target_rank, "0:1", num_accesses=1)) + + receive_buffer = get_state.add_write("receive_buffer") + get_state.add_edge(win_get_node, + "_outbuffer", + receive_buffer, + None, + Memlet.simple(receive_buffer, "0:n", num_accesses=n)) + + _, scal = sdfg.add_scalar(get_name, dace.int32, transient=True) + wnode = get_state.add_write(get_name) + get_state.add_edge(win_get_node, + "_out", + wnode, + None, + Memlet.from_array(get_name, scal)) + +############################################################################### + + fence_state_2 = sdfg.add_state("win_fence_2") + + sdfg.add_edge(get_state, fence_state_2, dace.InterstateEdge()) + + fence_name = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name, window_name) + + # pseudo access for ordering + get_node = fence_state_2.add_access(get_name) + get_desc = sdfg.arrays[get_name] + + fence_state_2.add_edge(get_node, + None, + win_fence_node, + None, + Memlet.from_array(get_name, get_desc)) + + assertion_node = fence_state_2.add_access("assertion") + + fence_state_2.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = fence_state_2.add_write(fence_name) + fence_state_2.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + + return sdfg + + +############################################################################### + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("MPI", dace.float32, marks=pytest.mark.mpi), + pytest.param("MPI", dace.int32, marks=pytest.mark.mpi) +]) +def test_win_get(dtype): + from mpi4py import MPI + np_dtype = getattr(np, dtype.to_string()) + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("This test is supposed to be run with at least two processes!") + + mpi_func = None + for r in range(0, comm_size): + if r == comm_rank: + sdfg = make_sdfg(dtype) + mpi_func = sdfg.compile() + comm_world.Barrier() + + window_size = 10 + win_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + receive_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + + target_rank = np.array([(comm_rank + 1) % comm_size], dtype=np.int32) + + assertion = np.full([1], 0, dtype=np.int32) + + mpi_func(assertion=assertion, + win_buffer=win_buffer, + receive_buffer=receive_buffer, + target_rank=target_rank, + n=window_size) + + correct_data = np.full(window_size, (comm_rank + 1) % comm_size, dtype=np_dtype) + if (not np.allclose(receive_buffer, correct_data)): + raise (ValueError("The received values are not what I expected on root.")) + +if __name__ == "__main__": + test_win_get(dace.int32) + test_win_get(dace.float32) diff --git a/tests/library/mpi/win_passive_sync_test.py b/tests/library/mpi/win_passive_sync_test.py new file mode 100644 index 0000000000..8a0ac7c3d7 --- /dev/null +++ b/tests/library/mpi/win_passive_sync_test.py @@ -0,0 +1,332 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +############################################################################### + + +def make_sdfg(dtype): + n = dace.symbol("n") + + sdfg = dace.SDFG("mpi_win_passive_sync") + window_state = sdfg.add_state("create_window") + + sdfg.add_array("lock_type", [1], dtype=dace.int32, transient=False) + sdfg.add_array("assertion", [1], dtype=dace.int32, transient=False) + sdfg.add_array("win_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("send_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("target_rank", [1], dace.dtypes.int32, transient=False) + + win_buffer = window_state.add_access("win_buffer") + + window_name = sdfg.add_window() + win_create_node = mpi.nodes.win_create.Win_create(window_name) + + window_state.add_edge(win_buffer, + None, + win_create_node, + '_win_buffer', + Memlet.simple(win_buffer, "0:n", num_accesses=n)) + + # for other nodes depends this window to connect + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = window_state.add_write(window_name) + window_state.add_edge(win_create_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + +############################################################################### + + lock_state = sdfg.add_state("win_lock") + + sdfg.add_edge(window_state, lock_state, dace.InterstateEdge()) + + lock_name = sdfg.add_rma_ops(window_name, "lock") + win_lock_node = mpi.nodes.win_lock.Win_lock(lock_name, window_name) + + # pseudo access for ordering + window_node = lock_state.add_access(window_name) + window_desc = sdfg.arrays[window_name] + + lock_state.add_edge(window_node, + None, + win_lock_node, + None, + Memlet.from_array(window_name, window_desc)) + + lock_type_node = lock_state.add_access("lock_type") + + target_rank_node = lock_state.add_access("target_rank") + + assertion_node = lock_state.add_access("assertion") + + lock_state.add_edge(lock_type_node, + None, + win_lock_node, + '_lock_type', + Memlet.simple(lock_type_node, "0:1", num_accesses=1)) + + lock_state.add_edge(target_rank_node, + None, + win_lock_node, + '_rank', + Memlet.simple(target_rank_node, "0:1", num_accesses=1)) + + lock_state.add_edge(assertion_node, + None, + win_lock_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(lock_name, dace.int32, transient=True) + wnode = lock_state.add_write(lock_name) + lock_state.add_edge(win_lock_node, + "_out", + wnode, + None, + Memlet.from_array(lock_name, scal)) + +############################################################################### + + put_state = sdfg.add_state("win_put") + + sdfg.add_edge(lock_state, put_state, dace.InterstateEdge()) + + put_name = sdfg.add_rma_ops(window_name, "put") + win_put_node = mpi.nodes.win_put.Win_put(put_name, window_name) + + # pseudo access for ordering + lock_node = put_state.add_access(lock_name) + lock_desc = sdfg.arrays[lock_name] + + send_buffer = put_state.add_access("send_buffer") + + target_rank = put_state.add_access("target_rank") + + put_state.add_edge(lock_node, + None, + win_put_node, + "_in", + Memlet.from_array(lock_name, lock_desc)) + + put_state.add_edge(send_buffer, + None, + win_put_node, + "_inbuffer", + Memlet.simple(send_buffer, "0:n", num_accesses=n)) + + put_state.add_edge(target_rank, + None, + win_put_node, + "_target_rank", + Memlet.simple(target_rank, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(put_name, dace.int32, transient=True) + wnode = put_state.add_write(put_name) + put_state.add_edge(win_put_node, + "_out", + wnode, + None, + Memlet.from_array(put_name, scal)) + +############################################################################### + + flush_state = sdfg.add_state("win_flush") + + sdfg.add_edge(put_state, flush_state, dace.InterstateEdge()) + + flush_name = sdfg.add_rma_ops(window_name, "flush") + win_flush_node = mpi.nodes.win_flush.Win_flush(flush_name, window_name) + + # pseudo access for ordering + put_node = flush_state.add_access(put_name) + put_desc = sdfg.arrays[put_name] + + flush_state.add_edge(put_node, + None, + win_flush_node, + None, + Memlet.from_array(put_name, put_desc)) + + target_rank_node = flush_state.add_access("target_rank") + + flush_state.add_edge(target_rank_node, + None, + win_flush_node, + '_rank', + Memlet.simple(target_rank_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(flush_name, dace.int32, transient=True) + wnode = flush_state.add_write(flush_name) + flush_state.add_edge(win_flush_node, + "_out", + wnode, + None, + Memlet.from_array(flush_name, scal)) + +############################################################################### + + unlock_state = sdfg.add_state("win_unlock") + + sdfg.add_edge(flush_state, unlock_state, dace.InterstateEdge()) + + unlock_name = sdfg.add_rma_ops(window_name, "unlock") + win_unlock_node = mpi.nodes.win_unlock.Win_unlock(unlock_name, window_name) + + # pseudo access for ordering + flush_node = unlock_state.add_access(flush_name) + flush_desc = sdfg.arrays[flush_name] + + unlock_state.add_edge(flush_node, + None, + win_unlock_node, + None, + Memlet.from_array(flush_name, flush_desc)) + + target_rank_node = unlock_state.add_access("target_rank") + + unlock_state.add_edge(target_rank_node, + None, + win_unlock_node, + '_rank', + Memlet.simple(target_rank_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(unlock_name, dace.int32, transient=True) + wnode = unlock_state.add_write(unlock_name) + unlock_state.add_edge(win_unlock_node, + "_out", + wnode, + None, + Memlet.from_array(unlock_name, scal)) + +# added these two fences as Barrier to ensure that every rank has completed +# since every rank are running independently +# some ranks might exit(since they completed) the transmission +# while others are still transmitting +############################################################################### + + fence_state_1 = sdfg.add_state("win_fence") + + sdfg.add_edge(unlock_state, fence_state_1, dace.InterstateEdge()) + + fence_name_1 = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name_1, window_name) + + # pseudo access for ordering + unlock_node = fence_state_1.add_access(unlock_name) + unlock_desc = sdfg.arrays[unlock_name] + + fence_state_1.add_edge(unlock_node, + None, + win_fence_node, + None, + Memlet.from_array(unlock_name, unlock_desc)) + + assertion_node = fence_state_1.add_access("assertion") + + fence_state_1.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name_1, dace.int32, transient=True) + wnode = fence_state_1.add_write(fence_name_1) + fence_state_1.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name_1, scal)) + +############################################################################### + + fence_state_2 = sdfg.add_state("win_fence") + + sdfg.add_edge(fence_state_1, fence_state_2, dace.InterstateEdge()) + + fence_name_2 = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name_2, window_name) + + # pseudo access for ordering + fence_node = fence_state_2.add_access(fence_name_1) + fence_desc = sdfg.arrays[fence_name_1] + + fence_state_2.add_edge(fence_node, + None, + win_fence_node, + None, + Memlet.from_array(fence_name_1, fence_desc)) + + assertion_node = fence_state_2.add_access("assertion") + + fence_state_2.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name_2, dace.int32, transient=True) + wnode = fence_state_2.add_write(fence_name_2) + fence_state_2.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name_2, scal)) + + return sdfg + + +############################################################################### + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("MPI", dace.float32, marks=pytest.mark.mpi), + pytest.param("MPI", dace.int32, marks=pytest.mark.mpi) +]) +def test_win_put(dtype): + from mpi4py import MPI + np_dtype = getattr(np, dtype.to_string()) + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("This test is supposed to be run with at least two processes!") + + mpi_func = None + for r in range(0, comm_size): + if r == comm_rank: + sdfg = make_sdfg(dtype) + mpi_func = sdfg.compile() + comm_world.Barrier() + + window_size = 10 + win_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + send_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + + target_rank = np.array([(comm_rank + 1) % comm_size], dtype=np.int32) + lock_type = np.full([1], MPI.LOCK_SHARED, dtype=np.int32) + assertion = np.full([1], 0, dtype=np.int32) + + mpi_func(lock_type=lock_type, + assertion=assertion, + win_buffer=win_buffer, + send_buffer=send_buffer, + target_rank=target_rank, + n=window_size) + + correct_data = np.full(window_size, (comm_rank - 1) % comm_size, dtype=np_dtype) + if (not np.allclose(win_buffer, correct_data)): + raise (ValueError("The received values are not what I expected on root.")) + +if __name__ == "__main__": + test_win_put(dace.int32) + test_win_put(dace.float32) diff --git a/tests/library/mpi/win_put_test.py b/tests/library/mpi/win_put_test.py new file mode 100644 index 0000000000..0e8af8487b --- /dev/null +++ b/tests/library/mpi/win_put_test.py @@ -0,0 +1,205 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +############################################################################### + + +def make_sdfg(dtype): + n = dace.symbol("n") + + sdfg = dace.SDFG("mpi_win_put") + window_state = sdfg.add_state("create_window") + + sdfg.add_array("assertion", [1], dtype=dace.int32, transient=False) + sdfg.add_array("win_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("send_buffer", [n], dtype=dtype, transient=False) + sdfg.add_array("target_rank", [1], dace.dtypes.int32, transient=False) + + win_buffer = window_state.add_access("win_buffer") + + window_name = sdfg.add_window() + win_create_node = mpi.nodes.win_create.Win_create(window_name) + + window_state.add_edge(win_buffer, + None, + win_create_node, + '_win_buffer', + Memlet.simple(win_buffer, "0:n", num_accesses=n)) + + # for other nodes depends this window to connect + _, scal = sdfg.add_scalar(window_name, dace.int32, transient=True) + wnode = window_state.add_write(window_name) + window_state.add_edge(win_create_node, + "_out", + wnode, + None, + Memlet.from_array(window_name, scal)) + +############################################################################### + + fence_state_1 = sdfg.add_state("win_fence_1") + + sdfg.add_edge(window_state, fence_state_1, dace.InterstateEdge()) + + fence_name = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name, window_name) + + # pseudo access for ordering + window_node = fence_state_1.add_access(window_name) + window_desc = sdfg.arrays[window_name] + + fence_state_1.add_edge(window_node, + None, + win_fence_node, + None, + Memlet.from_array(window_name, window_desc)) + + assertion_node = fence_state_1.add_access("assertion") + + fence_state_1.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = fence_state_1.add_write(fence_name) + fence_state_1.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + +############################################################################### + + put_state = sdfg.add_state("win_put") + + sdfg.add_edge(fence_state_1, put_state, dace.InterstateEdge()) + + put_name = sdfg.add_rma_ops(window_name, "put") + win_put_node = mpi.nodes.win_put.Win_put(put_name, window_name) + + # pseudo access for ordering + fence_node = put_state.add_access(fence_name) + fence_desc = sdfg.arrays[fence_name] + + send_buffer = put_state.add_access("send_buffer") + + target_rank = put_state.add_access("target_rank") + + put_state.add_edge(fence_node, + None, + win_put_node, + "_in", + Memlet.from_array(fence_name, fence_desc)) + + put_state.add_edge(send_buffer, + None, + win_put_node, + "_inbuffer", + Memlet.simple(send_buffer, "0:n", num_accesses=n)) + + put_state.add_edge(target_rank, + None, + win_put_node, + "_target_rank", + Memlet.simple(target_rank, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(put_name, dace.int32, transient=True) + wnode = put_state.add_write(put_name) + put_state.add_edge(win_put_node, + "_out", + wnode, + None, + Memlet.from_array(put_name, scal)) + +############################################################################### + + fence_state_2 = sdfg.add_state("win_fence_2") + + sdfg.add_edge(put_state, fence_state_2, dace.InterstateEdge()) + + fence_name = sdfg.add_rma_ops(window_name, "fence") + win_fence_node = mpi.nodes.win_fence.Win_fence(fence_name, window_name) + + # pseudo access for ordering + put_node = fence_state_2.add_access(put_name) + put_desc = sdfg.arrays[put_name] + + fence_state_2.add_edge(put_node, + None, + win_fence_node, + None, + Memlet.from_array(put_name, put_desc)) + + assertion_node = fence_state_2.add_access("assertion") + + fence_state_2.add_edge(assertion_node, + None, + win_fence_node, + '_assertion', + Memlet.simple(assertion_node, "0:1", num_accesses=1)) + + _, scal = sdfg.add_scalar(fence_name, dace.int32, transient=True) + wnode = fence_state_2.add_write(fence_name) + fence_state_2.add_edge(win_fence_node, + "_out", + wnode, + None, + Memlet.from_array(fence_name, scal)) + + return sdfg + + +############################################################################### + +@pytest.mark.parametrize("implementation, dtype", [ + pytest.param("MPI", dace.float32, marks=pytest.mark.mpi), + pytest.param("MPI", dace.int32, marks=pytest.mark.mpi) +]) +def test_win_put(dtype): + from mpi4py import MPI + np_dtype = getattr(np, dtype.to_string()) + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("This test is supposed to be run with at least two processes!") + + mpi_func = None + for r in range(0, comm_size): + if r == comm_rank: + sdfg = make_sdfg(dtype) + mpi_func = sdfg.compile() + comm_world.Barrier() + + window_size = 10 + win_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + send_buffer = np.full(window_size, comm_rank, dtype=np_dtype) + + target_rank = np.array([(comm_rank + 1) % comm_size], dtype=np.int32) + + assertion = np.full([1], 0, dtype=np.int32) + + mpi_func(assertion=assertion, + win_buffer=win_buffer, + send_buffer=send_buffer, + target_rank=target_rank, + n=window_size) + + correct_data = np.full(window_size, (comm_rank - 1) % comm_size, dtype=np_dtype) + if (not np.allclose(win_buffer, correct_data)): + raise (ValueError("The received values are not what I expected on root.")) + +if __name__ == "__main__": + test_win_put(dace.int32) + test_win_put(dace.float32)