Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coll: add coll_group to collective interfaces #7103

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0499415
comm: store num_local and num_external in MPIR_Comm
hzhou Aug 13, 2024
b7d6412
comm: remove node_count
hzhou Aug 13, 2024
cae3828
comm/csel: remove reference to subcomms in csel prune_tree
hzhou Aug 19, 2024
438e7b8
coll: remove coll.pof2 field
hzhou Aug 21, 2024
e7c88bd
comm: add MPIR_Subgroup
hzhou Aug 11, 2024
2b88398
coll: add macros to get rank/size with coll_group
hzhou Aug 22, 2024
e804cbc
coll: add coll_group argument to coll interfaces
hzhou Aug 16, 2024
e3969cc
continue: add coll_group to collective interfaces
hzhou Aug 16, 2024
024377a
coll: add coll_group argument to MPIC/sched/TSP routines
hzhou Aug 16, 2024
6338e01
continue: add coll_group in MPIC/sched/TSP routines
hzhou Aug 16, 2024
049fab4
ch4: fallback to mpir if coll_group is non-zero
hzhou Aug 19, 2024
98fc2fc
coll: add coll_group to csel signature
hzhou Aug 18, 2024
20b3244
coll: threadcomm coll to use MPIR_SUBGROUP_THREADCOMM
hzhou Aug 18, 2024
a831867
coll: check coll_group in MPIR_Comm_is_parent_comm
hzhou Aug 18, 2024
d2a6412
coll: make non-compositional algorithm coll_group aware
hzhou Aug 18, 2024
a2f92c4
coll: modify bcast_intra_smp to use subgroups
hzhou Aug 18, 2024
370661e
coll: avoid extra intra bcast in bcast_intra_smp
hzhou Aug 18, 2024
ae6fe4e
coll: modify smp algorithms to use MPIR_Subgroup
hzhou Aug 19, 2024
0ba1a80
mpir: replace subcomm usage with subgroups
hzhou Aug 20, 2024
3543718
temp: fix csel
hzhou Aug 22, 2024
7ea94c7
coll: refactor caching tree in the comm struct
hzhou Aug 22, 2024
ce2274d
coll: add coll_group to treealgo routines
hzhou Aug 22, 2024
2bf4890
coll: add nogroup restriction to certain algorithms
hzhou Aug 23, 2024
757066a
coll: check coll_group in MPIR_Sched_next_tag
hzhou Aug 24, 2024
513991d
coll: refactor barrier_intra_k_dissemination
hzhou Aug 24, 2024
10adb96
coll/allreduce: remove a leftover empty branch
hzhou Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions maint/gen_coll.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def dump_allcomm_auto_blocking(name):
dump_open("MPIR_Csel_coll_sig_s coll_sig = {")
G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME)
G.out.append(".comm_ptr = comm_ptr,")
if not re.match(r'i?neighbor_', func_name, re.IGNORECASE):
G.out.append(".coll_group = coll_group,")
for p in func['parameters']:
if not re.match(r'comm$', p['name']):
G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name']))
Expand Down Expand Up @@ -163,12 +165,16 @@ def dump_allcomm_sched_auto(name):
dump_split(0, "int MPIR_%s_allcomm_sched_auto(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
if re.match(r'Ineighbor_', Name):
G.out.append("int coll_group = MPIR_SUBGROUP_NONE;")
G.out.append("")

# -- Csel_search
dump_open("MPIR_Csel_coll_sig_s coll_sig = {")
G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME)
G.out.append(".comm_ptr = comm_ptr,")
if not re.match(r'i?neighbor_', func_name, re.IGNORECASE):
G.out.append(".coll_group = coll_group,")
for p in func['parameters']:
if not re.match(r'comm$', p['name']):
G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name']))
Expand Down Expand Up @@ -363,6 +369,8 @@ def dump_cases(commkind):
dump_split(0, "int MPIR_%s_sched_impl(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
if re.match(r'Ineighbor_', Name):
G.out.append("int coll_group = MPIR_SUBGROUP_NONE;")
G.out.append("")

dump_open("if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {")
Expand Down Expand Up @@ -552,20 +560,22 @@ def dump_fallback(algo):
elif a == "noinplace":
cond_list.append("sendbuf != MPI_IN_PLACE")
elif a == "power-of-two":
cond_list.append("comm_ptr->local_size == comm_ptr->coll.pof2")
cond_list.append("MPL_is_pof2(MPIR_Coll_size(comm_ptr, coll_group))")
elif a == "size-ge-pof2":
cond_list.append("count >= comm_ptr->coll.pof2")
cond_list.append("count >= MPL_pof2(MPIR_Coll_size(comm_ptr, coll_group))")
elif a == "commutative":
cond_list.append("MPIR_Op_is_commutative(op)")
elif a== "builtin-op":
cond_list.append("HANDLE_IS_BUILTIN(op)")
elif a == "parent-comm":
cond_list.append("MPIR_Comm_is_parent_comm(comm_ptr)")
cond_list.append("MPIR_Comm_is_parent_comm(comm_ptr, coll_group)")
elif a == "node-consecutive":
cond_list.append("MPII_Comm_is_node_consecutive(comm_ptr)")
elif a == "displs-ordered":
# assume it's allgatherv
cond_list.append("MPII_Iallgatherv_is_displs_ordered(comm_ptr->local_size, recvcounts, displs)")
elif a == "nogroup":
cond_list.append("coll_group == MPIR_SUBGROUP_NONE")
else:
raise Exception("Unsupported restrictions - %s" % a)
(func_name, commkind) = algo['func-commkind'].split('-')
Expand Down Expand Up @@ -644,6 +654,9 @@ def get_algo_extra_params(algo):
# additional wrappers
def get_algo_args(args, algo, kind):
algo_args = args
if not re.match(r'i?neighbor_', algo['func-commkind']):
algo_args += ", coll_group"

if 'extra_params' in algo:
algo_args += ", " + get_algo_extra_args(algo, kind)

Expand All @@ -658,6 +671,9 @@ def get_algo_args(args, algo, kind):

def get_algo_params(params, algo):
algo_params = params
if not re.match(r'i?neighbor_', algo['func-commkind']):
algo_params += ", int coll_group"

if 'extra_params' in algo:
algo_params += ", " + get_algo_extra_params(algo)

Expand All @@ -681,6 +697,8 @@ def get_algo_name(algo):

def get_func_params(params, name, kind):
func_params = params
if not name.startswith('neighbor_'):
func_params += ", int coll_group"
if kind == "blocking":
if not name.startswith('neighbor_'):
func_params += ", MPIR_Errflag_t errflag"
Expand All @@ -701,6 +719,8 @@ def get_func_params(params, name, kind):

def get_func_args(args, name, kind):
func_args = args
if not name.startswith('neighbor_'):
func_args += ", coll_group"
if kind == "blocking":
if not name.startswith('neighbor_'):
func_args += ", errflag"
Expand Down
5 changes: 5 additions & 0 deletions maint/local_python/binding_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,6 +1686,8 @@ def push_impl_decl(func, impl_name=None):
if func['_impl_param_list']:
params = ', '.join(func['_impl_param_list'])
if func['dir'] == 'coll':
if not RE.match(r'MPI_(Ineighbor|Neighbor)', func['name']):
params = params.replace('comm_ptr', 'comm_ptr, int coll_group')
# block collective use an extra errflag
if not RE.match(r'MPI_(I.*|Neighbor.*|.*_init)$', func['name']):
params = params + ", MPIR_Errflag_t errflag"
Expand Down Expand Up @@ -1726,6 +1728,8 @@ def dump_body_coll(func):
mpir_name = re.sub(r'^MPIX?_', 'MPIR_', func['name'])

args = ", ".join(func['_impl_arg_list'])
if not RE.match(r'MPI_(Ineighbor|Neighbor)', func['name']):
args = args.replace('comm_ptr', 'comm_ptr, MPIR_SUBGROUP_NONE')

if RE.match(r'MPI_(I.*|.*_init)$', func['name'], re.IGNORECASE):
# non-blocking collectives
Expand Down Expand Up @@ -1956,6 +1960,7 @@ def dump_body_reduce_equal(func):
args = ", ".join(func['_impl_arg_list'])
args = re.sub(r'recvbuf, ', '', args)
args = re.sub(r'op, ', 'recvbuf, ', args)
args += ", MPIR_SUBGROUP_NONE"
dump_line_with_break("mpi_errno = %s(%s);" % (impl, args))
dump_error_check("")

Expand Down
2 changes: 1 addition & 1 deletion src/binding/c/comm_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ MPI_Intercomm_merge:
* error to make */
acthigh = high ? 1 : 0; /* Clamp high into 1 or 0 */
mpi_errno = MPIR_Allreduce(MPI_IN_PLACE, &acthigh, 1, MPI_INT,
MPI_SUM, intercomm_ptr->local_comm, MPIR_ERR_NONE);
MPI_SUM, intercomm_ptr->local_comm, MPIR_SUBGROUP_NONE, MPIR_ERR_NONE);
MPIR_ERR_CHECK(mpi_errno);
/* acthigh must either == 0 or the size of the local comm */
if (acthigh != 0 && acthigh != intercomm_ptr->local_size) {
Expand Down
81 changes: 66 additions & 15 deletions src/include/mpir_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,52 @@

#include "coll_impl.h"
#include "coll_algos.h"
#include "mpir_threadcomm.h"

#ifdef ENABLE_THREADCOMM
#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
MPIR_Threadcomm *threadcomm = (comm)->threadcomm; \
MPIR_Assert(threadcomm); \
int intracomm_size = (comm)->local_size; \
size_ = threadcomm->rank_offset_table[intracomm_size - 1]; \
rank_ = MPIR_THREADCOMM_TID_TO_RANK(threadcomm, MPIR_threadcomm_get_tid(threadcomm)); \
} while (0)
#else
#define MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_) do { \
MPIR_Assert(0); \
size_ = 0; \
rank_ = -1; \
} while (0)
#endif

