Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added mpi4py communicator split support #1347

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
43b60d4
Added comm_split library node and its test
Com1t Jul 20, 2023
f5e93bf
Modified add_memlet_path to add_edge for code consistency
Com1t Jul 20, 2023
e730397
Update __init__.py and added type check for Comm_split
Com1t Jul 20, 2023
dc0f069
Modified comm_split_test.py to take scalar as input
Com1t Jul 20, 2023
36e0a35
Completed a basic version of comm_split replacement and test using bc…
Com1t Jul 21, 2023
8ce1606
Added an access node for comm nodes in bcast replacement to prevent f…
Com1t Jul 21, 2023
d83c587
Added library node for commnunication free, implementing its test
Com1t Jul 21, 2023
a25e94c
Enabled test_process_comm_split_bcast pytest
Com1t Jul 21, 2023
3ebb209
Implemented number input for comm_split function
Com1t Jul 25, 2023
75bb05d
Added a helper function for the access node of an argument
Com1t Jul 25, 2023
1ce0a97
Implemented nested communicator split
Com1t Jul 26, 2023
9e1bfa5
Implemented comm free library node, sdfg test and mpi4py replacement/…
Com1t Jul 26, 2023
e3d4df5
fixed a small bug in _get_int_arg_node()
Com1t Jul 26, 2023
16a2f49
Added side effect flag for comm_free lib node
Com1t Aug 4, 2023
d8c00e9
Removed a debug msg in distr.py
Com1t Aug 4, 2023
edbc3af
Merge branch 'master' into mpi4py_dev
alexnick83 Aug 4, 2023
96e7e03
Merge branch 'master' into mpi4py_dev
alexnick83 Aug 4, 2023
a0259f7
Updated split function name and comments
Com1t Aug 18, 2023
69213b3
Merge branch 'master' into mpi4py_dev
alexnick83 Sep 7, 2023
215eadf
Merge branch 'master' into mpi4py_dev
alexnick83 Sep 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions dace/distr_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,30 @@
RankType = Union[Integral, str, symbolic.symbol, symbolic.SymExpr, symbolic.sympy.Basic]


@make_properties
class ProcessComm(object):
"""
ProcessComm is the descriptor class for comm world split
Real comm creation is implemented in mpi.nodes.comm_split.Comm_split
"""

name = Property(dtype=str, desc="The name of new comm world.")
def __init__(self,
name: str):
self.name = name
self._validate()

def validate(self):
""" Validate the correctness of this object.
Raises an exception on error. """
self._validate()

# Validation of this class is in a separate function, so that this
# class can call `_validate()` without calling the subclasses'
# `validate` function.
def _validate(self):
return True

