Skip to content

Commit

Permalink
TL/MLX5: addressing sam's comments on PR 989
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jul 9, 2024
1 parent 11edf29 commit 5fca2e8
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 32 deletions.
8 changes: 5 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "components/tl/ucc_tl_log.h"
#include "utils/ucc_rcache.h"
#include "core/ucc_service_coll.h"
#include "utils/arch/cuda_def.h"
#include "components/mc/ucc_mc.h"

#define POLL_PACKED 16
#define REL_DONE ((void*)-1)
Expand Down Expand Up @@ -91,7 +91,7 @@ typedef struct mcast_coll_comm_init_spec {
int scq_moderation;
int wsize;
int max_eager;
int device_mem_enabled;
int cuda_mem_enabled;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -196,17 +196,19 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
ucc_mc_buffer_header_t *grh_cuda_header;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int device_mem_enabled;
int cuda_mem_enabled;
int max_per_packet;
int pending_send;
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
ucc_mc_buffer_header_t *pp_cuda_header;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
Expand Down
17 changes: 17 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,23 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
}
}

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm;
ucc_coll_args_t *args = &coll_args->args;

if ((comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) ||
(!comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_HOST)) {
return UCC_OK;
}

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
#endif
10 changes: 5 additions & 5 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
return UCC_ERR_NO_RESOURCE;
}

if (comm->device_mem_enabled) {
if (comm->cuda_mem_enabled) {
/* max inline send otherwise it segfault during ibv send */
comm->max_inline = 0;
} else {
Expand Down Expand Up @@ -482,8 +482,8 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
}

if (comm->grh_buf) {
if (comm->device_mem_enabled) {
cudaFree(comm->grh_buf);
if (comm->cuda_mem_enabled) {
ucc_mc_free(comm->grh_cuda_header);
} else {
ucc_free(comm->grh_buf);
}
Expand All @@ -502,8 +502,8 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
}

if (comm->pp_buf) {
if (comm->device_mem_enabled) {
cudaFree(comm->pp_buf);
if (comm->cuda_mem_enabled) {
ucc_mc_free(comm->pp_cuda_header);
} else {
ucc_free(comm->pp_buf);
}
Expand Down
10 changes: 7 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,13 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
if (zcopy) {
pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset);
} else {
if (comm->device_mem_enabled) {
CUDA_FUNC(cudaMemcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset),
length, cudaMemcpyDeviceToDevice));
if (comm->cuda_mem_enabled) {
status = ucc_mc_memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length,
UCC_MEMORY_TYPE_CUDA, UCC_MEMORY_TYPE_CUDA);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy cuda buffer");
return status;
}
} else {
memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length);
}
Expand Down
9 changes: 7 additions & 2 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,13 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com

if (pp->length > 0 ) {
dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);
if (comm->device_mem_enabled) {
cudaMemcpy(dest, (void*) pp->buf, pp->length, cudaMemcpyDeviceToDevice);
if (comm->cuda_mem_enabled) {
status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length,
UCC_MEMORY_TYPE_CUDA, UCC_MEMORY_TYPE_CUDA);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy cuda buffer");
return status;
}
} else {
memcpy(dest, (void*) pp->buf, pp->length);
}
Expand Down
58 changes: 41 additions & 17 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
#include "p2p/ucc_tl_mlx5_mcast_p2p.h"
#include "mcast/tl_mlx5_mcast_helper.h"

static ucc_status_t ucc_tl_mlx5_check_gpudirect_driver()
{
ucc_status_t status = UCC_ERR_NO_RESOURCE;
const char *file = "/sys/kernel/mm/memory_peers/nv_mem/version";

if (!access(file, F_OK)) {
status = UCC_OK;
}

return status;
}