#define MPIR_COLL_RANK_SIZE(comm, coll_group, rank_, size_) do { \
if (coll_group == MPIR_SUBGROUP_NONE) { \
rank_ = (comm)->rank; \
size_ = (comm)->local_size; \
} else if (coll_group == MPIR_SUBGROUP_THREADCOMM) { \
MPIR_THREADCOMM_RANK_SIZE(comm, rank_, size_); \
} else { \
rank_ = (comm)->subgroups[coll_group].rank; \
size_ = (comm)->subgroups[coll_group].size; \
} \
} while (0)

/* sometime it is convenient to just get the rank or size */
static inline int MPIR_Coll_size(MPIR_Comm * comm, int coll_group)
{
int rank, size;
MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size);
(void) rank;
return size;
}

static inline int MPIR_Coll_rank(MPIR_Comm * comm, int coll_group)
{
int rank, size;
MPIR_COLL_RANK_SIZE(comm, coll_group, rank, size);
(void) size;
return rank;
}

/* During init, not all algorithms are safe to use. For example, the csel
* may not have been initialized. We define a set of fallback routines that
Expand All @@ -28,36 +74,41 @@ int MPIC_Wait(MPIR_Request * request_ptr);
int MPIC_Probe(int source, int tag, MPI_Comm comm, MPI_Status * status);

int MPIC_Send(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag);
int MPIC_Recv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag,
MPIR_Comm * comm_ptr, MPI_Status * status);
MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status);
int MPIC_Sendrecv(const void *sendbuf, MPI_Aint sendcount, MPI_Datatype sendtype,
int dest, int sendtag, void *recvbuf, MPI_Aint recvcount,
MPI_Datatype recvtype, int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype,
int dest, int sendtag,
int source, int recvtag,
MPIR_Comm * comm_ptr, MPI_Status * status, MPIR_Errflag_t errflag);
MPIR_Comm * comm_ptr, int coll_group, MPI_Status * status,
MPIR_Errflag_t errflag);
int MPIC_Sendrecv_replace(void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int sendtag,
int source, int recvtag, MPIR_Comm * comm_ptr, int coll_group,
MPI_Status * status, MPIR_Errflag_t errflag);
int MPIC_Isend(const void *buf, MPI_Aint count, MPI_Datatype datatype, int dest, int tag,
MPIR_Comm * comm_ptr, MPIR_Request ** request, MPIR_Errflag_t errflag);
int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source,
int tag, MPIR_Comm * comm_ptr, MPIR_Request ** request);
MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request,
MPIR_Errflag_t errflag);
int MPIC_Irecv(void *buf, MPI_Aint count, MPI_Datatype datatype, int source, int tag,
MPIR_Comm * comm_ptr, int coll_group, MPIR_Request ** request);
int MPIC_Waitall(int numreq, MPIR_Request * requests[], MPI_Status * statuses);

int MPIR_Reduce_local(const void *inbuf, void *inoutbuf, MPI_Aint count, MPI_Datatype datatype,
MPI_Op op);

int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, MPIR_Errflag_t errflag);
int MPIR_Barrier_intra_dissemination(MPIR_Comm * comm_ptr, int coll_group, MPIR_Errflag_t errflag);

/* TSP auto */
int MPIR_TSP_Iallreduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
MPI_Datatype datatype, MPI_Op op,
MPIR_Comm * comm, MPIR_TSP_sched_t sched);
MPIR_Comm * comm, int coll_group,
MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibcast_sched_intra_tsp_auto(void *buffer, MPI_Aint count, MPI_Datatype datatype,
int root, MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, MPIR_TSP_sched_t sched);
int root, MPIR_Comm * comm_ptr, int coll_group,
MPIR_TSP_sched_t sched);
int MPIR_TSP_Ibarrier_sched_intra_tsp_auto(MPIR_Comm * comm, int coll_group,
MPIR_TSP_sched_t sched);
int MPIR_TSP_Ireduce_sched_intra_tsp_auto(const void *sendbuf, void *recvbuf, MPI_Aint count,
MPI_Datatype datatype, MPI_Op op, int root,
MPIR_Comm * comm_ptr, MPIR_TSP_sched_t sched);
MPIR_Comm * comm_ptr, int coll_group,
MPIR_TSP_sched_t sched);
#endif /* MPIR_COLL_H_INCLUDED */
70 changes: 53 additions & 17 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,51 @@ enum MPIR_COMM_HINT_PREDEFINED_t {
MPIR_COMM_HINT_PREDEFINED_COUNT
};

