Skip to content

Commit

Permalink
TL/MLX5: add device mem mcast bcast
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Aug 21, 2024
1 parent 777df69 commit 7798792
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 36 deletions.
3 changes: 3 additions & 0 deletions src/components/mc/base/ucc_mc_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ typedef struct ucc_mc_ops {
ucc_status_t (*memcpy)(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);
ucc_status_t (*sync_memcpy)(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);
ucc_status_t (*memset)(void *dst, int value, size_t len);
ucc_status_t (*flush)();
} ucc_mc_ops_t;
Expand Down
1 change: 1 addition & 0 deletions src/components/mc/cpu/mc_cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ ucc_mc_cpu_t ucc_mc_cpu = {
.super.ops.mem_alloc = ucc_mc_cpu_mem_pool_alloc_with_init,
.super.ops.mem_free = ucc_mc_cpu_mem_pool_free,
.super.ops.memcpy = ucc_mc_cpu_memcpy,
.super.ops.sync_memcpy = ucc_mc_cpu_memcpy,
.super.ops.memset = ucc_mc_cpu_memset,
.super.ops.flush = NULL,
.super.config_table =
Expand Down
31 changes: 27 additions & 4 deletions src/components/mc/cuda/mc_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,31 @@ ucc_mc_cuda_mem_pool_alloc_with_init(ucc_mc_buffer_header_t **h_ptr,
}
}

static ucc_status_t ucc_mc_cuda_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem)
static ucc_status_t ucc_mc_cuda_sync_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem)
{
ucc_status_t status;

ucc_assert(dst_mem == UCC_MEMORY_TYPE_CUDA ||
src_mem == UCC_MEMORY_TYPE_CUDA ||
dst_mem == UCC_MEMORY_TYPE_CUDA_MANAGED ||
src_mem == UCC_MEMORY_TYPE_CUDA_MANAGED);

status = CUDA_FUNC(cudaMemcpy(dst, src, len, cudaMemcpyDefault));
if (ucc_unlikely(status != UCC_OK)) {
mc_error(&ucc_mc_cuda.super,
"failed to launch cudaMemcpy, dst %p, src %p, len %zd",
dst, src, len);
return status;
}

return status;
}

static ucc_status_t ucc_mc_cuda_async_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem)
{
ucc_status_t status;
ucc_mc_cuda_resources_t *resources;
Expand Down Expand Up @@ -431,7 +453,8 @@ ucc_mc_cuda_t ucc_mc_cuda = {
.super.ops.mem_query = ucc_mc_cuda_mem_query,
.super.ops.mem_alloc = ucc_mc_cuda_mem_pool_alloc_with_init,
.super.ops.mem_free = ucc_mc_cuda_mem_pool_free,
.super.ops.memcpy = ucc_mc_cuda_memcpy,
.super.ops.memcpy = ucc_mc_cuda_async_memcpy,
.super.ops.sync_memcpy = ucc_mc_cuda_sync_memcpy,
.super.ops.memset = ucc_mc_cuda_memset,
.super.ops.flush = ucc_mc_cuda_flush_not_supported,
.super.config_table =
Expand Down
1 change: 1 addition & 0 deletions src/components/mc/rocm/mc_rocm.c
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ ucc_mc_rocm_t ucc_mc_rocm = {
.super.ops.mem_alloc = ucc_mc_rocm_mem_pool_alloc_with_init,
.super.ops.mem_free = ucc_mc_rocm_mem_pool_free,
.super.ops.memcpy = ucc_mc_rocm_memcpy,
.super.ops.sync_memcpy = ucc_mc_rocm_memcpy,
.super.ops.memset = ucc_mc_rocm_memset,
.super.config_table =
{
Expand Down
25 changes: 25 additions & 0 deletions src/components/mc/ucc_mc.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ ucc_status_t ucc_mc_get_attr(ucc_mc_attr_t *attr, ucc_memory_type_t mem_type)
return mc->get_attr(attr);
}

/* TODO: add the flexbility to bypass the mpool if the user asks for it */
UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_alloc, (h_ptr, size, mem_type),
ucc_mc_buffer_header_t **h_ptr, size_t size,
ucc_memory_type_t mem_type)
Expand All @@ -152,6 +153,30 @@ ucc_status_t ucc_mc_free(ucc_mc_buffer_header_t *h_ptr)
return mc_ops[mt]->mem_free(h_ptr);
}

UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_sync_memcpy,
(dst, src, len, dst_mem, src_mem), void *dst,
const void *src, size_t len, ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem)

