Skip to content

Commit

Permalink
TL/UCP: reduce_scatter_ring (#413)
Browse files Browse the repository at this point in the history
* TL/UCP: tl_ucp_schedule_t

* UTIL: ep_map reverse

* UTIL: buffer devision into blocks

* TL/UCP: reduce_scatter_ring

* TEST: reduce_scatter in gtest

* TEST: mpi reduce_scatter fix

    Allow running test_mpi reduce_scatter when count % comm_size != 0

* TL/UCP: rs ring 2 cb

* REVIEW: address comments
  • Loading branch information
valentin petrov authored Feb 7, 2022
1 parent 88f9706 commit 112735e
Show file tree
Hide file tree
Showing 19 changed files with 1,089 additions and 68 deletions.
4 changes: 3 additions & 1 deletion src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ reduce = \

reduce_scatter = \
reduce_scatter/reduce_scatter.h \
reduce_scatter/reduce_scatter_knomial.c
reduce_scatter/reduce_scatter_knomial.c \
reduce_scatter/reduce_scatter_ring.c \
reduce_scatter/reduce_scatter.c

scatter = \
scatter/scatter.h \
Expand Down
12 changes: 6 additions & 6 deletions src/components/tl/ucp/allreduce/allreduce_sra_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ static ucc_status_t ucc_tl_ucp_allreduce_sra_knomial_frag_init(
ucc_base_team_t *team, ucc_schedule_t **frag_p)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_schedule_t *schedule = ucc_tl_ucp_get_schedule(tl_team, coll_args);
size_t count = coll_args->args.dst.info.count;
ucc_base_coll_args_t args = *coll_args;
ucc_schedule_t *schedule =
&ucc_tl_ucp_get_schedule(tl_team, coll_args)->super.super;
ucc_coll_task_t *task, *rs_task;
ucc_status_t status;
ucc_kn_radix_t radix, cfg_radix;
Expand Down Expand Up @@ -159,13 +160,12 @@ static inline void get_sra_n_frags(ucc_base_coll_args_t *coll_args,
static ucc_status_t
ucc_tl_ucp_allreduce_sra_knomial_finalize(ucc_coll_task_t *task)
{
ucc_schedule_pipelined_t *schedule =
ucc_derived_of(task, ucc_schedule_pipelined_t);
ucc_schedule_t *schedule = ucc_derived_of(task, ucc_schedule_t);
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(schedule, "ucp_allreduce_sra_kn_done", 0);
status = ucc_schedule_pipelined_finalize(task);
ucc_tl_ucp_put_schedule_pipelined(schedule);
ucc_tl_ucp_put_schedule(schedule);
return status;
}

Expand All @@ -184,7 +184,7 @@ ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_tl_ucp_lib_config_t *cfg = &UCC_TL_UCP_TEAM_LIB(tl_team)->cfg;
int n_frags, pipeline_depth;
ucc_schedule_pipelined_t *schedule_p =
ucc_tl_ucp_get_schedule_pipelined(tl_team);
&ucc_tl_ucp_get_schedule(tl_team, NULL)->super;
ucc_status_t status;

if (!schedule_p) {
Expand All @@ -198,7 +198,7 @@ ucc_tl_ucp_allreduce_sra_knomial_init(ucc_base_coll_args_t *coll_args,
cfg->allreduce_sra_kn_seq, schedule_p);
if (UCC_OK != status) {
tl_error(team->context->lib, "failed to init pipelined schedule");
ucc_tl_ucp_put_schedule_pipelined(schedule_p);
ucc_tl_ucp_put_schedule(&schedule_p->super);
return status;
}
schedule_p->super.super.finalize =
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/ucp/bcast/bcast_sag_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ ucc_tl_ucp_bcast_sag_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_coll_task_t **task_h)
{
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(team, ucc_tl_ucp_team_t);
ucc_schedule_t *schedule = ucc_tl_ucp_get_schedule(tl_team, coll_args);
size_t count = coll_args->args.src.info.count;
ucc_base_coll_args_t args = *coll_args;
ucc_schedule_t *schedule =
&ucc_tl_ucp_get_schedule(tl_team, coll_args)->super.super;
ucc_coll_task_t *task, *rs_task;
ucc_status_t status;
ucc_kn_radix_t radix, cfg_radix;
Expand Down
17 changes: 17 additions & 0 deletions src/components/tl/ucp/reduce_scatter/reduce_scatter.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/**
* Copyright (C) Mellanox Technologies Ltd. 2022. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
*/
#include "tl_ucp.h"
#include "reduce_scatter.h"
#include "utils/ucc_coll_utils.h"

ucc_base_coll_alg_info_t
ucc_tl_ucp_reduce_scatter_algs[UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST + 1] = {
[UCC_TL_UCP_REDUCE_SCATTER_ALG_RING] =
{.id = UCC_TL_UCP_REDUCE_SCATTER_ALG_RING,
.name = "ring",
.desc = "O(N) ring"},
[UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};
30 changes: 29 additions & 1 deletion src/components/tl/ucp/reduce_scatter/reduce_scatter.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
/**
* Copyright (C) Mellanox Technologies Ltd. 2021. ALL RIGHTS RESERVED.
* Copyright (C) Mellanox Technologies Ltd. 2021-2022. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
*/
#ifndef REDUCE_SCATTER_H_
#define REDUCE_SCATTER_H_
#include "../tl_ucp_reduce.h"

enum
{
UCC_TL_UCP_REDUCE_SCATTER_ALG_RING,
UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST
};

extern ucc_base_coll_alg_info_t
ucc_tl_ucp_reduce_scatter_algs[UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST + 1];

#define UCC_TL_UCP_REDUCE_SCATTER_DEFAULT_ALG_SELECT_STR \
"reduce_scatter:@ring"

static inline int ucc_tl_ucp_reduce_scatter_alg_from_str(const char *str)
{
int i;
for (i = 0; i < UCC_TL_UCP_REDUCE_SCATTER_ALG_LAST; i++) {
if (0 == strcasecmp(str, ucc_tl_ucp_reduce_scatter_algs[i].name)) {
break;
}
}
return i;
}

/* Base interface signature: uses reduce_scatter_kn_radix from config. */

ucc_status_t
Expand All @@ -18,4 +41,9 @@ ucc_tl_ucp_reduce_scatter_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_status_t ucc_tl_ucp_reduce_scatter_knomial_init_r(
ucc_base_coll_args_t *coll_args, ucc_base_team_t *team,
ucc_coll_task_t **task_h, ucc_kn_radix_t radix);

ucc_status_t
ucc_tl_ucp_reduce_scatter_ring_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t * team,
ucc_coll_task_t ** task_h);
#endif
Loading

0 comments on commit 112735e

Please sign in to comment.