Skip to content

Commit

Permalink
coll: refactor caching tree in the comm struct
Browse files Browse the repository at this point in the history
Use a single "cached_tree" rather than 3 different fields for each tree
type.
  • Loading branch information
hzhou committed Aug 23, 2024
1 parent b729bb4 commit 499a9fb
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 92 deletions.
15 changes: 3 additions & 12 deletions src/include/mpir_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,9 @@ struct MPIR_Comm {
int **step2_nbrs[MAX_RADIX - 1];
int nbrs_defined[MAX_RADIX - 1];
void **recexch_allreduce_nbr_buffer;
int topo_aware_tree_root;
int topo_aware_tree_k;
MPIR_Treealgo_tree_t *topo_aware_tree;
int topo_aware_k_tree_root;
int topo_aware_k_tree_k;
MPIR_Treealgo_tree_t *topo_aware_k_tree;
int topo_wave_tree_root;
int topo_wave_tree_overhead;
int topo_wave_tree_lat_diff_groups;
int topo_wave_tree_lat_diff_switches;
int topo_wave_tree_lat_same_switches;
MPIR_Treealgo_tree_t *topo_wave_tree;

MPIR_Treealgo_tree_t *cached_tree;
MPIR_Treealgo_param_t cached_tree_param;
} coll;

void *csel_comm; /* collective selector handle */
Expand Down
34 changes: 5 additions & 29 deletions src/mpi/coll/algorithms/recexchalgo/recexchalgo.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,7 @@ int MPII_Recexchalgo_comm_init(MPIR_Comm * comm)
}
comm->coll.recexch_allreduce_nbr_buffer = NULL;

comm->coll.topo_aware_tree_root = -1;
comm->coll.topo_aware_tree_k = 0;
comm->coll.topo_aware_tree = NULL;
comm->coll.topo_aware_k_tree_root = -1;
comm->coll.topo_aware_k_tree_k = 0;
comm->coll.topo_aware_k_tree = NULL;
comm->coll.topo_wave_tree_root = -1;
comm->coll.topo_wave_tree = NULL;
comm->coll.topo_wave_tree_overhead = 0;
comm->coll.topo_wave_tree_lat_diff_groups = 0;
comm->coll.topo_wave_tree_lat_diff_switches = 0;
comm->coll.topo_wave_tree_lat_same_switches = 0;

comm->coll.cached_tree = NULL;
return mpi_errno;
}

Expand Down Expand Up @@ -66,22 +54,10 @@ int MPII_Recexchalgo_comm_cleanup(MPIR_Comm * comm)
MPL_free(comm->coll.recexch_allreduce_nbr_buffer);
}

if (comm->coll.topo_aware_tree) {
MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree);
MPL_free(comm->coll.topo_aware_tree);
comm->coll.topo_aware_tree = NULL;
}

if (comm->coll.topo_aware_k_tree) {
MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree);
MPL_free(comm->coll.topo_aware_k_tree);
comm->coll.topo_aware_k_tree = NULL;
}

if (comm->coll.topo_wave_tree) {
MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree);
MPL_free(comm->coll.topo_wave_tree);
comm->coll.topo_wave_tree = NULL;
if (comm->coll.cached_tree) {
MPIR_Treealgo_tree_free(comm->coll.cached_tree);
MPL_free(comm->coll.cached_tree);
comm->coll.cached_tree = NULL;
}

return mpi_errno;
Expand Down
123 changes: 82 additions & 41 deletions src/mpi/coll/algorithms/treealgo/treealgo.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,56 @@ int MPII_Treealgo_comm_cleanup(MPIR_Comm * comm)
return mpi_errno;
}

static bool match_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k)
{
return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE &&
param->root == root && param->u.topo_aware.k == k);
}

static void set_param_topo_aware(MPIR_Treealgo_param_t * param, int root, int k)
{
param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE;
param->root = root;
param->u.topo_aware.k = k;
}

static bool match_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k)
{
return (param->type == MPIR_TREE_TYPE_TOPOLOGY_AWARE_K &&
param->root == root && param->u.topo_aware.k == k);
}

static void set_param_topo_aware_k(MPIR_Treealgo_param_t * param, int root, int k)
{
param->type = MPIR_TREE_TYPE_TOPOLOGY_AWARE_K;
param->root = root;
param->u.topo_aware.k = k;
}

static inline bool match_param_topo_wave(MPIR_Treealgo_param_t * param,
int root, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches)
{
return (param->type == MPIR_TREE_TYPE_TOPOLOGY_WAVE &&
param->root == root &&
param->u.topo_wave.overhead == overhead &&
param->u.topo_wave.lat_diff_groups == lat_diff_groups &&
param->u.topo_wave.lat_diff_switches == lat_diff_switches &&
param->u.topo_wave.lat_same_switches == lat_same_switches);
}

