Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/UCP: Add linear alltoall and allgather algorithms based on xgvmi ucp get #992

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions contrib/doca_urom_ucc_plugin/dpu/worker_ucc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ static void *urom_worker_ucc_ctx_progress_thread(void *arg)
if (UCC_OK != ucc_context_config_modify(
ctx_config,
"tl/ucp", "TUNE",
"allreduce:0-inf:@sliding_window")) {
"allreduce:0-inf:@sliding_window#allgather:0-inf:@linear_xgvmi#alltoall:0-inf:@linear_xgvmi")) {
DOCA_LOG_ERR("Failed to modify TL_UCP_TUNE UCC lib config");
status = DOCA_ERROR_DRIVER;
goto cfg_release;
Expand Down Expand Up @@ -1956,8 +1956,7 @@ urom_worker_ucc_coll_init(struct urom_worker_ucc *ucc_worker,
/* Cannot support callbacks to host data.. just won't work */
coll_args->mask = coll_args->mask & (~UCC_COLL_ARGS_FIELD_CB);

if (coll_args->coll_type == UCC_COLL_TYPE_ALLTOALL ||
coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV) {
if (coll_args->coll_type == UCC_COLL_TYPE_ALLTOALLV) {
if (!ucc_cmd->coll_cmd.use_xgvmi) {
size_mod = urom_worker_get_dt_size(coll_args->src.info.datatype);
size = coll_args->src.info.count * size_mod;
Expand Down Expand Up @@ -2013,6 +2012,7 @@ urom_worker_ucc_coll_init(struct urom_worker_ucc *ucc_worker,
}
}
} else if (coll_args->coll_type == UCC_COLL_TYPE_ALLREDUCE ||
coll_args->coll_type == UCC_COLL_TYPE_ALLTOALL ||
coll_args->coll_type == UCC_COLL_TYPE_ALLGATHER) {
if (!ucc_cmd->coll_cmd.use_xgvmi) {
DOCA_LOG_ERR("Failed to initialize UCC collective:"
Expand Down
28 changes: 15 additions & 13 deletions src/components/tl/ucp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,27 @@ SUBDIRS = .
include makefile.coll_plugins.am
endif

allgather = \
allgather/allgather.h \
allgather/allgather.c \
allgather/allgather_ring.c \
allgather/allgather_neighbor.c \
allgather/allgather_bruck.c \
allgather/allgather_sparbit.c \
allgather/allgather_knomial.c
allgather = \
allgather/allgather.h \
allgather/allgather.c \
allgather/allgather_ring.c \
allgather/allgather_neighbor.c \
allgather/allgather_bruck.c \
allgather/allgather_sparbit.c \
allgather/allgather_knomial.c \
allgather/allgather_linear_xgvmi.c

allgatherv = \
allgatherv/allgatherv.h \
allgatherv/allgatherv.c \
allgatherv/allgatherv_ring.c

alltoall = \
alltoall/alltoall.h \
alltoall/alltoall.c \
alltoall/alltoall_onesided.c \
alltoall/alltoall_pairwise.c \
alltoall = \
alltoall/alltoall.h \
alltoall/alltoall.c \
alltoall/alltoall_onesided.c \
alltoall/alltoall_pairwise.c \
alltoall/alltoall_linear_xgvmi.c \
alltoall/alltoall_bruck.c

alltoallv = \
Expand Down
6 changes: 5 additions & 1 deletion src/components/tl/ucp/allgather/allgather.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -31,6 +31,10 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_ALLGATHER_ALG_SPARBIT,
.name = "sparbit",
.desc = "O(log(N)) SPARBIT algorithm"},
[UCC_TL_UCP_ALLGATHER_ALG_LINEAR_XGVMI] =
{.id = UCC_TL_UCP_ALLGATHER_ALG_LINEAR_XGVMI,
.name = "linear_xgvmi",
.desc = "Offloaded linear xgvmi algorithm"},
[UCC_TL_UCP_ALLGATHER_ALG_LAST] = {
.id = 0, .name = NULL, .desc = NULL}};

Expand Down
7 changes: 6 additions & 1 deletion src/components/tl/ucp/allgather/allgather.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
/**
* Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#ifndef ALLGATHER_H_
#define ALLGATHER_H_
#include "../tl_ucp.h"
Expand All @@ -14,6 +15,7 @@ enum {
UCC_TL_UCP_ALLGATHER_ALG_NEIGHBOR,
UCC_TL_UCP_ALLGATHER_ALG_BRUCK,
UCC_TL_UCP_ALLGATHER_ALG_SPARBIT,
UCC_TL_UCP_ALLGATHER_ALG_LINEAR_XGVMI,
UCC_TL_UCP_ALLGATHER_ALG_LAST
};

Expand Down Expand Up @@ -74,6 +76,9 @@ ucc_status_t ucc_tl_ucp_allgather_sparbit_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

/* XGVMI */
void ucc_tl_ucp_dpu_allgather_linear_xgvmi_rdma_progress(ucc_coll_task_t *coll_task);

/* Uses allgather_kn_radix from config */
ucc_status_t ucc_tl_ucp_allgather_knomial_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
Expand Down
76 changes: 76 additions & 0 deletions src/components/tl/ucp/allgather/allgather_linear_xgvmi.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/**
* Copyright(c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "tl_ucp_ep.h"
#include "tl_ucp_coll.h"
#include "tl_ucp_dpu_offload.h"

void
ucc_tl_ucp_dpu_allgather_linear_xgvmi_rdma_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_datatype_t dtype = TASK_ARGS(task).src.info.datatype;
size_t dt_size = ucc_dt_size(dtype);
ucc_count_t count = coll_task->bargs.args.src.info.count;
ucc_base_team_t *base_team = coll_task->team;
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(base_team,
ucc_tl_ucp_team_t);
ucc_coll_task_t *allgather_task = task->dpu_xgvmi.allgather_task;
ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
ucc_rank_t host_team_size = UCC_TL_TEAM_SIZE(tl_team);
ucp_request_param_t req_param = {0};
int i = 0;
ucc_rank_t rank = UCC_TL_TEAM_RANK(tl_team);
size_t data_size = (count * dt_size);
ucs_status_ptr_t *requests = task->dpu_xgvmi.requests;
int *posted = &task->dpu_xgvmi.gets_posted;
int *completed = &task->dpu_xgvmi.gets_completed;
void *src_addr;
void *dst_addr;
ucp_rkey_h rkey;
ucp_ep_h ep;
ucc_rank_t offset;

if (allgather_task != NULL) {
ucc_tl_ucp_dpu_xgvmi_key_exchange_progress(coll_task);
return;
}

req_param.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;

for (i = *posted; i < host_team_size; i++) {
offset = (i + rank) % host_team_size;
req_param.memh = task->dpu_xgvmi.bufs->dst_ebuf->memh;
src_addr = task->dpu_xgvmi.bufs->sbufs[offset];
dst_addr = PTR_OFFSET(task->dpu_xgvmi.bufs->rbufs[rank],
offset * data_size);
rkey = task->dpu_xgvmi.bufs->src_rkeys[offset];
ucc_tl_ucp_get_ep(tl_team, offset, &ep);

requests[i] = ucp_get_nbx(
ep, dst_addr,
data_size, (uint64_t)src_addr,
rkey, &req_param);

*posted += 1;
}

ucp_worker_progress(tl_ctx->worker.ucp_worker);

for (i = *completed; i < *posted; i++) {
if (ucc_tl_ucp_dpu_xgvmi_req_test(requests[i], task) == UCC_OK) {
if (requests[i]) ucp_request_free(requests[i]);
*completed += 1;
} else {
break;
}
}

if (*completed == host_team_size) {
task->super.status = UCC_OK;
}
}
6 changes: 5 additions & 1 deletion src/components/tl/ucp/alltoall/alltoall.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -43,6 +43,10 @@ ucc_base_coll_alg_info_t
{.id = UCC_TL_UCP_ALLTOALL_ALG_ONESIDED,
.name = "onesided",
.desc = "naive, linear one-sided implementation"},
[UCC_TL_UCP_ALLTOALL_ALG_LINEAR_XGVMI] =
{.id = UCC_TL_UCP_ALLTOALL_ALG_LINEAR_XGVMI,
.name = "linear_xgvmi",
.desc = "linear xgvmi-based implementation"},
[UCC_TL_UCP_ALLTOALL_ALG_LAST] = {.id = 0, .name = NULL, .desc = NULL}};

ucc_status_t ucc_tl_ucp_alltoall_init(ucc_tl_ucp_task_t *task)
Expand Down
6 changes: 5 additions & 1 deletion src/components/tl/ucp/alltoall/alltoall.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand All @@ -14,6 +14,7 @@ enum {
UCC_TL_UCP_ALLTOALL_ALG_PAIRWISE,
UCC_TL_UCP_ALLTOALL_ALG_BRUCK,
UCC_TL_UCP_ALLTOALL_ALG_ONESIDED,
UCC_TL_UCP_ALLTOALL_ALG_LINEAR_XGVMI,
nsarka marked this conversation as resolved.
Show resolved Hide resolved
UCC_TL_UCP_ALLTOALL_ALG_LAST
};

Expand Down Expand Up @@ -42,6 +43,9 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
ucc_coll_task_t **task_h);

/* XGVMI */
void ucc_tl_ucp_dpu_alltoall_linear_xgvmi_rdma_progress(ucc_coll_task_t *coll_task);

#define ALLTOALL_CHECK_INPLACE(_args, _team) \
do { \
if (UCC_IS_INPLACE(_args)) { \
Expand Down
77 changes: 77 additions & 0 deletions src/components/tl/ucp/alltoall/alltoall_linear_xgvmi.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/**
* Copyright(c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/

#include "tl_ucp_ep.h"
#include "tl_ucp_coll.h"
#include "tl_ucp_dpu_offload.h"

void
ucc_tl_ucp_dpu_alltoall_linear_xgvmi_rdma_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_datatype_t dtype = TASK_ARGS(task).src.info.datatype;
size_t dt_size = ucc_dt_size(dtype);
ucc_count_t count = coll_task->bargs.args.src.info.count;
ucc_base_team_t *base_team = coll_task->team;
ucc_tl_ucp_team_t *tl_team = ucc_derived_of(base_team,
ucc_tl_ucp_team_t);
ucc_rank_t host_team_size = UCC_TL_TEAM_SIZE(tl_team);
ucc_coll_task_t *allgather_task = task->dpu_xgvmi.allgather_task;
ucc_tl_ucp_context_t *tl_ctx = UCC_TL_UCP_TEAM_CTX(tl_team);
ucp_request_param_t req_param = {0};
int i = 0;
ucc_rank_t rank = UCC_TL_TEAM_RANK(tl_team);
size_t data_size = (count * dt_size) / host_team_size;
ucs_status_ptr_t *requests = task->dpu_xgvmi.requests;
int *posted = &task->dpu_xgvmi.gets_posted;
int *completed = &task->dpu_xgvmi.gets_completed;
void *src_addr;
void *dst_addr;
ucp_rkey_h rkey;
ucp_ep_h ep;
ucc_rank_t offset;

if (allgather_task != NULL) {
ucc_tl_ucp_dpu_xgvmi_key_exchange_progress(coll_task);
return;
}

req_param.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;

for (i = *posted; i < host_team_size; i++) {
offset = (i + rank) % host_team_size;
req_param.memh = task->dpu_xgvmi.bufs->dst_ebuf->memh;
src_addr = PTR_OFFSET(task->dpu_xgvmi.bufs->sbufs[offset],
rank * data_size);
dst_addr = PTR_OFFSET(task->dpu_xgvmi.bufs->rbufs[rank],
offset * data_size);
rkey = task->dpu_xgvmi.bufs->src_rkeys[offset];
ucc_tl_ucp_get_ep(tl_team, offset, &ep);

requests[i] = ucp_get_nbx(
ep, dst_addr,
data_size, (uint64_t)src_addr,
rkey, &req_param);

*posted += 1;
}

ucp_worker_progress(tl_ctx->worker.ucp_worker);

for (i = *completed; i < *posted; i++) {
if (ucc_tl_ucp_dpu_xgvmi_req_test(requests[i], task) == UCC_OK) {
if (requests[i]) ucp_request_free(requests[i]);
*completed += 1;
} else {
break;
}
}

if (*completed == host_team_size) {
task->super.status = UCC_OK;
}
}
7 changes: 7 additions & 0 deletions src/components/tl/ucp/tl_ucp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "fanin/fanin.h"
#include "fanout/fanout.h"
#include "scatterv/scatterv.h"
#include "tl_ucp_dpu_offload.h"

const ucc_tl_ucp_default_alg_desc_t
ucc_tl_ucp_default_alg_descs[UCC_TL_UCP_N_DEFAULT_ALG_SELECT_STR] = {
Expand Down Expand Up @@ -268,6 +269,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
case UCC_TL_UCP_ALLGATHER_ALG_SPARBIT:
*init = ucc_tl_ucp_allgather_sparbit_init;
break;
case UCC_TL_UCP_ALLGATHER_ALG_LINEAR_XGVMI:
*init = ucc_tl_ucp_dpu_xgvmi_init;
break;
default:
status = UCC_ERR_INVALID_PARAM;
break;
Expand Down Expand Up @@ -319,6 +323,9 @@ ucc_status_t ucc_tl_ucp_alg_id_to_init(int alg_id, const char *alg_id_str,
case UCC_TL_UCP_ALLTOALL_ALG_ONESIDED:
*init = ucc_tl_ucp_alltoall_onesided_init;
break;
case UCC_TL_UCP_ALLTOALL_ALG_LINEAR_XGVMI:
*init = ucc_tl_ucp_dpu_xgvmi_init;
break;
default:
status = UCC_ERR_INVALID_PARAM;
break;
Expand Down
20 changes: 14 additions & 6 deletions src/components/tl/ucp/tl_ucp_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,21 @@ typedef struct ucc_tl_ucp_task {
ucc_ee_executor_t *executor;
} allreduce_kn;
struct {
ucc_tl_ucp_allreduce_sw_pipeline *pipe;
ucs_status_ptr_t *put_requests;
ucc_tl_ucp_allreduce_sw_host_allgather *allgather_data;
ucc_coll_task_t *allgather_task;
ucc_ee_executor_task_t *reduce_task;
ucc_tl_ucp_dpu_offload_buf_info_t *bufs;
ucc_tl_ucp_allreduce_sw_pipeline *pipe;
ucs_status_ptr_t *put_requests;
ucc_tl_ucp_allreduce_sw_host_allgather *allgather_data;
ucc_coll_task_t *allgather_task;
ucc_ee_executor_task_t *reduce_task;
ucc_tl_ucp_dpu_offload_buf_info_t *bufs;
} allreduce_sliding_window;
struct {
ucc_tl_ucp_allreduce_sw_host_allgather *allgather_data;
ucc_coll_task_t *allgather_task;
ucc_tl_ucp_dpu_offload_buf_info_t *bufs;
ucs_status_ptr_t *requests;
int gets_posted;
int gets_completed;
} dpu_xgvmi;
struct {
int phase;
ucc_knomial_pattern_t p;
Expand Down
Loading
Loading