@make_properties
class ProcessGrid(object):
"""
Expand Down
107 changes: 107 additions & 0 deletions dace/frontend/common/distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,106 @@
RankType = Union[Integral, str, symbolic.symbol, symbolic.SymExpr, symbolic.sympy.Basic]
ProgramVisitor = 'dace.frontend.python.newast.ProgramVisitor'

##### MPI Communicators
# a helper function for getting an access node by argument name
# creates a scalar if it's a number
def _get_int_arg_node(pv: ProgramVisitor,
sdfg: SDFG,
state: SDFGState,
argument: Union[str, sp.Expr, Number]
):
if isinstance(argument, str) and argument in sdfg.arrays.keys():
arg_name = argument
arg_node = state.add_read(arg_name)
else:
# create a transient scalar and take its name
arg_name = _define_local_scalar(pv, sdfg, state, dace.int32)
arg_node = state.add_access(arg_name)
# every tasklet is in different scope, no need to worry about name confilct
color_tasklet = state.add_tasklet(f'_set_{arg_name}_', {}, {'__out'}, f'__out = {argument}')
state.add_edge(color_tasklet, '__out', arg_node, None, Memlet.simple(arg_node, '0'))

return arg_name, arg_node


@oprepo.replaces('mpi4py.MPI.COMM_WORLD.Split')
@oprepo.replaces('dace.comm.Split')
def _comm_split(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
color: Union[str, sp.Expr, Number] = 0,
key: Union[str, sp.Expr, Number] = 0,
grid: str = None):
""" Splits communicator
"""
from dace.libraries.mpi.nodes.comm_split import Comm_split

# fine a new comm world name
comm_name = sdfg.add_comm()

comm_split_node = Comm_split(comm_name, grid)

_, color_node = _get_int_arg_node(pv, sdfg, state, color)
_, key_node = _get_int_arg_node(pv, sdfg, state, key)

state.add_edge(color_node, None, comm_split_node, '_color', Memlet.simple(color_node, "0:1", num_accesses=1))
state.add_edge(key_node, None, comm_split_node, '_key', Memlet.simple(key_node, "0:1", num_accesses=1))

# Pseudo-writing for newast.py #3195 check and complete Processcomm creation
_, scal = sdfg.add_scalar(comm_name, dace.int32, transient=True)
wnode = state.add_write(comm_name)
state.add_edge(comm_split_node, "_out", wnode, None, Memlet.from_array(comm_name, scal))

# return value will be the name of this splited communicator
return comm_name


@oprepo.replaces_method('Cartcomm', 'Split')
@oprepo.replaces_method('Intracomm', 'Split')
def _intracomm_comm_split(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
comm: Tuple[str, 'Comm'],
color: Union[str, sp.Expr, Number] = 0,
key: Union[str, sp.Expr, Number] = 0):
""" Equivalent to `dace.comm.split(color, key)`. """
from mpi4py import MPI
comm_name, comm_obj = comm
if comm_obj == MPI.COMM_WORLD:
return _comm_split(pv, sdfg, state, color, key)
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')


@oprepo.replaces_method('ProcessComm', 'Split')
def _processcomm_comm_split(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
comm: Tuple[str, 'Comm'],
color: Union[str, sp.Expr, Number] = 0,
key: Union[str, sp.Expr, Number] = 0):
""" Equivalent to `dace.comm.split(color, key)`. """
return _comm_split(pv, sdfg, state, color, key, grid=comm)


@oprepo.replaces_method('ProcessComm', 'Free')
def _processcomm_comm_free(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
comm: Tuple[str, 'Comm']):

from dace.libraries.mpi.nodes.comm_free import Comm_free

comm_free_node = Comm_free("_Comm_free_", comm)

# Pseudo-writing for newast.py #3195 check and complete Processcomm creation
comm_node = state.add_read(comm)
comm_desc = sdfg.arrays[comm]
state.add_edge(comm_node, None, comm_free_node, "_in", Memlet.from_array(comm, comm_desc))

# return value will be the name of this splited communicator
return f"{comm}_free"


##### MPI Cartesian Communicators


Expand Down Expand Up @@ -166,6 +266,11 @@ def _bcast(pv: ProgramVisitor,
desc = sdfg.arrays[buffer]
in_buffer = state.add_read(buffer)
out_buffer = state.add_write(buffer)
if grid:
comm_node = state.add_read(grid)
comm_desc = sdfg.arrays[grid]
state.add_edge(comm_node, None, libnode, None, Memlet.from_array(grid, comm_desc))

if isinstance(root, str) and root in sdfg.arrays.keys():
root_node = state.add_read(root)
else:
Expand Down Expand Up @@ -200,6 +305,7 @@ def _intracomm_bcast(pv: 'ProgramVisitor',
return _bcast(pv, sdfg, state, buffer, root, fcomm=comm_name)


@oprepo.replaces_method('ProcessComm', 'Bcast')
@oprepo.replaces_method('ProcessGrid', 'Bcast')
def _pgrid_bcast(pv: 'ProgramVisitor',
sdfg: SDFG,
Expand Down Expand Up @@ -278,6 +384,7 @@ def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icom
return _alltoall(pv, sdfg, state, inp_buffer, out_buffer)


@oprepo.replaces_method('ProcessComm', 'Alltoall')
@oprepo.replaces_method('ProcessGrid', 'Alltoall')
def _pgrid_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, pgrid: str, inp_buffer: str, out_buffer: str):
""" Equivalent to `dace.comm.Alltoall(inp_buffer, out_buffer, grid=pgrid)`. """
Expand Down
8 changes: 5 additions & 3 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,9 +1304,9 @@ def defined(self):
result.update(self.sdfg.arrays)

# MPI-related stuff
result.update(
{k: self.sdfg.process_grids[v]
for k, v in self.variables.items() if v in self.sdfg.process_grids})
result.update({k: self.sdfg.process_grids[v] for k, v in self.variables.items() if v in self.sdfg.process_grids})
result.update({k: self.sdfg.process_comms[v] for k, v in self.variables.items() if v in self.sdfg.process_comms})

try:
from mpi4py import MPI
result.update({k: v for k, v in self.globals.items() if isinstance(v, MPI.Comm)})
Expand Down Expand Up @@ -4686,6 +4686,8 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]:
for operand in operands:
if isinstance(operand, str) and operand in self.sdfg.process_grids:
result.append((operand, type(self.sdfg.process_grids[operand]).__name__))
elif isinstance(operand, str) and operand in self.sdfg.process_comms:
result.append((operand, type(self.sdfg.process_comms[operand]).__name__))
elif isinstance(operand, str) and operand in self.sdfg.arrays:
result.append((operand, type(self.sdfg.arrays[operand])))
elif isinstance(operand, str) and operand in self.scope_arrays:
Expand Down
2 changes: 2 additions & 0 deletions dace/libraries/mpi/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .alltoall import Alltoall
from .dummy import Dummy
from .redistribute import Redistribute
from .comm_split import Comm_split
from .comm_free import Comm_free
50 changes: 50 additions & 0 deletions dace/libraries/mpi/nodes/comm_free.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace.library
import dace.properties
import dace.sdfg.nodes
from dace.transformation.transformation import ExpandTransformation
from .. import environments
from dace.libraries.mpi.nodes.node import MPINode


@dace.library.expansion
class ExpandFreeMPI(ExpandTransformation):

environments = [environments.mpi.MPI]

@staticmethod
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
code = f"""
MPI_Comm_free(&__state->{node.grid}_comm);
"""
tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
code,
language=dace.dtypes.Language.CPP,
side_effects=True)
return tasklet


@dace.library.node
class Comm_free(MPINode):

# Global properties
implementations = {
"MPI": ExpandFreeMPI,
}
default_implementation = "MPI"

grid = dace.properties.Property(dtype=str, allow_none=False, default=None)

def __init__(self, name, grid, *args, **kwargs):
super().__init__(name, *args, inputs={"_in"}, outputs={}, **kwargs)
self.grid = grid

def validate(self, sdfg, state):
"""
:return: A three-tuple (buffer, root) of the three data descriptors in the
parent SDFG.
"""

return None
77 changes: 77 additions & 0 deletions dace/libraries/mpi/nodes/comm_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
import dace.library
import dace.properties
import dace.sdfg.nodes
from dace.transformation.transformation import ExpandTransformation
from .. import environments
from dace.libraries.mpi.nodes.node import MPINode


@dace.library.expansion
class ExpandCommSplitMPI(ExpandTransformation):

environments = [environments.mpi.MPI]

@staticmethod
def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
color, key = node.validate(parent_sdfg, parent_state)

if node.grid is None:
comm = "MPI_COMM_WORLD"
else:
comm = f"__state->{node.grid}_comm"

comm_name = node.name

node.fields = [
f'MPI_Comm {comm_name}_comm;',
f'int {comm_name}_rank;',
f'int {comm_name}_size;',
]

code = f"""
MPI_Comm_split({comm}, _color, _key, &__state->{comm_name}_comm);
MPI_Comm_rank(__state->{comm_name}_comm, &__state->{comm_name}_rank);
MPI_Comm_size(__state->{comm_name}_comm, &__state->{comm_name}_size);
"""

tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
code,
state_fields=node.fields,
language=dace.dtypes.Language.CPP,
side_effects=True)
return tasklet


@dace.library.node
class Comm_split(MPINode):

# Global properties
implementations = {
"MPI": ExpandCommSplitMPI,
}
default_implementation = "MPI"

grid = dace.properties.Property(dtype=str, allow_none=True, default=None)

def __init__(self, name, grid=None, *args, **kwargs):
super().__init__(name, *args, inputs={"_color", "_key"}, outputs={"_out"}, **kwargs)
self.grid = grid

def validate(self, sdfg, state):
"""
:return: A three-tuple (buffer, root) of the three data descriptors in the
parent SDFG.
"""

color, key = None, None

for e in state.in_edges(self):
if e.dst_conn == "_color":
color = sdfg.arrays[e.data.data]
if e.dst_conn == "_key":
key = sdfg.arrays[e.data.data]

return color, key
25 changes: 23 additions & 2 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dace.sdfg.graph import OrderedDiGraph, Edge, SubgraphView
from dace.sdfg.state import SDFGState
from dace.sdfg.propagation import propagate_memlets_sdfg
from dace.distr_types import ProcessGrid, SubArray, RedistrArray
from dace.distr_types import ProcessComm, ProcessGrid, SubArray, RedistrArray
from dace.dtypes import validate_name
from dace.properties import (DebugInfoProperty, EnumProperty, ListProperty, make_properties, Property, CodeProperty,
TransformationHistProperty, OptionalSDFGReferenceProperty, DictProperty, CodeBlock)
Expand Down Expand Up @@ -445,6 +445,11 @@ class SDFG(OrderedDiGraph[SDFGState, InterstateEdge]):

debuginfo = DebugInfoProperty(allow_none=True)

_comms = DictProperty(str,
ProcessComm,
desc="Process-comm descriptors for this SDFG",
to_json=_arrays_to_json,
from_json=_arrays_from_json)
_pgrids = DictProperty(str,
ProcessGrid,
desc="Process-grid descriptors for this SDFG",
Expand Down Expand Up @@ -517,6 +522,7 @@ def __init__(self,
self._recompile = True

# Grid-distribution-related fields
self._comms = {}
self._pgrids = {}
self._subarrays = {}
self._rdistrarrays = {}
Expand Down Expand Up @@ -683,6 +689,11 @@ def arrays(self):
"""
return self._arrays

@property
def process_comms(self):
""" Returns a dictionary of process-comm descriptors (`ProcessComm` objects) used in this SDFG. """
return self._comms

@property
def process_grids(self):
""" Returns a dictionary of process-grid descriptors (`ProcessGrid` objects) used in this SDFG. """
Expand Down Expand Up @@ -1707,7 +1718,7 @@ def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False)
def _find_new_name(self, name: str):
""" Tries to find a new name by adding an underscore and a number. """

names = (self._arrays.keys() | self.constants_prop.keys() | self._pgrids.keys() | self._subarrays.keys()
names = (self._arrays.keys() | self.constants_prop.keys() | self._comms.keys() | self._pgrids.keys() | self._subarrays.keys()
| self._rdistrarrays.keys())
return dt.find_new_name(name, names)

Expand Down Expand Up @@ -2049,6 +2060,16 @@ def _add_symbols(desc: dt.Data):

return name

def add_comm(self):
""" Adds a comm world to the process-comm descriptor store.
"""

comm_name = self._find_new_name('__proc')

self._comms[comm_name] = ProcessComm(comm_name)

return comm_name

def add_pgrid(self,
shape: ShapeType = None,
parent_grid: str = None,
Expand Down
Loading
Loading