Skip to content

Commit

Permalink
TL/MLX5: addressing all sam's comments on PR 989
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jul 11, 2024
1 parent 11edf29 commit 9d12794
Show file tree
Hide file tree
Showing 16 changed files with 188 additions and 98 deletions.
3 changes: 3 additions & 0 deletions src/components/mc/base/ucc_mc_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ typedef struct ucc_mc_ops {
ucc_status_t (*memcpy)(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);
ucc_status_t (*sync_memcpy)(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);
ucc_status_t (*memset)(void *dst, int value, size_t len);
ucc_status_t (*flush)();
} ucc_mc_ops_t;
Expand Down
1 change: 1 addition & 0 deletions src/components/mc/cpu/mc_cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ ucc_mc_cpu_t ucc_mc_cpu = {
.super.ops.mem_alloc = ucc_mc_cpu_mem_pool_alloc_with_init,
.super.ops.mem_free = ucc_mc_cpu_mem_pool_free,
.super.ops.memcpy = ucc_mc_cpu_memcpy,
.super.ops.sync_memcpy = ucc_mc_cpu_memcpy,
.super.ops.memset = ucc_mc_cpu_memset,
.super.ops.flush = NULL,
.super.config_table =
Expand Down
23 changes: 23 additions & 0 deletions src/components/mc/cuda/mc_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,28 @@ ucc_mc_cuda_mem_pool_alloc_with_init(ucc_mc_buffer_header_t **h_ptr,
}
}

static ucc_status_t ucc_mc_cuda_sync_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem)
{
ucc_status_t status;

ucc_assert(dst_mem == UCC_MEMORY_TYPE_CUDA ||
src_mem == UCC_MEMORY_TYPE_CUDA ||
dst_mem == UCC_MEMORY_TYPE_CUDA_MANAGED ||
src_mem == UCC_MEMORY_TYPE_CUDA_MANAGED);

status = CUDA_FUNC(cudaMemcpy(dst, src, len, cudaMemcpyDefault));
if (ucc_unlikely(status != UCC_OK)) {
mc_error(&ucc_mc_cuda.super,
"failed to launch cudaMemcpy, dst %p, src %p, len %zd",
dst, src, len);
return status;
}

return status;
}

static ucc_status_t ucc_mc_cuda_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem)
Expand Down Expand Up @@ -432,6 +454,7 @@ ucc_mc_cuda_t ucc_mc_cuda = {
.super.ops.mem_alloc = ucc_mc_cuda_mem_pool_alloc_with_init,
.super.ops.mem_free = ucc_mc_cuda_mem_pool_free,
.super.ops.memcpy = ucc_mc_cuda_memcpy,
.super.ops.sync_memcpy = ucc_mc_cuda_sync_memcpy,
.super.ops.memset = ucc_mc_cuda_memset,
.super.ops.flush = ucc_mc_cuda_flush_not_supported,
.super.config_table =
Expand Down
1 change: 1 addition & 0 deletions src/components/mc/rocm/mc_rocm.c
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ ucc_mc_rocm_t ucc_mc_rocm = {
.super.ops.mem_alloc = ucc_mc_rocm_mem_pool_alloc_with_init,
.super.ops.mem_free = ucc_mc_rocm_mem_pool_free,
.super.ops.memcpy = ucc_mc_rocm_memcpy,
.super.ops.sync_memcpy = ucc_mc_rocm_memcpy,
.super.ops.memset = ucc_mc_rocm_memset,
.super.config_table =
{
Expand Down
24 changes: 24 additions & 0 deletions src/components/mc/ucc_mc.c
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,30 @@ ucc_status_t ucc_mc_free(ucc_mc_buffer_header_t *h_ptr)
return mc_ops[mt]->mem_free(h_ptr);
}

UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_sync_memcpy,
(dst, src, len, dst_mem, src_mem), void *dst,
const void *src, size_t len, ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem)

{
ucc_memory_type_t mt;
if (src_mem == UCC_MEMORY_TYPE_UNKNOWN ||
dst_mem == UCC_MEMORY_TYPE_UNKNOWN) {
return UCC_ERR_INVALID_PARAM;
} else if (src_mem == UCC_MEMORY_TYPE_HOST &&
dst_mem == UCC_MEMORY_TYPE_HOST) {
UCC_CHECK_MC_AVAILABLE(UCC_MEMORY_TYPE_HOST);
return mc_ops[UCC_MEMORY_TYPE_HOST]->memcpy(dst, src, len,
UCC_MEMORY_TYPE_HOST,
UCC_MEMORY_TYPE_HOST);
}
/* take any non host MC component */
mt = (dst_mem == UCC_MEMORY_TYPE_HOST) ? src_mem : dst_mem;
mt = (mt == UCC_MEMORY_TYPE_CUDA_MANAGED) ? UCC_MEMORY_TYPE_CUDA : mt;
UCC_CHECK_MC_AVAILABLE(mt);
return mc_ops[mt]->sync_memcpy(dst, src, len, dst_mem, src_mem);
}

UCC_MC_PROFILE_FUNC(ucc_status_t, ucc_mc_memcpy,
(dst, src, len, dst_mem, src_mem), void *dst,
const void *src, size_t len, ucc_memory_type_t dst_mem,
Expand Down
4 changes: 4 additions & 0 deletions src/components/mc/ucc_mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ ucc_status_t ucc_mc_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);

ucc_status_t ucc_mc_sync_memcpy(void *dst, const void *src, size_t len,
ucc_memory_type_t dst_mem,
ucc_memory_type_t src_mem);

ucc_status_t ucc_mc_memset(void *ptr, int value, size_t size,
ucc_memory_type_t mem_type);

Expand Down
8 changes: 5 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "components/tl/ucc_tl_log.h"
#include "utils/ucc_rcache.h"
#include "core/ucc_service_coll.h"
#include "utils/arch/cuda_def.h"
#include "components/mc/ucc_mc.h"

#define POLL_PACKED 16
#define REL_DONE ((void*)-1)
Expand Down Expand Up @@ -91,7 +91,7 @@ typedef struct mcast_coll_comm_init_spec {
int scq_moderation;
int wsize;
int max_eager;
int device_mem_enabled;
int cuda_mem_enabled;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -196,17 +196,19 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
ucc_rank_t rank;
ucc_rank_t commsize;
char *grh_buf;
ucc_mc_buffer_header_t *grh_buf_header;
struct ibv_mr *grh_mr;
uint16_t mcast_lid;
union ibv_gid mgid;
unsigned max_inline;
size_t max_eager;
int device_mem_enabled;
int cuda_mem_enabled;
int max_per_packet;
int pending_send;
int pending_recv;
struct ibv_mr *pp_mr;
char *pp_buf;
ucc_mc_buffer_header_t *pp_buf_header;
struct pp_packet *pp;
uint32_t psn;
uint32_t last_psn;
Expand Down
17 changes: 17 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,23 @@ void ucc_tl_mlx5_mcast_collective_progress(ucc_coll_task_t *coll_task)
}
}

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team)
{
ucc_tl_mlx5_team_t *mlx5_team = ucc_derived_of(team, ucc_tl_mlx5_team_t);
ucc_tl_mlx5_mcast_coll_comm_t *comm = mlx5_team->mcast->mcast_comm;
ucc_coll_args_t *args = &coll_args->args;

if ((comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_CUDA) ||
(!comm->cuda_mem_enabled &&
args->src.info.mem_type == UCC_MEMORY_TYPE_HOST)) {
return UCC_OK;
}

return UCC_ERR_NO_RESOURCE;
}

ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_bcast_start;
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req);

