diff --git a/src/include/mpir_comm.h b/src/include/mpir_comm.h index 9c2fb697c02..1a2292199f0 100644 --- a/src/include/mpir_comm.h +++ b/src/include/mpir_comm.h @@ -279,18 +279,9 @@ struct MPIR_Comm { int **step2_nbrs[MAX_RADIX - 1]; int nbrs_defined[MAX_RADIX - 1]; void **recexch_allreduce_nbr_buffer; - int topo_aware_tree_root; - int topo_aware_tree_k; - MPIR_Treealgo_tree_t *topo_aware_tree; - int topo_aware_k_tree_root; - int topo_aware_k_tree_k; - MPIR_Treealgo_tree_t *topo_aware_k_tree; - int topo_wave_tree_root; - int topo_wave_tree_overhead; - int topo_wave_tree_lat_diff_groups; - int topo_wave_tree_lat_diff_switches; - int topo_wave_tree_lat_same_switches; - MPIR_Treealgo_tree_t *topo_wave_tree; + + MPIR_Treealgo_tree_t *cached_tree; + MPIR_Treealgo_param_t cached_tree_param; } coll; void *csel_comm; /* collective selector handle */ diff --git a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c index 61042c75671..759be0cd6e2 100644 --- a/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c +++ b/src/mpi/coll/algorithms/recexchalgo/recexchalgo.c @@ -26,19 +26,7 @@ int MPII_Recexchalgo_comm_init(MPIR_Comm * comm) } comm->coll.recexch_allreduce_nbr_buffer = NULL; - comm->coll.topo_aware_tree_root = -1; - comm->coll.topo_aware_tree_k = 0; - comm->coll.topo_aware_tree = NULL; - comm->coll.topo_aware_k_tree_root = -1; - comm->coll.topo_aware_k_tree_k = 0; - comm->coll.topo_aware_k_tree = NULL; - comm->coll.topo_wave_tree_root = -1; - comm->coll.topo_wave_tree = NULL; - comm->coll.topo_wave_tree_overhead = 0; - comm->coll.topo_wave_tree_lat_diff_groups = 0; - comm->coll.topo_wave_tree_lat_diff_switches = 0; - comm->coll.topo_wave_tree_lat_same_switches = 0; - + comm->coll.cached_tree = NULL; return mpi_errno; } @@ -66,22 +54,10 @@ int MPII_Recexchalgo_comm_cleanup(MPIR_Comm * comm) MPL_free(comm->coll.recexch_allreduce_nbr_buffer); } - if (comm->coll.topo_aware_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree); - MPL_free(comm->coll.topo_aware_tree); - comm->coll.topo_aware_tree = NULL; - } - - if (comm->coll.topo_aware_k_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree); - MPL_free(comm->coll.topo_aware_k_tree); - comm->coll.topo_aware_k_tree = NULL; - } - - if (comm->coll.topo_wave_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree); - MPL_free(comm->coll.topo_wave_tree); - comm->coll.topo_wave_tree = NULL; + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); + MPL_free(comm->coll.cached_tree); + comm->coll.cached_tree = NULL; } return mpi_errno; diff --git a/src/mpi/coll/algorithms/treealgo/treealgo.c b/src/mpi/coll/algorithms/treealgo/treealgo.c index 25a291c5f1c..da5ac83eef9 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo.c +++ b/src/mpi/coll/algorithms/treealgo/treealgo.c @@ -33,6 +33,56 @@ int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm) return mpi_errno; } +static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE && + param->root == root && param->u.topo_aware.k == k); +} + +static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE; + param->root = root; + param->u.topo_aware.k = k; +} + +static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K && + param->root == root && param->u.topo_aware.k == k); +} + +static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE_K; + param->root = root; + param->u.topo_aware.k = k; +} + +static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param, + int root, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches) +{ + return (param->type == MPIR_TREE_TYPE_TOPOLOGY_WAVE && + param->root == root && + param->u.topo_wave.overhead == overhead && + param->u.topo_wave.lat_diff_groups == lat_diff_groups && + param->u.topo_wave.lat_diff_switches == lat_diff_switches && + param->u.topo_wave.lat_same_switches == lat_same_switches); +} + +static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param, + int root, int overhead, int lat_diff_groups, + int lat_diff_switches, int lat_same_switches) +{ + param->type = MPIR_TREE_TYPE_TOPOLOGY_WAVE; + param->root = root; + param->u.topo_wave.overhead = overhead; + param->u.topo_wave.lat_diff_groups = lat_diff_groups; + param->u.topo_wave.lat_diff_switches = lat_diff_switches; + param->u.topo_wave.lat_same_switches = lat_same_switches; +} + int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int root, MPIR_Treealgo_tree_t * ct) @@ -84,56 +134,52 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k, switch (tree_type) { case MPIR_TREE_TYPE_TOPOLOGY_AWARE: - if (!comm->coll.topo_aware_tree || root != comm->coll.topo_aware_tree_root - || k != comm->coll.topo_aware_tree_k) { - if (comm->coll.topo_aware_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree); + if (!comm->coll.cached_tree || + !match_param_topo_aware(&comm->coll.cached_tree_param, root, k)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_aware_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder, - comm->coll.topo_aware_tree); + comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_aware_tree; - comm->coll.topo_aware_tree_root = root; - comm->coll.topo_aware_tree_k = k; + *ct = *comm->coll.cached_tree; + set_param_topo_aware(&comm->coll.cached_tree_param, root, k); } - *ct = *comm->coll.topo_aware_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_aware_tree->children)[i], - MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } break; case MPIR_TREE_TYPE_TOPOLOGY_AWARE_K: - if (!comm->coll.topo_aware_k_tree || root != comm->coll.topo_aware_k_tree_root - || k != comm->coll.topo_aware_k_tree_k) { - if (comm->coll.topo_aware_k_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree); + if (!comm->coll.cached_tree || + !match_param_topo_aware_k(&comm->coll.cached_tree_param, root, k)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_aware_k_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = MPII_Treeutil_tree_topology_aware_k_init(comm, k, root, enable_reorder, - comm->coll.topo_aware_k_tree); + comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_aware_k_tree; - comm->coll.topo_aware_k_tree_root = root; - comm->coll.topo_aware_k_tree_k = k; + *ct = *comm->coll.cached_tree; + set_param_topo_aware_k(&comm->coll.cached_tree_param, root, k); } - *ct = *comm->coll.topo_aware_k_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_aware_k_tree->children)[i], - MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } break; @@ -164,34 +210,29 @@ int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root, MPIR_FUNC_ENTER; - if (!comm->coll.topo_wave_tree || root != comm->coll.topo_wave_tree_root - || overhead != comm->coll.topo_wave_tree_overhead - || lat_diff_groups != comm->coll.topo_wave_tree_lat_diff_groups - || lat_diff_switches != comm->coll.topo_wave_tree_lat_diff_switches - || lat_same_switches != comm->coll.topo_wave_tree_lat_same_switches) { - if (comm->coll.topo_wave_tree) { - MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree); + if (!comm->coll.cached_tree || + !match_param_topo_wave(&comm->coll.cached_tree_param, root, overhead, + lat_diff_groups, lat_diff_switches, lat_same_switches)) { + if (comm->coll.cached_tree) { + MPIR_Treealgo_tree_free(comm->coll.cached_tree); } else { - comm->coll.topo_wave_tree = + comm->coll.cached_tree = (MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER); } mpi_errno = MPII_Treeutil_tree_topology_wave_init(comm, k, root, enable_reorder, overhead, lat_diff_groups, lat_diff_switches, lat_same_switches, - comm->coll.topo_wave_tree); + comm->coll.cached_tree); MPIR_ERR_CHECK(mpi_errno); - *ct = *comm->coll.topo_wave_tree; - comm->coll.topo_wave_tree_root = root; - comm->coll.topo_wave_tree_overhead = overhead; - comm->coll.topo_wave_tree_lat_diff_groups = lat_diff_groups; - comm->coll.topo_wave_tree_lat_diff_switches = lat_diff_switches; - comm->coll.topo_wave_tree_lat_same_switches = lat_same_switches; + *ct = *comm->coll.cached_tree; + set_param_topo_wave(&comm->coll.cached_tree_param, root, overhead, + lat_diff_groups, lat_diff_switches, lat_same_switches); } - *ct = *comm->coll.topo_wave_tree; + *ct = *comm->coll.cached_tree; utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL); for (int i = 0; i < ct->num_children; i++) { utarray_push_back(ct->children, - &ut_int_array(comm->coll.topo_wave_tree->children)[i], MPL_MEM_COLL); + &ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL); } MPIR_FUNC_EXIT; diff --git a/src/mpi/coll/algorithms/treealgo/treealgo_types.h b/src/mpi/coll/algorithms/treealgo/treealgo_types.h index 5db2c5ae931..bb15947046e 100644 --- a/src/mpi/coll/algorithms/treealgo/treealgo_types.h +++ b/src/mpi/coll/algorithms/treealgo/treealgo_types.h @@ -8,6 +8,16 @@ #include +/* enumerator for different tree types */ +typedef enum MPIR_Tree_type_t { + MPIR_TREE_TYPE_KARY = 0, + MPIR_TREE_TYPE_KNOMIAL_1, + MPIR_TREE_TYPE_KNOMIAL_2, + MPIR_TREE_TYPE_TOPOLOGY_AWARE, + MPIR_TREE_TYPE_TOPOLOGY_AWARE_K, + MPIR_TREE_TYPE_TOPOLOGY_WAVE, +} MPIR_Tree_type_t; + typedef struct { int rank; int nranks; @@ -16,4 +26,20 @@ typedef struct { UT_array *children; } MPIR_Treealgo_tree_t; +typedef struct { + MPIR_Tree_type_t type; + int root; + union { + struct { + int k; + } topo_aware; + struct { + int overhead; + int lat_diff_groups; + int lat_diff_switches; + int lat_same_switches; + } topo_wave; + } u; +} MPIR_Treealgo_param_t; + #endif /* TREEALGO_TYPES_H_INCLUDED */ diff --git a/src/mpi/coll/include/coll_types.h b/src/mpi/coll/include/coll_types.h index a32ce6c551d..22fbad4716b 100644 --- a/src/mpi/coll/include/coll_types.h +++ b/src/mpi/coll/include/coll_types.h @@ -13,16 +13,6 @@ #define MPIR_COLL_FLAG_REDUCE_L 1 #define MPIR_COLL_FLAG_REDUCE_R 0 -/* enumerator for different tree types */ -typedef enum MPIR_Tree_type_t { - MPIR_TREE_TYPE_KARY = 0, - MPIR_TREE_TYPE_KNOMIAL_1, - MPIR_TREE_TYPE_KNOMIAL_2, - MPIR_TREE_TYPE_TOPOLOGY_AWARE, - MPIR_TREE_TYPE_TOPOLOGY_AWARE_K, - MPIR_TREE_TYPE_TOPOLOGY_WAVE, -} MPIR_Tree_type_t; - /* enumerator for different recexch types */ enum { MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER = 0,