/* MPIR_Subgroup is similar to MPIR_Group, but only used to describe subgroups within
* an intra communicator. The proc_table refers to ranks within the communicator.
* It is only used internally for group collectives.
*/
typedef struct MPIR_Subgroup {
int size;
int rank;
int *proc_table; /* can be NULL if the group is trivial */
} MPIR_Subgroup;

#define MPIR_MAX_SUBGROUPS 16

/* reserved subgroup indexes */
enum {
MPIR_SUBGROUP_THREADCOMM = -1,
MPIR_SUBGROUP_NONE = 0,
MPIR_SUBGROUP_NODE, /* i.e. nodecomm */
MPIR_SUBGROUP_NODE_CROSS, /* node_roots_comm, node_rank_1_comm, ... */
MPIR_SUBGROUP_NUMA1, /* 1-level below node in topology */
MPIR_SUBGROUP_NUMA1_CROSS, /* cross-link group at NUMA1 within NODE */
MPIR_SUBGROUP_NUMA2, /* and so on */
MPIR_SUBGROUP_NUMA2_CROSS,
MPIR_SUBGROUP_NUM_RESERVED,
};

/* macros to create dynamic subgroups.
* It is expected to fillout the proc_table after MPIR_COMM_PUSH_SUBGROUP.
*/
#define MPIR_COMM_PUSH_SUBGROUP(comm, _size, _rank, newgrp, proc_table_out) \
do { \
(newgrp) = (comm)->num_subgroups++; \
MPIR_Assert((comm)->num_subgroups < MPIR_MAX_SUBGROUPS); \
(comm)->subgroups[newgrp].size = _size; \
(comm)->subgroups[newgrp].rank = _rank; \
(proc_table_out) = MPL_malloc((_size) * sizeof(int), MPL_MEM_OTHER); \
(comm)->subgroups[newgrp].proc_table = (proc_table_out); \
} while (0)