static inline void set_param_topo_wave(MPIR_Treealgo_param_t * param,
int root, int overhead, int lat_diff_groups,
int lat_diff_switches, int lat_same_switches)
{
param->type = MPIR_TREE_TYPE_TOPOLOGY_WAVE;
param->root = root;
param->u.topo_wave.overhead = overhead;
param->u.topo_wave.lat_diff_groups = lat_diff_groups;
param->u.topo_wave.lat_diff_switches = lat_diff_switches;
param->u.topo_wave.lat_same_switches = lat_same_switches;
}


int MPIR_Treealgo_tree_create(int rank, int nranks, int tree_type, int k, int root,
MPIR_Treealgo_tree_t * ct)
Expand Down Expand Up @@ -84,56 +134,52 @@ int MPIR_Treealgo_tree_create_topo_aware(MPIR_Comm * comm, int tree_type, int k,

switch (tree_type) {
case MPIR_TREE_TYPE_TOPOLOGY_AWARE:
if (!comm->coll.topo_aware_tree || root != comm->coll.topo_aware_tree_root
|| k != comm->coll.topo_aware_tree_k) {
if (comm->coll.topo_aware_tree) {
MPIR_Treealgo_tree_free(comm->coll.topo_aware_tree);
if (!comm->coll.cached_tree ||
!match_param_topo_aware(&comm->coll.cached_tree_param, root, k)) {
if (comm->coll.cached_tree) {
MPIR_Treealgo_tree_free(comm->coll.cached_tree);
} else {
comm->coll.topo_aware_tree =
comm->coll.cached_tree =
(MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t),
MPL_MEM_BUFFER);
}
mpi_errno =
MPII_Treeutil_tree_topology_aware_init(comm, k, root, enable_reorder,
comm->coll.topo_aware_tree);
comm->coll.cached_tree);
MPIR_ERR_CHECK(mpi_errno);
*ct = *comm->coll.topo_aware_tree;
comm->coll.topo_aware_tree_root = root;
comm->coll.topo_aware_tree_k = k;
*ct = *comm->coll.cached_tree;
set_param_topo_aware(&comm->coll.cached_tree_param, root, k);
}
*ct = *comm->coll.topo_aware_tree;
*ct = *comm->coll.cached_tree;
utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL);
for (int i = 0; i < ct->num_children; i++) {
utarray_push_back(ct->children,
&ut_int_array(comm->coll.topo_aware_tree->children)[i],
MPL_MEM_COLL);
&ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL);
}
break;

case MPIR_TREE_TYPE_TOPOLOGY_AWARE_K:
if (!comm->coll.topo_aware_k_tree || root != comm->coll.topo_aware_k_tree_root
|| k != comm->coll.topo_aware_k_tree_k) {
if (comm->coll.topo_aware_k_tree) {
MPIR_Treealgo_tree_free(comm->coll.topo_aware_k_tree);
if (!comm->coll.cached_tree ||
!match_param_topo_aware_k(&comm->coll.cached_tree_param, root, k)) {
if (comm->coll.cached_tree) {
MPIR_Treealgo_tree_free(comm->coll.cached_tree);
} else {
comm->coll.topo_aware_k_tree =
comm->coll.cached_tree =
(MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t),
MPL_MEM_BUFFER);
}
mpi_errno =
MPII_Treeutil_tree_topology_aware_k_init(comm, k, root, enable_reorder,
comm->coll.topo_aware_k_tree);
comm->coll.cached_tree);
MPIR_ERR_CHECK(mpi_errno);
*ct = *comm->coll.topo_aware_k_tree;
comm->coll.topo_aware_k_tree_root = root;
comm->coll.topo_aware_k_tree_k = k;
*ct = *comm->coll.cached_tree;
set_param_topo_aware_k(&comm->coll.cached_tree_param, root, k);
}
*ct = *comm->coll.topo_aware_k_tree;
*ct = *comm->coll.cached_tree;
utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL);
for (int i = 0; i < ct->num_children; i++) {
utarray_push_back(ct->children,
&ut_int_array(comm->coll.topo_aware_k_tree->children)[i],
MPL_MEM_COLL);
&ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL);
}
break;

