From 3c20962a1d7525b79c3e0a2aac07fc397daa3e5c Mon Sep 17 00:00:00 2001 From: Hui Zhou Date: Fri, 23 Aug 2024 07:31:14 -0500 Subject: [PATCH] coll: add nogroup restriction to certain algorithms Some algorithm, e.g. Allgather recexch, caches comm size-related info in communicator, thus won't work with none trivial coll_group. Add a restriction so it will fallback when coll_group != MPIR_SUBGROUP_NONE. --- maint/gen_coll.py | 2 ++ src/mpi/coll/allgather/allgather_intra_recexch.c | 3 +++ .../allreduce/allreduce_intra_k_reduce_scatter_allgather.c | 2 ++ src/mpi/coll/allreduce/allreduce_intra_recexch.c | 3 +++ src/mpi/coll/coll_algorithms.txt | 5 ++++- 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/maint/gen_coll.py b/maint/gen_coll.py index 83323a951da..afd46f026e6 100644 --- a/maint/gen_coll.py +++ b/maint/gen_coll.py @@ -570,6 +570,8 @@ def dump_fallback(algo): elif a == "displs-ordered": # assume it's allgatherv cond_list.append("MPII_Iallgatherv_is_displs_ordered(comm_ptr->local_size, recvcounts, displs)") + elif a == "nogroup": + cond_list.append("coll_group == MPIR_SUBGROUP_NONE") else: raise Exception("Unsupported restrictions - %s" % a) (func_name, commkind) = algo['func-commkind'].split('-') diff --git a/src/mpi/coll/allgather/allgather_intra_recexch.c b/src/mpi/coll/allgather/allgather_intra_recexch.c index 2ed596e0ff5..70a4db65cfc 100644 --- a/src/mpi/coll/allgather/allgather_intra_recexch.c +++ b/src/mpi/coll/allgather/allgather_intra_recexch.c @@ -36,6 +36,9 @@ int MPIR_Allgather_intra_recexch(const void *sendbuf, MPI_Aint sendcount, MPIR_Request *rreqs[MAX_RADIX * 2], *sreqs[MAX_RADIX * 2]; MPIR_Request **recv_reqs = NULL, **send_reqs = NULL; + /* it caches data in comm */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + is_inplace = (sendbuf == MPI_IN_PLACE); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c index 2a68e6ee5d8..e63ccbcbd6b 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c +++ b/src/mpi/coll/allreduce/allreduce_intra_k_reduce_scatter_allgather.c @@ -36,6 +36,8 @@ int MPIR_Allreduce_intra_k_reduce_scatter_allgather(const void *sendbuf, MPIR_CHKLMEM_DECL(2); MPIR_Assert(k > 1); + /* This algorithm uses cached data in comm, thus it won't work with coll_group */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); diff --git a/src/mpi/coll/allreduce/allreduce_intra_recexch.c b/src/mpi/coll/allreduce/allreduce_intra_recexch.c index 5bf8a76ee1c..498da1965aa 100644 --- a/src/mpi/coll/allreduce/allreduce_intra_recexch.c +++ b/src/mpi/coll/allreduce/allreduce_intra_recexch.c @@ -34,6 +34,9 @@ int MPIR_Allreduce_intra_recexch(const void *sendbuf, MPIR_Request **send_reqs = NULL, **recv_reqs = NULL; int send_nreq = 0, recv_nreq = 0, total_phases = 0; + /* uses cached data in comm */ + MPIR_Assert(coll_group == MPIR_SUBGROUP_NONE); + MPIR_COLL_RANK_SIZE(comm, coll_group, rank, nranks); is_commutative = MPIR_Op_is_commutative(op); diff --git a/src/mpi/coll/coll_algorithms.txt b/src/mpi/coll/coll_algorithms.txt index 4251999ff7f..23e2a1d032b 100644 --- a/src/mpi/coll/coll_algorithms.txt +++ b/src/mpi/coll/coll_algorithms.txt @@ -174,10 +174,12 @@ allgather-intra: func_name: recexch extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_DOUBLING, k, single_phase_recv cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup recexch_halving func_name: recexch extra_params: recexch_type=MPIR_ALLGATHER_RECEXCH_TYPE_DISTANCE_HALVING, k, single_phase_recv cvar_params: -, RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup allgather-inter: local_gather_remote_bcast iallgather-intra: @@ -346,10 +348,11 @@ allreduce-intra: recexch extra_params: k, single_phase_recv cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV + restrictions: nogroup ring restrictions: commutative k_reduce_scatter_allgather - restrictions: commutative + restrictions: commutative, nogroup extra_params: k, single_phase_recv cvar_params: RECEXCH_KVAL, RECEXCH_SINGLE_PHASE_RECV allreduce-inter: