From 7798792a612f2ba6271188786501c284a8653339 Mon Sep 17 00:00:00 2001 From: Mamzi Bayatpour Date: Thu, 13 Jun 2024 12:14:07 -0700 Subject: [PATCH] TL/MLX5: add device mem mcast bcast --- src/components/mc/base/ucc_mc_base.h | 3 + src/components/mc/cpu/mc_cpu.c | 1 + src/components/mc/cuda/mc_cuda.c | 31 ++++++- src/components/mc/rocm/mc_rocm.c | 1 + src/components/mc/ucc_mc.c | 25 ++++++ src/components/mc/ucc_mc.h | 4 + src/components/tl/mlx5/mcast/tl_mlx5_mcast.h | 5 ++ .../tl/mlx5/mcast/tl_mlx5_mcast_coll.c | 17 ++++ .../tl/mlx5/mcast/tl_mlx5_mcast_coll.h | 2 + .../tl/mlx5/mcast/tl_mlx5_mcast_helper.c | 12 ++- .../tl/mlx5/mcast/tl_mlx5_mcast_helper.h | 9 +- .../tl/mlx5/mcast/tl_mlx5_mcast_progress.c | 15 +++- .../tl/mlx5/mcast/tl_mlx5_mcast_team.c | 87 +++++++++++++------ src/components/tl/mlx5/tl_mlx5.c | 4 + src/components/tl/mlx5/tl_mlx5_coll.c | 5 ++ 15 files changed, 185 insertions(+), 36 deletions(-) diff --git a/src/components/mc/base/ucc_mc_base.h b/src/components/mc/base/ucc_mc_base.h index 442088a09d..0c7182dbdc 100644 --- a/src/components/mc/base/ucc_mc_base.h +++ b/src/components/mc/base/ucc_mc_base.h @@ -109,6 +109,9 @@ typedef struct ucc_mc_ops { ucc_status_t (*memcpy)(void *dst, const void *src, size_t len, ucc_memory_type_t dst_mem, ucc_memory_type_t src_mem); + ucc_status_t (*sync_memcpy)(void *dst, const void *src, size_t len, + ucc_memory_type_t dst_mem, + ucc_memory_type_t src_mem); ucc_status_t (*memset)(void *dst, int value, size_t len); ucc_status_t (*flush)(); } ucc_mc_ops_t; diff --git a/src/components/mc/cpu/mc_cpu.c b/src/components/mc/cpu/mc_cpu.c index 624d54c1ac..ec73aff958 100644 --- a/src/components/mc/cpu/mc_cpu.c +++ b/src/components/mc/cpu/mc_cpu.c @@ -209,6 +209,7 @@ ucc_mc_cpu_t ucc_mc_cpu = { .super.ops.mem_alloc = ucc_mc_cpu_mem_pool_alloc_with_init, .super.ops.mem_free = ucc_mc_cpu_mem_pool_free, .super.ops.memcpy = ucc_mc_cpu_memcpy, + .super.ops.sync_memcpy = ucc_mc_cpu_memcpy, .super.ops.memset = ucc_mc_cpu_memset, .super.ops.flush = NULL, .super.config_table = diff --git a/src/components/mc/cuda/mc_cuda.c b/src/components/mc/cuda/mc_cuda.c index 72b73b4e67..f57cb63332 100644 --- a/src/components/mc/cuda/mc_cuda.c +++ b/src/components/mc/cuda/mc_cuda.c @@ -228,9 +228,31 @@ ucc_mc_cuda_mem_pool_alloc_with_init(ucc_mc_buffer_header_t **h_ptr, } } -static ucc_status_t ucc_mc_cuda_memcpy(void *dst, const void *src, size_t len, - ucc_memory_type_t dst_mem, - ucc_memory_type_t src_mem) +static ucc_status_t ucc_mc_cuda_sync_memcpy(void *dst, const void *src, size_t len, + ucc_memory_type_t dst_mem, + ucc_memory_type_t src_mem) +{ + ucc_status_t status; + + ucc_assert(dst_mem == UCC_MEMORY_TYPE_CUDA || + src_mem == UCC_MEMORY_TYPE_CUDA || + dst_mem == UCC_MEMORY_TYPE_CUDA_MANAGED || + src_mem == UCC_MEMORY_TYPE_CUDA_MANAGED); + + status = CUDA_FUNC(cudaMemcpy(dst, src, len, cudaMemcpyDefault)); + if (ucc_unlikely(status != UCC_OK)) { + mc_error(&ucc_mc_cuda.super, + "failed to launch cudaMemcpy, dst %p, src %p, len %zd", + dst, src, len); + return status; + } + + return status; +} + +static ucc_status_t ucc_mc_cuda_async_memcpy(void *dst, const void *src, size_t len, + ucc_memory_type_t dst_mem, + ucc_memory_type_t src_mem) { ucc_status_t status; ucc_mc_cuda_resources_t *resources; @@ -431,7 +453,8 @@ ucc_mc_cuda_t ucc_mc_cuda = { .super.ops.mem_query = ucc_mc_cuda_mem_query, .super.ops.mem_alloc = ucc_mc_cuda_mem_pool_alloc_with_init, .super.ops.mem_free = ucc_mc_cuda_mem_pool_free, - .super.ops.memcpy = ucc_mc_cuda_memcpy, + .super.ops.memcpy = ucc_mc_cuda_async_memcpy, + .super.ops.sync_memcpy = ucc_mc_cuda_sync_memcpy, .super.ops.memset = ucc_mc_cuda_memset, .super.ops.flush = ucc_mc_cuda_flush_not_supported, .super.config_table = diff --git a/src/components/mc/rocm/mc_rocm.c b/src/components/mc/rocm/mc_rocm.c index fbe7643ab7..0f830473f3 100644 --- a/src/components/mc/rocm/mc_rocm.c +++ b/src/components/mc/rocm/mc_rocm.c @@ -370,6 +370,7 @@ ucc_mc_rocm_t ucc_mc_rocm = { .super.ops.mem_alloc = ucc_mc_rocm_mem_pool_alloc_with_init, .super.ops.mem_free = ucc_mc_rocm_mem_pool_free, .super.ops.memcpy = ucc_mc_rocm_memcpy, + .super.ops.sync_memcpy = ucc_mc_rocm_memcpy, .super.ops.memset = ucc_mc_rocm_memset, .super.config_table = { diff --git a/src/components/mc/ucc_mc.c b/src/components/mc/ucc_mc.c index 997355443e..fbddc93403 100644 --- a/src/components/mc/ucc_mc.c +++ b/src/components/mc/ucc_mc.c @@ -132,6 +132,7 @@ ucc_status_t ucc_mc_get_attr(ucc_mc_attr_t *attr, ucc_memory_type_t mem_type) return mc->get_attr(attr); } +/* TODO: add the flexbility to bypass the mpool if the user asks for it */ UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_alloc, (h_ptr, size, mem_type), ucc_mc_buffer_header_t **h_ptr, size_t size, ucc_memory_type_t mem_type) @@ -152,6 +153,30 @@ ucc_status_t ucc_mc_free(ucc_mc_buffer_header_t *h_ptr) return mc_ops[mt]->mem_free(h_ptr); } +UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_sync_memcpy, + (dst, src, len, dst_mem, src_mem), void *dst, + const void *src, size_t len, ucc_memory_type_t dst_mem, + ucc_memory_type_t src_mem) + +{ + ucc_memory_type_t mt; + if (src_mem == UCC_MEMORY_TYPE_UNKNOWN || + dst_mem == UCC_MEMORY_TYPE_UNKNOWN) { + return UCC_ERR_INVALID_PARAM; + } else if (src_mem == UCC_MEMORY_TYPE_HOST && + dst_mem == UCC_MEMORY_TYPE_HOST) { + UCC_CHECK_MC_AVAILABLE(UCC_MEMORY_TYPE_HOST); + return mc_ops[UCC_MEMORY_TYPE_HOST]->memcpy(dst, src, len, + UCC_MEMORY_TYPE_HOST, + UCC_MEMORY_TYPE_HOST); + } + /* take any non host MC component */ + mt = (dst_mem == UCC_MEMORY_TYPE_HOST) ? src_mem : dst_mem; + mt = (mt == UCC_MEMORY_TYPE_CUDA_MANAGED) ? UCC_MEMORY_TYPE_CUDA : mt; + UCC_CHECK_MC_AVAILABLE(mt); + return mc_ops[mt]->sync_memcpy(dst, src, len, dst_mem, src_mem); +} + UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_memcpy, (dst, src, len, dst_mem, src_mem), void *dst, const void *src, size_t len, ucc_memory_type_t dst_mem, diff --git a/src/components/mc/ucc_mc.h b/src/components/mc/ucc_mc.h index e98396b2f7..6b3ad6ee45 100644 --- a/src/components/mc/ucc_mc.h +++ b/src/components/mc/ucc_mc.h @@ -37,6 +37,10 @@ ucc_status_t ucc_mc_memcpy(void *dst, const void *src, size_t len, ucc_memory_type_t dst_mem, ucc_memory_type_t src_mem); +ucc_status_t ucc_mc_sync_memcpy(void *dst, const void *src, size_t len, + ucc_memory_type_t dst_mem, + ucc_memory_type_t src_mem); + ucc_status_t ucc_mc_memset(void *ptr, int value, size_t size, ucc_memory_type_t mem_type); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 1208226bda..4a0b0a49aa 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -17,6 +17,7 @@ #include "components/tl/ucc_tl_log.h" #include "utils/ucc_rcache.h" #include "core/ucc_service_coll.h" +#include "components/mc/ucc_mc.h" #define POLL_PACKED 16 #define REL_DONE ((void*)-1) @@ -98,6 +99,7 @@ typedef struct mcast_coll_comm_init_spec { int scq_moderation; int wsize; int max_eager; + int cuda_mem_enabled; void *oob; } ucc_tl_mlx5_mcast_coll_comm_init_spec_t; @@ -251,6 +253,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { ucc_rank_t rank; ucc_rank_t commsize; char *grh_buf; + ucc_mc_buffer_header_t *grh_buf_header; struct ibv_mr *grh_mr; uint16_t mcast_lid; union ibv_gid mgid; @@ -261,6 +264,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int pending_recv; struct ibv_mr *pp_mr; char *pp_buf; + ucc_mc_buffer_header_t *pp_buf_header; struct pp_packet *pp; uint32_t psn; uint32_t last_psn; @@ -293,6 +297,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int n_prep_reliable; int n_mcast_reliable; int wsize; + int cuda_mem_enabled; ucc_tl_mlx5_mcast_join_info_t *group_setup_info; ucc_service_coll_req_t *group_setup_info_req; ucc_tl_mlx5_mcast_service_coll_t service_coll; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 9696ba8c82..a6725698ce 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -283,6 +283,23 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task) } } +ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team) +{ + ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); + ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm; + ucc_coll_args_t *args = &coll_args->args; + + if ((comm->cuda_mem_enabled && + args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) || + (!comm->cuda_mem_enabled && + args->src.info.mem_type == UCC_MEMORY_TYPE_HOST)) { + return UCC_OK; + } + + return UCC_ERR_NO_RESOURCE; +} + ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { task->super.post = ucc_tl_mlx5_mcast_bcast_start; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index 74385b1573..a5725915f7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -14,4 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task); ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); +ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team); #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index f57daeab5e..7a5f21a81c 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -300,7 +300,12 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, return UCC_ERR_NO_RESOURCE; } - comm->max_inline = qp_init_attr.cap.max_inline_data; + if (comm->cuda_mem_enabled) { + /* max inline send otherwise it segfault during ibv send */ + comm->max_inline = 0; + } else { + comm->max_inline = qp_init_attr.cap.max_inline_data; + } return UCC_OK; } @@ -609,8 +614,9 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) return UCC_ERR_NO_RESOURCE; } } + if (comm->grh_buf) { - ucc_free(comm->grh_buf); + ucc_mc_free(comm->grh_buf_header); } if (comm->pp) { @@ -626,7 +632,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) } if (comm->pp_buf) { - ucc_free(comm->pp_buf); + ucc_mc_free(comm->pp_buf_header); } if (comm->call_rwr) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index 9d66f3453e..59a7a2cb5a 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -47,6 +47,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t int rc; int length; ucc_status_t status; + ucc_memory_type_t mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA + : UCC_MEMORY_TYPE_HOST; for (i = 0; i < num_packets; i++) { if (comm->params.sx_depth <= @@ -75,7 +77,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t if (zcopy) { pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset); } else { - memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length); + status = ucc_mc_sync_memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length, + mem_type, mem_type); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->lib, "failed to copy cuda buffer"); + return status; + } ssg[0].addr = (uint64_t) pp->buf; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 4522097973..b4f04d0874 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -371,6 +371,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com { ucc_status_t status = UCC_OK; void *dest; + ucc_memory_type_t mem_type; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); @@ -379,7 +380,19 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - memcpy(dest, (void*) pp->buf, pp->length); + + if (comm->cuda_mem_enabled) { + mem_type = UCC_MEMORY_TYPE_CUDA; + } else { + mem_type = UCC_MEMORY_TYPE_HOST; + } + + status = ucc_mc_sync_memcpy(dest, (void*) pp->buf, pp->length, + mem_type, mem_type); + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->lib, "failed to copy buffer"); + return status; + } } comm->r_window[pp->psn & (comm->wsize-1)] = pp; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index 402ff84472..bdf4fb2857 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -13,6 +13,17 @@ #include "mcast/tl_mlx5_mcast_helper.h" #include "mcast/tl_mlx5_mcast_service_coll.h" +static ucc_status_t ucc_tl_mlx5_check_gpudirect_driver() +{ + const char *file = "/sys/kernel/mm/memory_peers/nv_mem/version"; + + if (!access(file, F_OK)) { + return UCC_OK; + } + + return UCC_ERR_NO_RESOURCE; +} + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, @@ -88,23 +99,14 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, memcpy(&comm->params, conf_params, sizeof(*conf_params)); - comm->wsize = conf_params->wsize; - comm->max_eager = conf_params->max_eager; - comm->comm_id = team_params->id; - comm->ctx = mcast_context; - comm->grh_buf = (char *)ucc_malloc(GRH_LENGTH * sizeof(char), "grh_buf"); - if (!comm->grh_buf) { - status = UCC_ERR_NO_MEMORY; - goto cleanup; - } + comm->wsize = conf_params->wsize; + comm->max_eager = conf_params->max_eager; + comm->cuda_mem_enabled = conf_params->cuda_mem_enabled; + comm->comm_id = team_params->id; + comm->ctx = mcast_context; - memset(comm->grh_buf, 0, GRH_LENGTH); - - comm->grh_mr = ibv_reg_mr(mcast_context->pd, comm->grh_buf, GRH_LENGTH, - IBV_ACCESS_REMOTE_WRITE | - IBV_ACCESS_LOCAL_WRITE); - if (!comm->grh_mr) { - tl_error(mcast_context->lib, "could not register memory for GRH, errno %d", errno); + if (comm->cuda_mem_enabled && (UCC_OK != ucc_tl_mlx5_check_gpudirect_driver())) { + tl_warn(mcast_context->lib, "cuda-aware mcast not available as gpu direct is not ready"); status = UCC_ERR_NO_RESOURCE; goto cleanup; } @@ -162,9 +164,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_comm_t *comm) { - ucc_status_t status; - size_t page_size; - int buf_size, i, ret; + ucc_status_t status; + size_t page_size; + int buf_size, i, ret; + ucc_memory_type_t supported_mem_type; status = ucc_tl_mlx5_mcast_init_qps(comm->ctx, comm); if (UCC_OK != status) { @@ -197,19 +200,49 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ comm->pending_recv = 0; comm->buf_n = comm->params.rx_depth * 2; - ret = posix_memalign((void**) &comm->pp_buf, page_size, buf_size * comm->buf_n); - if (ret) { - tl_error(comm->ctx->lib, "posix_memalign failed"); - return UCC_ERR_NO_MEMORY; + supported_mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA + : UCC_MEMORY_TYPE_HOST; + + status = ucc_mc_alloc(&comm->grh_buf_header, GRH_LENGTH * + sizeof(char), UCC_MEMORY_TYPE_HOST); + comm->grh_buf = comm->grh_buf_header->addr; + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->ctx->lib, "failed to allocate cuda memory"); + return status; + } + + status = ucc_mc_memset(comm->grh_buf, 0, GRH_LENGTH, UCC_MEMORY_TYPE_HOST); + if (status != UCC_OK) { + tl_error(comm->ctx->lib, "could not cuda memset"); + goto error; + } + + comm->grh_mr = ibv_reg_mr(comm->ctx->pd, comm->grh_buf, GRH_LENGTH, + IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE); + if (!comm->grh_mr) { + tl_error(comm->ctx->lib, "could not register device memory for GRH, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; + goto error; + } + + status = ucc_mc_alloc(&comm->pp_buf_header, buf_size * comm->buf_n, supported_mem_type); + comm->pp_buf = comm->pp_buf_header->addr; + if (ucc_unlikely(status != UCC_OK)) { + tl_error(comm->ctx->lib, "failed to allocate cuda memory"); + goto error; + } + + status = ucc_mc_memset(comm->pp_buf, 0, buf_size * comm->buf_n, supported_mem_type); + if (status != UCC_OK) { + tl_error(comm->ctx->lib, "could not memset"); + goto error; } - memset(comm->pp_buf, 0, buf_size * comm->buf_n); - comm->pp_mr = ibv_reg_mr(comm->ctx->pd, comm->pp_buf, buf_size * comm->buf_n, IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE); if (!comm->pp_mr) { - tl_error(comm->ctx->lib, "could not register pp_buf mr, errno %d", errno); - status = UCC_ERR_NO_MEMORY; + tl_error(comm->ctx->lib, "could not register pp_buf device mr, errno %d", errno); + status = UCC_ERR_NO_RESOURCE; goto error; } diff --git a/src/components/tl/mlx5/tl_mlx5.c b/src/components/tl/mlx5/tl_mlx5.c index 75e6f517cc..38ac2e4ca0 100644 --- a/src/components/tl/mlx5/tl_mlx5.c +++ b/src/components/tl/mlx5/tl_mlx5.c @@ -92,6 +92,10 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = { ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.max_eager), UCC_CONFIG_TYPE_MEMUNITS}, + {"MCAST_CUDA_MEM_ENABLE", "0", "Enable GPU CUDA memory support for Mcast. GPUDirect RDMA must be enabled", + ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.cuda_mem_enabled), + UCC_CONFIG_TYPE_INT}, + {NULL}}; static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = { diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index e918be166e..a8add9715e 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -19,6 +19,11 @@ ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, tl_trace(team->context->lib, "mcast bcast not supported for active sets"); return UCC_ERR_NOT_SUPPORTED; } + + if (UCC_OK != ucc_tl_mlx5_mcast_check_memory_type_cap(coll_args, team)) { + tl_trace(team->context->lib, "mcast bcast not compatible with this memory type"); + return UCC_ERR_NOT_SUPPORTED; + } task = ucc_tl_mlx5_get_task(coll_args, team);