{
ucc_memory_type_t mt;
if (src_mem == UCC_MEMORY_TYPE_UNKNOWN ||
dst_mem == UCC_MEMORY_TYPE_UNKNOWN) {
return UCC_ERR_INVALID_PARAM;
} else if (src_mem == UCC_MEMORY_TYPE_HOST &&
dst_mem == UCC_MEMORY_TYPE_HOST) {
UCC_CHECK_MC_AVAILABLE(UCC_MEMORY_TYPE_HOST);
return mc_ops[UCC_MEMORY_TYPE_HOST]->memcpy(dst, src, len,
UCC_MEMORY_TYPE_HOST,
UCC_MEMORY_TYPE_HOST);
}
/* take any non host MC component */
mt = (dst_mem == UCC_MEMORY_TYPE_HOST) ? src_mem : dst_mem;
mt = (mt == UCC_MEMORY_TYPE_CUDA_MANAGED) ? UCC_MEMORY_TYPE_CUDA : mt;
UCC_CHECK_MC_AVAILABLE(mt);
return mc_ops[mt]->sync_memcpy(dst, src, len, dst_mem, src_mem);
}

UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_memcpy,
(dst, src, len, dst_mem, src_mem), void *dst,
const void *src, size_t len, ucc_memory_type_t dst_mem,
Expand Down
4 changes: 4 additions & 0 deletions src/components/mc/ucc_mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ ucc_status_t ucc_mc_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);

ucc_status_t ucc_mc_sync_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);

ucc_status_t ucc_mc_memset(void *ptr, int value, size_t size,
ucc_memory_type_t mem_type);

Expand Down
5 changes: 5 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "components/tl/ucc_tl_log.h"
#include "utils/ucc_rcache.h"
#include "core/ucc_service_coll.h"
#include "components/mc/ucc_mc.h"

#define POLL_PACKED 16
#define REL_DONE ((void*)-1)
Expand Down Expand Up @@ -98,6 +99,7 @@ typedef struct mcast_coll_comm_init_spec {
int scq_moderation;
int wsize;
int max_eager;
int cuda_mem_enabled;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -251,6 +253,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
ucc_mc_buffer_header_t *grh_buf_header;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
Expand All @@ -261,6 +264,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
ucc_mc_buffer_header_t *pp_buf_header;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
Expand Down Expand Up @@ -293,6 +297,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
int n_prep_reliable;
int n_mcast_reliable;
int wsize;
int cuda_mem_enabled;
ucc_tl_mlx5_mcast_join_info_t *group_setup_info;
ucc_service_coll_req_t *group_setup_info_req;
ucc_tl_mlx5_mcast_service_coll_t service_coll;
Expand Down
17 changes: 17 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,23 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
}
}

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm;
ucc_coll_args_t *args = &coll_args->args;

if ((comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) ||
(!comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_HOST)) {
return UCC_OK;
}

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
#endif
12 changes: 9 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,12 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
return UCC_ERR_NO_RESOURCE;
}

comm->max_inline = qp_init_attr.cap.max_inline_data;
if (comm->cuda_mem_enabled) {
/* max inline send otherwise it segfault during ibv send */
comm->max_inline = 0;
} else {
comm->max_inline = qp_init_attr.cap.max_inline_data;
}

return UCC_OK;
}
Expand Down Expand Up @@ -609,8 +614,9 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
return UCC_ERR_NO_RESOURCE;
}
}

if (comm->grh_buf) {
ucc_free(comm->grh_buf);
ucc_mc_free(comm->grh_buf_header);
}

if (comm->pp) {
Expand All @@ -626,7 +632,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
}

if (comm->pp_buf) {
ucc_free(comm->pp_buf);
ucc_mc_free(comm->pp_buf_header);
}

if (comm->call_rwr) {
Expand Down
9 changes: 8 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
int rc;
int length;
ucc_status_t status;
ucc_memory_type_t mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA
: UCC_MEMORY_TYPE_HOST;

for (i = 0; i < num_packets; i++) {
if (comm->params.sx_depth <=
Expand Down Expand Up @@ -75,7 +77,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
if (zcopy) {
pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset);
} else {
memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length);
status = ucc_mc_sync_memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy cuda buffer");
return status;
}
ssg[0].addr = (uint64_t) pp->buf;
}

Expand Down
15 changes: 14 additions & 1 deletion src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);

Expand All @@ -379,7 +380,19 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com

if (pp->length > 0 ) {
dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);
memcpy(dest, (void*) pp->buf, pp->length);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = ucc_mc_sync_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
}
}

comm->r_window[pp->psn & (comm->wsize-1)] = pp;
Expand Down
87 changes: 60 additions & 27 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
#include "mcast/tl_mlx5_mcast_helper.h"
#include "mcast/tl_mlx5_mcast_service_coll.h"

