Skip to content

Commit

Permalink
TL/MLX5: Addressing Sam's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jul 3, 2024
1 parent 044e785 commit 569af1f
Show file tree
Hide file tree
Showing 13 changed files with 403 additions and 425 deletions.
297 changes: 148 additions & 149 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h

Large diffs are not rendered by default.

199 changes: 101 additions & 98 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
#include "tl_mlx5_mcast_allgather.h"
#include <inttypes.h>

/* 32 here is the bit count of ib mcast packet's immediate data */
#define TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT 32

static inline void ucc_tl_mlx5_mcast_get_max_allgather_packet_count(int *max_count)
{
int pow2;
int tmp;
pow2 = log(ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) / log(2);
tmp = TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT - pow2;
pow2 = log(ONE_SIDED_MAX_ALLGATHER_COUNTER) / log(2);
tmp = tmp - pow2;
*max_count = pow(2, tmp);
}

#define MCAST_ALLGATHER_IN_PROGRESS(_req, _comm) \
(_req->to_send || _req->to_recv || _comm->pending_send || \
_comm->one_sided.rdma_read_in_progress || (NULL != _req->allgather_rkeys_req)) \

static inline ucc_status_t ucc_tl_mlx5_mcast_check_collective(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
{
Expand Down Expand Up @@ -80,7 +98,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reset_reliablity(ucc_tl_mlx5_mcast_
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_status_t status;

ucc_assert(req->ag_counter == comm->ag_under_progress_counter);
ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter);

if (comm->one_sided.reliability_enabled && !comm->one_sided.reliability_ready) {
/* initialize the structures needed by reliablity protocol */
Expand Down Expand Up @@ -129,7 +147,7 @@ static inline void ucc_tl_mlx5_mcast_init_async_reliability_slots(ucc_tl_mlx5_mc
ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm;
void *dest;

ucc_assert(req->ag_counter == comm->ag_under_progress_counter);
ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter);

if (ONE_SIDED_ASYNCHRONOUS_PROTO == req->one_sided_reliability_scheme &&
ONE_SIDED_INVALID == comm->one_sided.slots_state) {
Expand All @@ -147,7 +165,7 @@ static inline void ucc_tl_mlx5_mcast_init_async_reliability_slots(ucc_tl_mlx5_mc
}
}

static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll_req_t *req)
static inline ucc_status_t ucc_tl_mlx5_mcast_do_staging_based_allgather(ucc_tl_mlx5_mcast_coll_req_t *req)
{
ucc_status_t status = UCC_OK;
ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm;
Expand All @@ -162,12 +180,16 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll
}

if (req->to_send || req->to_recv) {
ucc_assert(comm->max_push_send >= comm->pending_send);
ucc_assert(comm->allgather_comm.max_push_send >= comm->pending_send);
if (req->to_send &&
(comm->max_push_send - comm->pending_send) > 0) {
ucc_tl_mlx5_mcast_send_collective(comm, req, ucc_min(comm->max_push_send -
comm->pending_send, req->to_send),
zcopy, UCC_COLL_TYPE_ALLGATHER, -1, SIZE_MAX);
(comm->allgather_comm.max_push_send - comm->pending_send) > 0) {
status = ucc_tl_mlx5_mcast_send_collective(comm, req, ucc_min(comm->allgather_comm.max_push_send -
comm->pending_send, req->to_send),
zcopy, UCC_COLL_TYPE_ALLGATHER, -1, SIZE_MAX);
if (status < 0) {
tl_error(comm->lib, "a failure happend during send packets");
return status;
}
}

ucc_tl_mlx5_mcast_init_async_reliability_slots(req);
Expand Down Expand Up @@ -223,11 +245,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll
}
}

ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req)
static inline ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req)
{
ucc_status_t status;

status = ucc_tl_mlx5_mcast_do_allgather(req);
status = ucc_tl_mlx5_mcast_do_staging_based_allgather(req);
if (UCC_OK == status) {
ucc_assert(req->comm->ctx != NULL);
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg);
Expand All @@ -248,17 +270,44 @@ ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req)
return status;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void *rbuf, int size,
ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req)
ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_status_t status;
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task);
ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast;
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_datatype_t dt = args->src.info.datatype;
size_t count = args->src.info.count;
ucc_status_t status = UCC_OK;
size_t data_size = ucc_dt_size(dt) * count;
void *sbuf = args->src.info.buffer;
void *rbuf = args->dst.info.buffer;
ucc_tl_mlx5_mcast_coll_comm_t *comm = team->mcast_comm;
ucc_tl_mlx5_mcast_reg_t *reg = NULL;
ucc_tl_mlx5_mcast_coll_req_t *req;


if (!data_size) {
coll_task->status = UCC_OK;
return ucc_task_complete(coll_task);
}

task->coll_mcast.req_handle = NULL;

tl_trace(comm->lib, "MCAST allgather start, sbuf %p, rbuf %p, size %ld, comm %d, "
"comm_size %d, counter %d",
sbuf, rbuf, data_size, comm->comm_id, comm->commsize, comm->allgather_comm.coll_counter);

req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req");
if (!req) {
tl_warn(comm->lib, "malloc failed");
goto failed;
}