Expand Down Expand Up @@ -164,34 +210,29 @@ int MPIR_Treealgo_tree_create_topo_wave(MPIR_Comm * comm, int k, int root,

MPIR_FUNC_ENTER;

if (!comm->coll.topo_wave_tree || root != comm->coll.topo_wave_tree_root
|| overhead != comm->coll.topo_wave_tree_overhead
|| lat_diff_groups != comm->coll.topo_wave_tree_lat_diff_groups
|| lat_diff_switches != comm->coll.topo_wave_tree_lat_diff_switches
|| lat_same_switches != comm->coll.topo_wave_tree_lat_same_switches) {
if (comm->coll.topo_wave_tree) {
MPIR_Treealgo_tree_free(comm->coll.topo_wave_tree);
if (!comm->coll.cached_tree ||
!match_param_topo_wave(&comm->coll.cached_tree_param, root, overhead,
lat_diff_groups, lat_diff_switches, lat_same_switches)) {
if (comm->coll.cached_tree) {
MPIR_Treealgo_tree_free(comm->coll.cached_tree);
} else {
comm->coll.topo_wave_tree =
comm->coll.cached_tree =
(MPIR_Treealgo_tree_t *) MPL_malloc(sizeof(MPIR_Treealgo_tree_t), MPL_MEM_BUFFER);
}
mpi_errno = MPII_Treeutil_tree_topology_wave_init(comm, k, root, enable_reorder, overhead,
lat_diff_groups, lat_diff_switches,
lat_same_switches,
comm->coll.topo_wave_tree);
comm->coll.cached_tree);
MPIR_ERR_CHECK(mpi_errno);
*ct = *comm->coll.topo_wave_tree;
comm->coll.topo_wave_tree_root = root;
comm->coll.topo_wave_tree_overhead = overhead;
comm->coll.topo_wave_tree_lat_diff_groups = lat_diff_groups;
comm->coll.topo_wave_tree_lat_diff_switches = lat_diff_switches;
comm->coll.topo_wave_tree_lat_same_switches = lat_same_switches;
*ct = *comm->coll.cached_tree;
set_param_topo_wave(&comm->coll.cached_tree_param, root, overhead,
lat_diff_groups, lat_diff_switches, lat_same_switches);
}
*ct = *comm->coll.topo_wave_tree;
*ct = *comm->coll.cached_tree;
utarray_new(ct->children, &ut_int_icd, MPL_MEM_COLL);
for (int i = 0; i < ct->num_children; i++) {
utarray_push_back(ct->children,
&ut_int_array(comm->coll.topo_wave_tree->children)[i], MPL_MEM_COLL);
&ut_int_array(comm->coll.cached_tree->children)[i], MPL_MEM_COLL);
}

MPIR_FUNC_EXIT;
Expand Down
26 changes: 26 additions & 0 deletions src/mpi/coll/algorithms/treealgo/treealgo_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@

#include <utarray.h>

/* enumerator for different tree types */
typedef enum MPIR_Tree_type_t {
MPIR_TREE_TYPE_KARY = 0,
MPIR_TREE_TYPE_KNOMIAL_1,
MPIR_TREE_TYPE_KNOMIAL_2,
MPIR_TREE_TYPE_TOPOLOGY_AWARE,
MPIR_TREE_TYPE_TOPOLOGY_AWARE_K,
MPIR_TREE_TYPE_TOPOLOGY_WAVE,
} MPIR_Tree_type_t;

typedef struct {
int rank;
int nranks;
Expand All @@ -16,4 +26,20 @@ typedef struct {
UT_array *children;
} MPIR_Treealgo_tree_t;

typedef struct {
MPIR_Tree_type_t type;
int root;
union {
struct {
int k;
} topo_aware;
struct {
int overhead;
int lat_diff_groups;
int lat_diff_switches;
int lat_same_switches;
} topo_wave;
} u;
} MPIR_Treealgo_param_t;

#endif /* TREEALGO_TYPES_H_INCLUDED */
10 changes: 0 additions & 10 deletions src/mpi/coll/include/coll_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,6 @@
#define MPIR_COLL_FLAG_REDUCE_L 1
#define MPIR_COLL_FLAG_REDUCE_R 0

/* enumerator for different tree types */
typedef enum MPIR_Tree_type_t {
MPIR_TREE_TYPE_KARY = 0,
MPIR_TREE_TYPE_KNOMIAL_1,
MPIR_TREE_TYPE_KNOMIAL_2,
MPIR_TREE_TYPE_TOPOLOGY_AWARE,
MPIR_TREE_TYPE_TOPOLOGY_AWARE_K,
MPIR_TREE_TYPE_TOPOLOGY_WAVE,
} MPIR_Tree_type_t;

/* enumerator for different recexch types */
enum {
MPIR_IALLREDUCE_RECEXCH_TYPE_SINGLE_BUFFER = 0,
Expand Down

0 comments on commit 499a9fb

Please sign in to comment.