static ucc_status_t ucc_tl_mlx5_check_gpudirect_driver()
{
const char *file = "/sys/kernel/mm/memory_peers/nv_mem/version";

if (!access(file, F_OK)) {
return UCC_OK;
}

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
Expand Down Expand Up @@ -88,23 +99,14 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

memcpy(&comm->params, conf_params, sizeof(*conf_params));

comm->wsize = conf_params->wsize;
comm->max_eager = conf_params->max_eager;
comm->comm_id = team_params->id;
comm->ctx = mcast_context;
comm->grh_buf = (char *)ucc_malloc(GRH_LENGTH * sizeof(char), "grh_buf");
if (!comm->grh_buf) {
status = UCC_ERR_NO_MEMORY;
goto cleanup;
}
comm->wsize = conf_params->wsize;
comm->max_eager = conf_params->max_eager;
comm->cuda_mem_enabled = conf_params->cuda_mem_enabled;
comm->comm_id = team_params->id;
comm->ctx = mcast_context;

memset(comm->grh_buf, 0, GRH_LENGTH);

comm->grh_mr = ibv_reg_mr(mcast_context->pd, comm->grh_buf, GRH_LENGTH,
IBV_ACCESS_REMOTE_WRITE |
IBV_ACCESS_LOCAL_WRITE);
if (!comm->grh_mr) {
tl_error(mcast_context->lib, "could not register memory for GRH, errno %d", errno);
if (comm->cuda_mem_enabled && (UCC_OK != ucc_tl_mlx5_check_gpudirect_driver())) {
tl_warn(mcast_context->lib, "cuda-aware mcast not available as gpu direct is not ready");
status = UCC_ERR_NO_RESOURCE;
goto cleanup;
}
Expand Down Expand Up @@ -162,9 +164,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context,

ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_comm_t *comm)
{
ucc_status_t status;
size_t page_size;
int buf_size, i, ret;
ucc_status_t status;
size_t page_size;
int buf_size, i, ret;
ucc_memory_type_t supported_mem_type;

status = ucc_tl_mlx5_mcast_init_qps(comm->ctx, comm);
if (UCC_OK != status) {
Expand Down Expand Up @@ -197,19 +200,49 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_
comm->pending_recv = 0;
comm->buf_n = comm->params.rx_depth * 2;

ret = posix_memalign((void**) &comm->pp_buf, page_size, buf_size * comm->buf_n);
if (ret) {
tl_error(comm->ctx->lib, "posix_memalign failed");
return UCC_ERR_NO_MEMORY;
supported_mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA
: UCC_MEMORY_TYPE_HOST;

status = ucc_mc_alloc(&comm->grh_buf_header, GRH_LENGTH *
sizeof(char), UCC_MEMORY_TYPE_HOST);
comm->grh_buf = comm->grh_buf_header->addr;
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->ctx->lib, "failed to allocate cuda memory");
return status;
}

status = ucc_mc_memset(comm->grh_buf, 0, GRH_LENGTH, UCC_MEMORY_TYPE_HOST);
if (status != UCC_OK) {
tl_error(comm->ctx->lib, "could not cuda memset");
goto error;
}

comm->grh_mr = ibv_reg_mr(comm->ctx->pd, comm->grh_buf, GRH_LENGTH,
IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE);
if (!comm->grh_mr) {
tl_error(comm->ctx->lib, "could not register device memory for GRH, errno %d", errno);
status = UCC_ERR_NO_RESOURCE;
goto error;
}

status = ucc_mc_alloc(&comm->pp_buf_header, buf_size * comm->buf_n, supported_mem_type);
comm->pp_buf = comm->pp_buf_header->addr;
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->ctx->lib, "failed to allocate cuda memory");
goto error;
}

status = ucc_mc_memset(comm->pp_buf, 0, buf_size * comm->buf_n, supported_mem_type);
if (status != UCC_OK) {
tl_error(comm->ctx->lib, "could not memset");
goto error;
}

memset(comm->pp_buf, 0, buf_size * comm->buf_n);

comm->pp_mr = ibv_reg_mr(comm->ctx->pd, comm->pp_buf, buf_size * comm->buf_n,
IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_LOCAL_WRITE);
if (!comm->pp_mr) {
tl_error(comm->ctx->lib, "could not register pp_buf mr, errno %d", errno);
status = UCC_ERR_NO_MEMORY;
tl_error(comm->ctx->lib, "could not register pp_buf device mr, errno %d", errno);
status = UCC_ERR_NO_RESOURCE;
goto error;
}

Expand Down
Loading

0 comments on commit 7798792

Please sign in to comment.