Skip to content

Commit

Permalink
coll: add nogroup restriction to certain algorithms
Browse files Browse the repository at this point in the history
Some algorithm, e.g. Allgather recexch, caches comm size-related info in
communicator, thus won't work with none trivial coll_group. Add a
restriction so it will fallback when coll_group != MPIR_SUBGROUP_NONE.
  • Loading branch information
hzhou committed Aug 25, 2024
1 parent 2e58b5d commit 3c20962
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 1 deletion.
2 changes: 2 additions & 0 deletions maint/gen_coll.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,8 @@ def dump_fallback(algo):
elif a == "displs-ordered":
# assume it's allgatherv
cond_list.append("MPII_Iallgatherv_is_displs_ordered(comm_ptr->local_size, recvcounts, displs)")
elif a == "nogroup":
cond_list.append("coll_group == MPIR_SUBGROUP_NONE")
else:
raise Exception("Unsupported restrictions - %s" % a)
(func_name, commkind) = algo['func-commkind'].split('-')
Expand Down
3 changes: 3 additions & 0 deletions src/mpi/coll/allgather/allgather_intra_recexch.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount,
MPIR_Request *rreqs[MAX_RADIX * 2], *sreqs[MAX_RADIX * 2];
MPIR_Request **recv_reqs = NULL, **send_reqs = NULL;

/* it caches data in comm */
MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE);

is_inplace = (sendbuf == MPI_IN_PLACE);
MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf,
MPIR_CHKLMEM_DECL(2);

MPIR_Assert(k > 1);
/* This algorithm uses cached data in comm, thus it won't work with coll_group */
MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE);

MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks);

Expand Down
3 changes: 3 additions & 0 deletions src/mpi/coll/allreduce/allreduce_intra_recexch.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf,
MPIR_Request **send_reqs = NULL, **recv_reqs = NULL;
int send_nreq = 0, recv_nreq = 0, total_phases = 0;

/* uses cached data in comm */
MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE);

MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks);

is_commutative = MPIR_Op_is_commutative(op);
Expand Down
5 changes: 4 additions & 1 deletion src/mpi/coll/coll_algorithms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,12 @@ allgather-intra:
func_name: recexch
extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_DOUBLING, k, single_phase_recv
cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV
restrictions: nogroup
recexch_halving
func_name: recexch
extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_HALVING, k, single_phase_recv
cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV
restrictions: nogroup
allgather-inter:
local_gather_remote_bcast
iallgather-intra:
Expand Down Expand Up @@ -346,10 +348,11 @@ allreduce-intra:
recexch
extra_params: k, single_phase_recv
cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV
restrictions: nogroup
ring
restrictions: commutative
k_reduce_scatter_allgather
restrictions: commutative
restrictions: commutative, nogroup
extra_params: k, single_phase_recv
cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV
allreduce-inter:
Expand Down

0 comments on commit 3c20962

Please sign in to comment.