#define MPIR_COMM_POP_SUBGROUP(comm) \
do { \
int i = --(comm)->num_subgroups; \
MPIR_Assert(i > 0); \
MPL_free((comm)->subgroups[i].proc_table); \
} while (0)

/*S
MPIR_Comm - Description of the Communicator data structure

Expand Down Expand Up @@ -187,7 +232,8 @@ struct MPIR_Comm {
int *internode_table; /* internode_table[i] gives the rank in
* node_roots_comm of rank i in this comm.
* It is of size 'local_size'. */
int node_count; /* number of nodes this comm is spread over */
int num_local; /* number of procs in this comm on local node */
int num_external; /* number of nodes this comm is spread over */

int is_low_group; /* For intercomms only, this boolean is
* set for all members of one of the
Expand All @@ -196,6 +242,8 @@ struct MPIR_Comm {
* intercommunicator collective operations
* that wish to use half-duplex operations
* to implement a full-duplex operation */
MPIR_Subgroup subgroups[MPIR_MAX_SUBGROUPS];
int num_subgroups;

struct MPIR_Comm *comm_next; /* Provides a chain through all active
* communicators */
Expand All @@ -222,9 +270,6 @@ struct MPIR_Comm {
* use int array for fast access */

struct {
int pof2; /* Nearest (smaller than or equal to) power of 2
* to the number of ranks in the communicator.
* To be used during collective communication */
int pofk[MAX_RADIX - 1];
int k[MAX_RADIX - 1];
int step1_sendto[MAX_RADIX - 1];
Expand All @@ -234,18 +279,9 @@ struct MPIR_Comm {
int **step2_nbrs[MAX_RADIX - 1];
int nbrs_defined[MAX_RADIX - 1];
void **recexch_allreduce_nbr_buffer;
int topo_aware_tree_root;
int topo_aware_tree_k;
MPIR_Treealgo_tree_t *topo_aware_tree;
int topo_aware_k_tree_root;
int topo_aware_k_tree_k;
MPIR_Treealgo_tree_t *topo_aware_k_tree;
int topo_wave_tree_root;
int topo_wave_tree_overhead;
int topo_wave_tree_lat_diff_groups;
int topo_wave_tree_lat_diff_switches;
int topo_wave_tree_lat_same_switches;
MPIR_Treealgo_tree_t *topo_wave_tree;

MPIR_Treealgo_tree_t *cached_tree;
MPIR_Treealgo_param_t cached_tree_param;
} coll;

void *csel_comm; /* collective selector handle */
Expand Down Expand Up @@ -374,7 +410,7 @@ int MPIR_Comm_create_inter(MPIR_Comm * comm_ptr, MPIR_Group * group_ptr, MPIR_Co
int MPIR_Comm_create_subcomms(MPIR_Comm * comm);
int MPIR_Comm_commit(MPIR_Comm *);

int MPIR_Comm_is_parent_comm(MPIR_Comm *);
int MPIR_Comm_is_parent_comm(MPIR_Comm * comm, int coll_group);

/* peer intercomm is an internal 1-to-1 intercomm used for connecting dynamic processes */
int MPIR_peer_intercomm_create(MPIR_Context_id_t context_id, MPIR_Context_id_t recvcontext_id,
Expand Down
1 change: 1 addition & 0 deletions src/include/mpir_csel.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ typedef enum {
typedef struct {
MPIR_Csel_coll_type_e coll_type;
MPIR_Comm *comm_ptr;
int coll_group;

union {
struct {
Expand Down
Loading