Skip to content

Commit

Permalink
TL/MLX5: address sergey comments part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jul 15, 2024
1 parent 91205f8 commit 13e6845
Show file tree
Hide file tree
Showing 16 changed files with 91 additions and 96 deletions.
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall_mkeys.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/alltoall/alltoall_mkeys.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
155 changes: 75 additions & 80 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -219,33 +219,34 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_staging_based_allgather(ucc_tl_m

if (MCAST_ALLGATHER_IN_PROGRESS(req, comm)) {
return UCC_INPROGRESS;
} else {
if (ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme) {
if (!req->barrier_req) {
// mcast operations are done and now go to barrier
status = comm->service_coll.barrier_post(comm->p2p_ctx, &req->barrier_req);
if (status != UCC_OK) {
return status;
}
tl_trace(comm->lib, "mcast operations are done and now go to barrier");
return UCC_INPROGRESS;
} else {
status = comm->service_coll.coll_test(req->barrier_req);
if (status == UCC_OK) {
req->barrier_req = NULL;
tl_trace(comm->lib, "barrier at the end of mcast allgather is completed");
} else {
return status;
}
}
}

if (ONE_SIDED_SYNCHRONOUS_PROTO == req->one_sided_reliability_scheme) {
/* mcast operations are all done, now wait until all the processes
* are done with their mcast operations */
if (!req->barrier_req) {
// mcast operations are done and now go to barrier
status = comm->service_coll.barrier_post(comm->p2p_ctx, &req->barrier_req);
if (status != UCC_OK) {
return status;
}
tl_trace(comm->lib, "mcast operations are done and now go to barrier");
}

/* this task is completed */
return UCC_OK;
status = comm->service_coll.coll_test(req->barrier_req);
if (status == UCC_OK) {
req->barrier_req = NULL;
tl_trace(comm->lib, "barrier at the end of mcast allgather is completed");
} else {
return status;
}
}

/* this task is completed */
return UCC_OK;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req)
static inline ucc_status_t ucc_tl_mlx5_mcast_allgather_test(ucc_tl_mlx5_mcast_coll_req_t* req)
{
ucc_status_t status;

Expand All @@ -272,7 +273,49 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_co

ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
return UCC_OK;
}

void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle;
ucc_status_t status;

ucc_assert(req != NULL);

if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) {
/* it is not this task's turn for progress */
ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter);
return;
}

status = ucc_tl_mlx5_mcast_allgather_test(task->coll_mcast.req_handle);
if (UCC_INPROGRESS == status) {
return;
} else if (UCC_OK == status) {
coll_task->status = UCC_OK;
req->comm->allgather_comm.under_progress_counter++;
ucc_free(req);
task->coll_mcast.req_handle = NULL;
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed:%d", status);
coll_task->status = status;
if (req->rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg);
req->rreg = NULL;
}
if (req->recv_rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->recv_rreg);
req->recv_rreg = NULL;
}
ucc_free(req);
}
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
{
ucc_coll_task_t *coll_task = &(task->super);
ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task);
ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast;
ucc_coll_args_t *args = &TASK_ARGS(task);
Expand All @@ -286,21 +329,21 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_tl_mlx5_mcast_coll_req_t *req;


if (!data_size) {
coll_task->status = UCC_OK;
return ucc_task_complete(coll_task);
}

task->coll_mcast.req_handle = NULL;

tl_trace(comm->lib, "MCAST allgather start, sbuf %p, rbuf %p, size %ld, comm %d, "
tl_trace(comm->lib, "MCAST allgather init, sbuf %p, rbuf %p, size %ld, comm %d, "
"comm_size %d, counter %d",
sbuf, rbuf, data_size, comm->comm_id, comm->commsize, comm->allgather_comm.coll_counter);

req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req");
if (!req) {
tl_warn(comm->lib, "malloc failed");
tl_error(comm->lib, "malloc failed");
status = UCC_ERR_NO_MEMORY;
goto failed;
}

Expand Down Expand Up @@ -331,14 +374,15 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
goto failed;
}

req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet;
req->last_pkt_len = req->length % comm->max_per_packet;

ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet);

if (req->proto == MCAST_PROTO_ZCOPY) {
/* register the send buffer */
status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, &reg);
if (UCC_OK != status) {
tl_error(comm->lib, "sendbuf registeration failed");
goto failed;
}
req->rreg = reg;
Expand All @@ -361,64 +405,15 @@ ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)

task->coll_mcast.req_handle = req;
coll_task->status = UCC_INPROGRESS;
task->super.post = ucc_tl_mlx5_mcast_allgather_start;
task->super.progress = ucc_tl_mlx5_mcast_allgather_progress;
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);

failed:
tl_warn(UCC_TASK_LIB(task), "mcast start allgather failed:%d", status);
tl_warn(UCC_TASK_LIB(task), "mcast init allgather failed:%d", status);
if (req) {
ucc_free(req);
}
coll_task->status = status;
return ucc_task_complete(coll_task);
}

void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle;
ucc_status_t status;

if (req != NULL) {
if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) {
/* it is not this task's turn for progress */
ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter);
return;
}

status = ucc_tl_mlx5_mcast_test_allgather(task->coll_mcast.req_handle);
if (UCC_INPROGRESS == status) {
return;
} else if (UCC_OK == status) {
coll_task->status = UCC_OK;
req->comm->allgather_comm.under_progress_counter++;
ucc_free(req);
task->coll_mcast.req_handle = NULL;
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed:%d", status);
coll_task->status = status;
if (req->rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg);
req->rreg = NULL;
}
if (req->recv_rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->recv_rreg);
req->recv_rreg = NULL;
}
ucc_free(req);
ucc_task_complete(coll_task);
}
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed, mcast coll not initialized");
coll_task->status = UCC_ERR_NO_RESOURCE;
ucc_task_complete(coll_task);
}
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task)
{
task->super.post = ucc_tl_mlx5_mcast_allgather_start;
task->super.progress = ucc_tl_mlx5_mcast_allgather_progress;

return UCC_OK;
return status;
}

4 changes: 2 additions & 2 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down Expand Up @@ -98,7 +98,7 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {

{"MCAST_ONE_SIDED_RELIABILITY_ENABLE", "1", "Enable one sided reliability for mcast",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_INT},
UCC_CONFIG_TYPE_BOOL},

{NULL}};

Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_context.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_dm.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_dm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_lib.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_pd.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_pd.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_team.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_wqe.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down
2 changes: 1 addition & 1 deletion src/components/tl/mlx5/tl_mlx5_wqe.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See file LICENSE for terms.
*/
Expand Down

0 comments on commit 13e6845

Please sign in to comment.