ucc_status_t ucc_tl_mlx5_mcast_check_memory_type_cap(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team);
#endif
14 changes: 3 additions & 11 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx,
return UCC_ERR_NO_RESOURCE;
}

if (comm->device_mem_enabled) {
if (comm->cuda_mem_enabled) {
/* max inline send otherwise it segfault during ibv send */
comm->max_inline = 0;
} else {
Expand Down Expand Up @@ -482,11 +482,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
}

if (comm->grh_buf) {
if (comm->device_mem_enabled) {
cudaFree(comm->grh_buf);
} else {
ucc_free(comm->grh_buf);
}
ucc_mc_free(comm->grh_buf_header);
}

if (comm->pp) {
Expand All @@ -502,11 +498,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm)
}

if (comm->pp_buf) {
if (comm->device_mem_enabled) {
cudaFree(comm->pp_buf);
} else {
ucc_free(comm->pp_buf);
}
ucc_mc_free(comm->pp_buf_header);
}

if (comm->call_rwr) {
Expand Down
12 changes: 7 additions & 5 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
int rc;
int length;
ucc_status_t status;
ucc_memory_type_t mem_type = comm->cuda_mem_enabled ? UCC_MEMORY_TYPE_CUDA
: UCC_MEMORY_TYPE_HOST;

for (i = 0; i < num_packets; i++) {
if (comm->params.sx_depth <=
Expand Down Expand Up @@ -75,11 +77,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t
if (zcopy) {
pp->context = (uintptr_t) PTR_OFFSET(req->ptr, offset);
} else {
if (comm->device_mem_enabled) {
CUDA_FUNC(cudaMemcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset),
length, cudaMemcpyDeviceToDevice));
} else {
memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length);
status = ucc_mc_sync_memcpy((void*) pp->buf, PTR_OFFSET(req->ptr, offset), length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy cuda buffer");
return status;
}
ssg[0].addr = (uint64_t) pp->buf;
}
Expand Down
15 changes: 12 additions & 3 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com
{
ucc_status_t status = UCC_OK;
void *dest;
ucc_memory_type_t mem_type;
ucc_assert(pp->psn >= req->start_psn &&
pp->psn < req->start_psn + req->num_packets);

Expand All @@ -379,10 +380,18 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com

if (pp->length > 0 ) {
dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm);
if (comm->device_mem_enabled) {
cudaMemcpy(dest, (void*) pp->buf, pp->length, cudaMemcpyDeviceToDevice);

if (comm->cuda_mem_enabled) {
mem_type = UCC_MEMORY_TYPE_CUDA;
} else {
memcpy(dest, (void*) pp->buf, pp->length);
mem_type = UCC_MEMORY_TYPE_HOST;
}

status = ucc_mc_sync_memcpy(dest, (void*) pp->buf, pp->length,
mem_type, mem_type);
if (ucc_unlikely(status != UCC_OK)) {
tl_error(comm->lib, "failed to copy buffer");
return status;
}
}

Expand Down
Loading

0 comments on commit 9d12794

Please sign in to comment.