req->comm = comm;
req->ptr = sbuf;
req->rptr = rbuf;
req->length = size;
req->length = data_size;
req->mr = comm->pp_mr;
req->rreg = NULL;
/* - zero copy protocol only provides zero copy design at sender side
Expand All @@ -267,27 +316,19 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void
req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER :
MCAST_PROTO_ZCOPY;

if (comm->commsize > ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) {
tl_warn(comm->lib,
"team size is %d but max supported team size of mcast allgather is %d",
comm->commsize, ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE);
return UCC_ERR_NOT_SUPPORTED;
}
assert(comm->commsize <= ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE);

req->offset = 0;
req->num_packets = (req->length + comm->max_per_packet - 1)/comm->max_per_packet;
req->num_packets = ucc_max(1, (req->length + comm->max_per_packet - 1)/comm->max_per_packet);

if (req->num_packets == 0) {
req->num_packets = 1;
}
ucc_tl_mlx5_mcast_get_max_allgather_packet_count(&comm->allgather_comm.max_num_packets);

ONE_SIDED_MAX_PACKET_COUNT(comm->ag_max_num_packets);

if (comm->ag_max_num_packets < req->num_packets) {
if (comm->allgather_comm.max_num_packets < req->num_packets) {
tl_warn(comm->lib,
"msg size is %ld but max supported msg size of mcast allgather is %d",
req->length, comm->ag_max_num_packets * comm->max_per_packet);
return UCC_ERR_NOT_SUPPORTED;
req->length, comm->allgather_comm.max_num_packets * comm->max_per_packet);
status = UCC_ERR_NOT_SUPPORTED;
goto failed;
}

req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet;
Expand All @@ -298,7 +339,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void
/* register the send buffer */
status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, &reg);
if (UCC_OK != status) {
return status;
ucc_free(req);
goto failed;
}
req->rreg = reg;
req->mr = reg->mr;
Expand All @@ -312,84 +354,32 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void
req->one_sided_reliability_scheme = ONE_SIDED_NO_RELIABILITY;
}

req->ag_counter = comm->ag_counter;
req->ag_counter = comm->allgather_comm.coll_counter;
req->to_send = req->num_packets;
req->to_recv = comm->commsize * req->num_packets;

comm->ag_counter++;
return UCC_OK;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_coll_do_allgather(void* sbuf, void *rbuf, int size,
ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t **task_req_handle)
{
ucc_tl_mlx5_mcast_coll_req_t *req;
ucc_status_t status;
comm->allgather_comm.coll_counter++;

tl_trace(comm->lib, "MCAST allgather start, sbuf %p, rbuf %p, size %d, comm %d, "
"comm_size %d, counter %d",
sbuf, rbuf, size, comm->comm_id, comm->commsize, comm->ag_counter);

req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req");
if (!req) {
tl_error(comm->lib, "malloc failed");
return UCC_ERR_NO_MEMORY;
}

status = ucc_tl_mlx5_mcast_prepare_allgather(sbuf, rbuf, size, comm, req);
if (UCC_OK != status) {
tl_warn(comm->lib, "prepare mcast allgather failed");
ucc_free(req);
return status;
}

status = UCC_INPROGRESS;

*task_req_handle = req;

return status;
}

ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task);
ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast;
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_datatype_t dt = args->src.info.datatype;
size_t count = args->src.info.count;
ucc_status_t status = UCC_OK;
size_t data_size = ucc_dt_size(dt) * count;
void *sbuf = args->src.info.buffer;
void *rbuf = args->dst.info.buffer;
ucc_tl_mlx5_mcast_coll_comm_t *comm = team->mcast_comm;

task->coll_mcast.req_handle = NULL;

status = ucc_tl_mlx5_mcast_coll_do_allgather(sbuf, rbuf, data_size, comm, &task->coll_mcast.req_handle);
if (status < 0) {
tl_warn(UCC_TASK_LIB(task), "do mcast allgather failed:%d", status);
coll_task->status = status;
return ucc_task_complete(coll_task);
}
task->coll_mcast.req_handle = req;
coll_task->status = UCC_INPROGRESS;
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);

failed:
tl_warn(UCC_TASK_LIB(task), "mcast start allgather failed:%d", status);
coll_task->status = status;

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super);
return ucc_task_complete(coll_task);
}

void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle;
ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t);
ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle;
ucc_status_t status;

if (task->coll_mcast.req_handle != NULL) {
req = task->coll_mcast.req_handle;
if (req->ag_counter != req->comm->ag_under_progress_counter) {
if (req != NULL) {
if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) {
/* it is not this task's turn for progress */
ucc_assert(req->comm->ag_under_progress_counter < req->ag_counter);
ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter);
return;
}

Expand All @@ -398,14 +388,27 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task)
return;
} else if (UCC_OK == status) {
coll_task->status = UCC_OK;
req->comm->ag_under_progress_counter++;
req->comm->allgather_comm.under_progress_counter++;
ucc_free(req);
task->coll_mcast.req_handle = NULL;
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed:%d", status);
coll_task->status = status;
if (req->rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg);
req->rreg = NULL;
}
if (req->recv_rreg) {
ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->recv_rreg);
req->recv_rreg = NULL;
}
ucc_free(req);
ucc_task_complete(coll_task);
}
} else {
tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed, mcast coll not initialized");
coll_task->status = UCC_ERR_NO_RESOURCE;
ucc_task_complete(coll_task);
}

return;
Expand Down
6 changes: 0 additions & 6 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@
#include "tl_mlx5_mcast.h"
#include "tl_mlx5_coll.h"

#define MCAST_ALLGATHER_IN_PROGRESS(_req, _comm) \
(_req->to_send || _req->to_recv || _comm->pending_send || \
_comm->one_sided.rdma_read_in_progress || (NULL != _req->allgather_rkeys_req)) \

ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task);

ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* _req);

#endif
Loading

0 comments on commit 569af1f

Please sign in to comment.