static ucc_status_t ucc_tl_mlx5_mcast_service_bcast_post(void *arg, void *buf, size_t size, ucc_rank_t root,
ucc_service_coll_req_t **bcast_req)
{
Expand Down Expand Up @@ -120,10 +132,16 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

comm->wsize = conf_params->wsize;
comm->max_eager = conf_params->max_eager;
comm->device_mem_enabled = conf_params->device_mem_enabled;
comm->cuda_mem_enabled = conf_params->cuda_mem_enabled;
comm->comm_id = team_params->id;
comm->ctx = mcast_context;

if (comm->cuda_mem_enabled && (UCC_OK != ucc_tl_mlx5_check_gpudirect_driver())) {
tl_warn(mcast_context->lib, "cuda-aware mcast not available as gpu direct is not ready");
status = UCC_ERR_NO_RESOURCE;
goto cleanup;
}

comm->rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0);
if (!comm->rcq) {
ibv_dereg_mr(comm->grh_mr);
Expand Down Expand Up @@ -212,35 +230,41 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_
comm->pending_recv = 0;
comm->buf_n = comm->params.rx_depth * 2;

if (comm->device_mem_enabled) {
/* TODO add check to make sure GPUDirect is enabled
* lsmod | grep nv_peer */
CUDA_FUNC(cudaMalloc((void **)&comm->grh_buf, GRH_LENGTH * sizeof(char)));
if (!comm->grh_buf) {
tl_error(comm->ctx->lib, "cuda memcpy failed");
status = UCC_ERR_NO_MEMORY;
goto error;
if (comm->cuda_mem_enabled) {
status = ucc_mc_alloc(&comm->grh_cuda_header, GRH_LENGTH *
sizeof(char), UCC_MEMORY_TYPE_CUDA);
comm->grh_buf = comm->grh_cuda_header->addr;
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->ctx->lib, "failed to allocate cuda memory");
return status;
}

CUDA_FUNC(cudaMemset(comm->grh_buf, 0, GRH_LENGTH));
status = ucc_mc_memset(comm->grh_buf, 0, GRH_LENGTH, UCC_MEMORY_TYPE_CUDA);
if (status != UCC_OK) {
tl_error(comm->ctx->lib, "could not cuda memset");
goto error;
}

comm->grh_mr = ibv_reg_mr(comm->ctx->pd, comm->grh_buf, GRH_LENGTH,
IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE);
if (!comm->grh_mr) {
tl_error(comm->ctx->lib, "Could not register device memory for GRH, errno %d", errno);
tl_error(comm->ctx->lib, "could not register device memory for GRH, errno %d", errno);
status = UCC_ERR_NO_RESOURCE;
goto error;
}

// assuming the device page size is same as host page size
CUDA_FUNC(cudaMalloc((void**) &comm->pp_buf, buf_size * comm->buf_n));
if (!comm->pp_buf) {
tl_error(comm->ctx->lib, "cuda memcpy failed");
status = UCC_ERR_NO_MEMORY;
status = ucc_mc_alloc(&comm->pp_cuda_header, buf_size * comm->buf_n, UCC_MEMORY_TYPE_CUDA);
comm->pp_buf = comm->pp_cuda_header->addr;
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->ctx->lib, "failed to allocate cuda memory");
goto error;
}

CUDA_FUNC(cudaMemset(comm->pp_buf, 0, buf_size * comm->buf_n));
status = ucc_mc_memset(comm->pp_buf, 0, buf_size * comm->buf_n, UCC_MEMORY_TYPE_CUDA);
if (status != UCC_OK) {
tl_error(comm->ctx->lib, "could not cuda memset");
goto error;
}

comm->pp_mr = ibv_reg_mr(comm->ctx->pd, comm->pp_buf, buf_size * comm->buf_n,
IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE |
Expand Down
4 changes: 2 additions & 2 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.max_eager),
UCC_CONFIG_TYPE_MEMUNITS},

{"MCAST_DEVICE_MEM_ENABLE", "0", "Enable GPU memory support for Mcast",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.device_mem_enabled),
{"MCAST_CUDA_MEM_ENABLE", "0", "Enable GPU CUDA memory support for Mcast. GPUDirect RDMA must be enabled",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.cuda_mem_enabled),
UCC_CONFIG_TYPE_INT},

{NULL}};
Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/mlx5/tl_mlx5_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args,
tl_trace(team->context->lib, "mcast bcast not supported for active sets");
return UCC_ERR_NOT_SUPPORTED;
}

if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) {
tl_trace(team->context->lib, "mcast bcast not compatible with this memory type");
return UCC_ERR_NOT_SUPPORTED;
}

task = ucc_tl_mlx5_get_task(coll_args, team);

Expand Down

0 comments on commit 5fca2e8

Please sign in to comment.