diff --git a/dace/distr_types.py b/dace/distr_types.py index 1b595a1b84..d7ccad8061 100644 --- a/dace/distr_types.py +++ b/dace/distr_types.py @@ -12,6 +12,30 @@ RankType = Union[Integral, str, symbolic.symbol, symbolic.SymExpr, symbolic.sympy.Basic] +@make_properties +class ProcessComm(object): + """ + ProcessComm is the descriptor class for comm world split + Real comm creation is implemented in mpi.nodes.comm_split.Comm_split + """ + + name = Property(dtype=str, desc="The name of new comm world.") + def __init__(self, + name: str): + self.name = name + self._validate() + + def validate(self): + """ Validate the correctness of this object. + Raises an exception on error. """ + self._validate() + + # Validation of this class is in a separate function, so that this + # class can call `_validate()` without calling the subclasses' + # `validate` function. + def _validate(self): + return True + @make_properties class ProcessGrid(object): """ diff --git a/dace/frontend/common/distr.py b/dace/frontend/common/distr.py index d6f22da358..6b8752804e 100644 --- a/dace/frontend/common/distr.py +++ b/dace/frontend/common/distr.py @@ -15,6 +15,106 @@ RankType = Union[Integral, str, symbolic.symbol, symbolic.SymExpr, symbolic.sympy.Basic] ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor' +##### MPI Communicators +# a helper function for getting an access node by argument name +# creates a scalar if it's a number +def _get_int_arg_node(pv: ProgramVisitor, + sdfg: SDFG, + state: SDFGState, + argument: Union[str, sp.Expr, Number] + ): + if isinstance(argument, str) and argument in sdfg.arrays.keys(): + arg_name = argument + arg_node = state.add_read(arg_name) + else: + # create a transient scalar and take its name + arg_name = _define_local_scalar(pv, sdfg, state, dace.int32) + arg_node = state.add_access(arg_name) + # every tasklet is in different scope, no need to worry about name confilct + color_tasklet = state.add_tasklet(f'_set_{arg_name}_', {}, {'__out'}, f'__out = {argument}') + state.add_edge(color_tasklet, '__out', arg_node, None, Memlet.simple(arg_node, '0')) + + return arg_name, arg_node + + +@oprepo.replaces('mpi4py.MPI.COMM_WORLD.Split') +@oprepo.replaces('dace.comm.Split') +def _comm_split(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + color: Union[str, sp.Expr, Number] = 0, + key: Union[str, sp.Expr, Number] = 0, + grid: str = None): + """ Splits communicator + """ + from dace.libraries.mpi.nodes.comm_split import Comm_split + + # fine a new comm world name + comm_name = sdfg.add_comm() + + comm_split_node = Comm_split(comm_name, grid) + + _, color_node = _get_int_arg_node(pv, sdfg, state, color) + _, key_node = _get_int_arg_node(pv, sdfg, state, key) + + state.add_edge(color_node, None, comm_split_node, '_color', Memlet.simple(color_node, "0:1", num_accesses=1)) + state.add_edge(key_node, None, comm_split_node, '_key', Memlet.simple(key_node, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(comm_name, dace.int32, transient=True) + wnode = state.add_write(comm_name) + state.add_edge(comm_split_node, "_out", wnode, None, Memlet.from_array(comm_name, scal)) + + # return value will be the name of this splited communicator + return comm_name + + +@oprepo.replaces_method('Cartcomm', 'Split') +@oprepo.replaces_method('Intracomm', 'Split') +def _intracomm_comm_split(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + comm: Tuple[str, 'Comm'], + color: Union[str, sp.Expr, Number] = 0, + key: Union[str, sp.Expr, Number] = 0): + """ Equivalent to `dace.comm.split(color, key)`. """ + from mpi4py import MPI + comm_name, comm_obj = comm + if comm_obj == MPI.COMM_WORLD: + return _comm_split(pv, sdfg, state, color, key) + raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.') + + +@oprepo.replaces_method('ProcessComm', 'Split') +def _processcomm_comm_split(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + comm: Tuple[str, 'Comm'], + color: Union[str, sp.Expr, Number] = 0, + key: Union[str, sp.Expr, Number] = 0): + """ Equivalent to `dace.comm.split(color, key)`. """ + return _comm_split(pv, sdfg, state, color, key, grid=comm) + + +@oprepo.replaces_method('ProcessComm', 'Free') +def _processcomm_comm_free(pv: 'ProgramVisitor', + sdfg: SDFG, + state: SDFGState, + comm: Tuple[str, 'Comm']): + + from dace.libraries.mpi.nodes.comm_free import Comm_free + + comm_free_node = Comm_free("_Comm_free_", comm) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + comm_node = state.add_read(comm) + comm_desc = sdfg.arrays[comm] + state.add_edge(comm_node, None, comm_free_node, "_in", Memlet.from_array(comm, comm_desc)) + + # return value will be the name of this splited communicator + return f"{comm}_free" + + ##### MPI Cartesian Communicators @@ -166,6 +266,11 @@ def _bcast(pv: ProgramVisitor, desc = sdfg.arrays[buffer] in_buffer = state.add_read(buffer) out_buffer = state.add_write(buffer) + if grid: + comm_node = state.add_read(grid) + comm_desc = sdfg.arrays[grid] + state.add_edge(comm_node, None, libnode, None, Memlet.from_array(grid, comm_desc)) + if isinstance(root, str) and root in sdfg.arrays.keys(): root_node = state.add_read(root) else: @@ -200,6 +305,7 @@ def _intracomm_bcast(pv: 'ProgramVisitor', return _bcast(pv, sdfg, state, buffer, root, fcomm=comm_name) +@oprepo.replaces_method('ProcessComm', 'Bcast') @oprepo.replaces_method('ProcessGrid', 'Bcast') def _pgrid_bcast(pv: 'ProgramVisitor', sdfg: SDFG, @@ -278,6 +384,7 @@ def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icom return _alltoall(pv, sdfg, state, inp_buffer, out_buffer) +@oprepo.replaces_method('ProcessComm', 'Alltoall') @oprepo.replaces_method('ProcessGrid', 'Alltoall') def _pgrid_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, pgrid: str, inp_buffer: str, out_buffer: str): """ Equivalent to `dace.comm.Alltoall(inp_buffer, out_buffer, grid=pgrid)`. """ diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c9d92b7860..0c5032ba01 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1304,9 +1304,9 @@ def defined(self): result.update(self.sdfg.arrays) # MPI-related stuff - result.update( - {k: self.sdfg.process_grids[v] - for k, v in self.variables.items() if v in self.sdfg.process_grids}) + result.update({k: self.sdfg.process_grids[v] for k, v in self.variables.items() if v in self.sdfg.process_grids}) + result.update({k: self.sdfg.process_comms[v] for k, v in self.variables.items() if v in self.sdfg.process_comms}) + try: from mpi4py import MPI result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)}) @@ -4686,6 +4686,8 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]: for operand in operands: if isinstance(operand, str) and operand in self.sdfg.process_grids: result.append((operand, type(self.sdfg.process_grids[operand]).__name__)) + elif isinstance(operand, str) and operand in self.sdfg.process_comms: + result.append((operand, type(self.sdfg.process_comms[operand]).__name__)) elif isinstance(operand, str) and operand in self.sdfg.arrays: result.append((operand, type(self.sdfg.arrays[operand]))) elif isinstance(operand, str) and operand in self.scope_arrays: diff --git a/dace/libraries/mpi/nodes/__init__.py b/dace/libraries/mpi/nodes/__init__.py index 0cd36cc82f..ae4b723c49 100644 --- a/dace/libraries/mpi/nodes/__init__.py +++ b/dace/libraries/mpi/nodes/__init__.py @@ -13,3 +13,5 @@ from .alltoall import Alltoall from .dummy import Dummy from .redistribute import Redistribute +from .comm_split import Comm_split +from .comm_free import Comm_free diff --git a/dace/libraries/mpi/nodes/comm_free.py b/dace/libraries/mpi/nodes/comm_free.py new file mode 100644 index 0000000000..7d375b62fa --- /dev/null +++ b/dace/libraries/mpi/nodes/comm_free.py @@ -0,0 +1,50 @@ +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandFreeMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): + code = f""" + MPI_Comm_free(&__state->{node.grid}_comm); + """ + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Comm_free(MPINode): + + # Global properties + implementations = { + "MPI": ExpandFreeMPI, + } + default_implementation = "MPI" + + grid = dace.properties.Property(dtype=str, allow_none=False, default=None) + + def __init__(self, name, grid, *args, **kwargs): + super().__init__(name, *args, inputs={"_in"}, outputs={}, **kwargs) + self.grid = grid + + def validate(self, sdfg, state): + """ + :return: A three-tuple (buffer, root) of the three data descriptors in the + parent SDFG. + """ + + return None diff --git a/dace/libraries/mpi/nodes/comm_split.py b/dace/libraries/mpi/nodes/comm_split.py new file mode 100644 index 0000000000..d625796c47 --- /dev/null +++ b/dace/libraries/mpi/nodes/comm_split.py @@ -0,0 +1,77 @@ +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +import dace.library +import dace.properties +import dace.sdfg.nodes +from dace.transformation.transformation import ExpandTransformation +from .. import environments +from dace.libraries.mpi.nodes.node import MPINode + + +@dace.library.expansion +class ExpandCommSplitMPI(ExpandTransformation): + + environments = [environments.mpi.MPI] + + @staticmethod + def expansion(node, parent_state, parent_sdfg, n=None, **kwargs): + color, key = node.validate(parent_sdfg, parent_state) + + if node.grid is None: + comm = "MPI_COMM_WORLD" + else: + comm = f"__state->{node.grid}_comm" + + comm_name = node.name + + node.fields = [ + f'MPI_Comm {comm_name}_comm;', + f'int {comm_name}_rank;', + f'int {comm_name}_size;', + ] + + code = f""" + MPI_Comm_split({comm}, _color, _key, &__state->{comm_name}_comm); + MPI_Comm_rank(__state->{comm_name}_comm, &__state->{comm_name}_rank); + MPI_Comm_size(__state->{comm_name}_comm, &__state->{comm_name}_size); + """ + + tasklet = dace.sdfg.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + state_fields=node.fields, + language=dace.dtypes.Language.CPP, + side_effects=True) + return tasklet + + +@dace.library.node +class Comm_split(MPINode): + + # Global properties + implementations = { + "MPI": ExpandCommSplitMPI, + } + default_implementation = "MPI" + + grid = dace.properties.Property(dtype=str, allow_none=True, default=None) + + def __init__(self, name, grid=None, *args, **kwargs): + super().__init__(name, *args, inputs={"_color", "_key"}, outputs={"_out"}, **kwargs) + self.grid = grid + + def validate(self, sdfg, state): + """ + :return: A three-tuple (buffer, root) of the three data descriptors in the + parent SDFG. + """ + + color, key = None, None + + for e in state.in_edges(self): + if e.dst_conn == "_color": + color = sdfg.arrays[e.data.data] + if e.dst_conn == "_key": + key = sdfg.arrays[e.data.data] + + return color, key diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index a23d2616f9..59e3531108 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -32,7 +32,7 @@ from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView from dace.sdfg.state import SDFGState from dace.sdfg.propagation import propagate_memlets_sdfg -from dace.distr_types import ProcessGrid, SubArray, RedistrArray +from dace.distr_types import ProcessComm, ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty, TransformationHistProperty, OptionalSDFGReferenceProperty, DictProperty, CodeBlock) @@ -445,6 +445,11 @@ class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]): debuginfo = DebugInfoProperty(allow_none=True) + _comms = DictProperty(str, + ProcessComm, + desc="Process-comm descriptors for this SDFG", + to_json=_arrays_to_json, + from_json=_arrays_from_json) _pgrids = DictProperty(str, ProcessGrid, desc="Process-grid descriptors for this SDFG", @@ -517,6 +522,7 @@ def __init__(self, self._recompile = True # Grid-distribution-related fields + self._comms = {} self._pgrids = {} self._subarrays = {} self._rdistrarrays = {} @@ -683,6 +689,11 @@ def arrays(self): """ return self._arrays + @property + def process_comms(self): + """ Returns a dictionary of process-comm descriptors (`ProcessComm` objects) used in this SDFG. """ + return self._comms + @property def process_grids(self): """ Returns a dictionary of process-grid descriptors (`ProcessGrid` objects) used in this SDFG. """ @@ -1707,7 +1718,7 @@ def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ - names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | self._subarrays.keys() + names = (self._arrays.keys() | self.constants_prop.keys() | self._comms.keys() | self._pgrids.keys() | self._subarrays.keys() | self._rdistrarrays.keys()) return dt.find_new_name(name, names) @@ -2049,6 +2060,16 @@ def _add_symbols(desc: dt.Data): return name + def add_comm(self): + """ Adds a comm world to the process-comm descriptor store. + """ + + comm_name = self._find_new_name('__proc') + + self._comms[comm_name] = ProcessComm(comm_name) + + return comm_name + def add_pgrid(self, shape: ShapeType = None, parent_grid: str = None, diff --git a/tests/library/mpi/comm_free_test.py b/tests/library/mpi/comm_free_test.py new file mode 100644 index 0000000000..34c1f8a0ea --- /dev/null +++ b/tests/library/mpi/comm_free_test.py @@ -0,0 +1,63 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +@pytest.mark.mpi +def test_comm_free(): + from mpi4py import MPI + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = dace.SDFG("mpi_free_test") + start_state = sdfg.add_state("start") + + sdfg.add_scalar("color", dace.dtypes.int32, transient=False) + sdfg.add_scalar("key", dace.dtypes.int32, transient=False) + + color = start_state.add_read("color") + key = start_state.add_read("key") + + # color and key needs to be variable + comm_name = sdfg.add_comm() + comm_split_node = mpi.nodes.comm_split.Comm_split(comm_name) + + start_state.add_edge(color, None, comm_split_node, '_color', Memlet.simple(color, "0:1", num_accesses=1)) + start_state.add_edge(key, None, comm_split_node, '_key', Memlet.simple(key, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(comm_name, dace.int32, transient=True) + wnode = start_state.add_write(comm_name) + start_state.add_edge(comm_split_node, "_out", wnode, None, Memlet.from_array(comm_name, scal)) + + main_state = sdfg.add_state("main") + + sdfg.add_edge(start_state, main_state, dace.InterstateEdge()) + + comm_free_node = mpi.nodes.comm_free.Comm_free("_Comm_free_", comm_name) + + comm_node = main_state.add_read(comm_name) + comm_desc = sdfg.arrays[comm_name] + main_state.add_edge(comm_node, None, comm_free_node, "_in", Memlet.from_array(comm_name, comm_desc)) + + func = utils.distributed_compile(sdfg, comm_world) + + # split world + color = comm_rank % 2 + key = comm_rank + + func(color=color, key=key) + + +if __name__ == "__main__": + test_comm_free() diff --git a/tests/library/mpi/comm_split_test.py b/tests/library/mpi/comm_split_test.py new file mode 100644 index 0000000000..4165372352 --- /dev/null +++ b/tests/library/mpi/comm_split_test.py @@ -0,0 +1,87 @@ +# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.sdfg import utils +import dace.dtypes as dtypes +from dace.memlet import Memlet +import dace.libraries.mpi as mpi +import dace.frontend.common.distr as comm +import numpy as np +import pytest + + +@pytest.mark.mpi +def test_comm_split(): + from mpi4py import MPI + comm_world = MPI.COMM_WORLD + comm_rank = comm_world.Get_rank() + comm_size = comm_world.Get_size() + + if comm_size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = dace.SDFG("mpi_comm_split") + state = sdfg.add_state("start") + + sdfg.add_scalar("color", dace.dtypes.int32, transient=False) + sdfg.add_scalar("key", dace.dtypes.int32, transient=False) + sdfg.add_array("new_rank", [1], dtype=dace.int32, transient=False) + sdfg.add_array("new_size", [1], dtype=dace.int32, transient=False) + + color = state.add_read("color") + key = state.add_read("key") + + # color and key needs to be variable + comm_name = sdfg.add_comm() + comm_split_node = mpi.nodes.comm_split.Comm_split(comm_name) + + state.add_edge(color, None, comm_split_node, '_color', Memlet.simple(color, "0:1", num_accesses=1)) + state.add_edge(key, None, comm_split_node, '_key', Memlet.simple(key, "0:1", num_accesses=1)) + + # Pseudo-writing for newast.py #3195 check and complete Processcomm creation + _, scal = sdfg.add_scalar(comm_name, dace.int32, transient=True) + wnode = state.add_write(comm_name) + state.add_edge(comm_split_node, "_out", wnode, None, Memlet.from_array(comm_name, scal)) + + state2 = sdfg.add_state("main") + + sdfg.add_edge(state, state2, dace.InterstateEdge()) + + tasklet = state2.add_tasklet( + "new_comm_get", + {}, + {'_rank', '_size'}, + f"_rank = __state->{comm_name}_rank;\n_size = __state->{comm_name}_size;", + dtypes.Language.CPP) + + new_rank = state2.add_write("new_rank") + new_size = state2.add_write("new_size") + + state2.add_edge(tasklet, '_rank', new_rank, None, Memlet.simple(new_rank, "0:1", num_accesses=1)) + state2.add_edge(tasklet, '_size', new_size, None, Memlet.simple(new_size, "0:1", num_accesses=1)) + + func = utils.distributed_compile(sdfg, comm_world) + + # split world + color = comm_rank % 2 + key = comm_rank + new_rank = np.zeros((1, ), dtype=np.int32) + new_size = np.zeros((1, ), dtype=np.int32) + + func(color=color, key=key, new_rank=new_rank, new_size=new_size) + + correct_new_rank = np.arange(0, comm_size, dtype=np.int32) // 2 + assert (correct_new_rank[comm_rank] == new_rank[0]) + + # reverse rank order + color = 0 + key = comm_size - comm_rank + new_rank = np.zeros((1, ), dtype=np.int32) + new_size = np.zeros((1, ), dtype=np.int32) + + func(color=color, key=key, new_rank=new_rank, new_size=new_size) + + correct_new_rank = np.flip(np.arange(0, comm_size, dtype=np.int32), 0) + assert (correct_new_rank[comm_rank] == new_rank[0]) + +if __name__ == "__main__": + test_comm_split() diff --git a/tests/library/mpi/mpi4py_test.py b/tests/library/mpi/mpi4py_test.py index 52b5deb7a8..60f280ee53 100644 --- a/tests/library/mpi/mpi4py_test.py +++ b/tests/library/mpi/mpi4py_test.py @@ -77,6 +77,135 @@ def external_comm_bcast(A: dace.int32[10]): assert (np.array_equal(A, A_ref)) +@pytest.mark.mpi +def test_process_comm_split_bcast(): + + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def comm_split_bcast(rank: dace.int32, A: dace.int32[10]): + # new_comm = commworld.Split(rank % 2, 0) + color = np.full((1,), rank % 2, dtype=np.int32) + new_comm = commworld.Split(color, 0) + new_comm.Bcast(A) + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = comm_split_bcast.to_sdfg() + # disable openMP section for split completeness + sdfg.openmp_sections = False + func = utils.distributed_compile(sdfg, commworld) + + if rank == 0: + A = np.arange(10, dtype=np.int32) + A_ref = A.copy() + elif rank == 1: + A = np.arange(10, 20, dtype=np.int32) + A_ref = A.copy() + else: + A = np.zeros((10, ), dtype=np.int32) + A_ref = A.copy() + + func(rank=rank, A=A) + comm_split_bcast.f(rank, A_ref) + + assert(np.array_equal(A, A_ref)) + +# Disable this test, since currently we cannot assure the order of free and bcast +# @pytest.mark.mpi +def test_process_comm_free(): + + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def comm_free_test(rank: dace.int32, A: dace.int32[10]): + # new_comm = commworld.Split(rank % 2, 0) + color = np.full((1,), rank % 2, dtype=np.int32) + new_comm = commworld.Split(color, 0) + new_comm.Bcast(A) + new_comm.Free() + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = comm_free_test.to_sdfg() + # disable openMP section for split completeness + sdfg.openmp_sections = False + func = utils.distributed_compile(sdfg, commworld) + + if rank == 0: + A = np.arange(10, dtype=np.int32) + A_ref = A.copy() + elif rank == 1: + A = np.arange(10, 20, dtype=np.int32) + A_ref = A.copy() + else: + A = np.zeros((10, ), dtype=np.int32) + A_ref = A.copy() + + func(rank=rank, A=A) + comm_free_test.f(rank, A_ref) + + assert(np.array_equal(A, A_ref)) + + +@pytest.mark.mpi +def test_nested_process_comm_split_bcast(): + + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def nested_comm_split_bcast(rank: dace.int32, A: dace.int32[10]): + # from commworld to new_comm rank assignment + # 0 2 4 6 -> 0 1 2 3 + # 1 3 5 7 -> 0 1 2 3 + color = np.full((1,), rank % 2, dtype=np.int32) + new_comm = commworld.Split(color, rank) + + # from commworld to new_comm to new_comm2 rank assignment + # (0,1) (4,5) -> 0 2 -> 0 1 + # (2,3) (6,7) -> 1 3 -> 0 1 + color2 = np.full((1,), (rank // 2) % 2, dtype=np.int32) + new_comm2 = new_comm.Split(color2, rank) + new_comm2.Bcast(A) + + if size < 2: + raise ValueError("Please run this test with at least two processes.") + + sdfg = None + if rank == 0: + sdfg = nested_comm_split_bcast.to_sdfg() + # disable openMP section for split completeness + sdfg.openmp_sections = False + func = utils.distributed_compile(sdfg, commworld) + + if rank < 4: + A = np.arange(rank * 10, (rank + 1) * 10, dtype=np.int32) + A_ref = A.copy() + else: + A = np.zeros((10, ), dtype=np.int32) + A_ref = A.copy() + + func(rank=rank, A=A) + nested_comm_split_bcast.f(rank, A_ref) + + assert(np.array_equal(A, A_ref)) + + @pytest.mark.mpi def test_process_grid_bcast(): @@ -339,12 +468,46 @@ def mpi4py_alltoall(rank: dace.int32, size: dace.compiletime): raise (ValueError("The received values are not what I expected.")) +@pytest.mark.mpi +def test_comm_split_alltoall(): + from mpi4py import MPI + commworld = MPI.COMM_WORLD + rank = commworld.Get_rank() + size = commworld.Get_size() + + @dace.program + def mpi4py_comm_split_alltoall(rank: dace.int32, size: dace.compiletime): + color = np.full((1,), rank % 2, dtype=np.int32) + key = np.full((1,), 0, dtype=np.int32) + new_comm = commworld.Split(color, key) + + sbuf = np.full((size // 2,), rank, dtype=np.int32) + rbuf = np.zeros((size // 2, ), dtype=np.int32) + new_comm.Alltoall(sbuf, rbuf) + return rbuf + + sdfg = None + if rank == 0: + sdfg = mpi4py_comm_split_alltoall.to_sdfg(simplify=True, size=size) + func = utils.distributed_compile(sdfg, commworld) + + val = func(rank=rank) + ref = mpi4py_comm_split_alltoall.f(rank, size) + + if (not np.allclose(val, ref)): + raise (ValueError("The received values are not what I expected.")) + + if __name__ == "__main__": test_comm_world_bcast() test_external_comm_bcast() + test_process_comm_split_bcast() + # test_process_comm_free() + test_nested_process_comm_split_bcast() test_process_grid_bcast() test_sub_grid_bcast() test_3mm() test_isend_irecv() test_send_recv() test_alltoall() + test_comm_split_alltoall()