diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 8d60bc102b..023fd7e473 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -36,29 +36,23 @@ /* Allgather RDMA-based reliability designs */ #define ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE 1024 -#define ONE_SIDED_NO_RELIABILITY 0 -#define ONE_SIDED_SYNCHRONOUS_PROTO 1 -#define ONE_SIDED_ASYNCHRONOUS_PROTO 2 #define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */ #define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */ -#define ONE_SIDED_INVALID -1 -#define ONE_SIDED_VALID -2 -#define ONE_SIDED_PENDING_INFO -3 -#define ONE_SIDED_PENDING_DATA -4 #define ONE_SIDED_MAX_ALLGATHER_COUNTER 32 #define ONE_SIDED_MAX_CONCURRENT_LEVEL 64 -/* 32 here is the bit count of ib send immediate */ -#define ONE_SIDED_MAX_PACKET_COUNT(_max_count) \ - do { \ - int pow2; \ - int tmp; \ - pow2 = log(ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) / log(2); \ - tmp = 32 - pow2; \ - pow2 = log(ONE_SIDED_MAX_ALLGATHER_COUNTER) / log(2); \ - tmp = tmp - pow2; \ - _max_count = pow(2, tmp); \ - } while(0); +enum ucc_tl_mlx5_mcast_one_sided_slot_states { + ONE_SIDED_INVALID = -4, + ONE_SIDED_VALID, + ONE_SIDED_PENDING_INFO, + ONE_SIDED_PENDING_DATA, +}; + +enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme { + ONE_SIDED_NO_RELIABILITY = 0, + ONE_SIDED_SYNCHRONOUS_PROTO, + ONE_SIDED_ASYNCHRONOUS_PROTO +}; enum { MCAST_PROTO_EAGER, /* Internal staging buffers */ @@ -195,8 +189,8 @@ struct pp_packet { ucc_list_link_t super; uint32_t psn; int length; + int packet_counter; uintptr_t context; - uint32_t packet_counter; int qp_id; uintptr_t buf; // buffer address, initialized once }; @@ -245,25 +239,25 @@ typedef struct ucc_tl_mlx5_mcast_one_sided_reliability_comm { /* holds all the remote-addr/rkey of sendbuf from processes in the team * used in sync design. it needs to be set during each mcast-allgather call * after sendbuf registration */ - ucc_tl_mlx5_mcast_slot_mem_info_t *sendbuf_memkey_list; + ucc_tl_mlx5_mcast_slot_mem_info_t *sendbuf_memkey_list; /* counter for each target recv packet */ - uint32_t *recvd_pkts_tracker; + uint32_t *recvd_pkts_tracker; /* holds the remote targets' collective call counter. it is used to check * if remote temp slot is ready for RDMA READ in async design */ - uint32_t *remote_slot_info; - struct ibv_mr *remote_slot_info_mr; - int reliability_scheme_msg_threshold; + uint32_t *remote_slot_info; + struct ibv_mr *remote_slot_info_mr; + int reliability_scheme_msg_threshold; /* mem address and mem keys of the temp slots in async design */ - char *slots_buffer; - struct ibv_mr *slots_mr; + char *slots_buffer; + struct ibv_mr *slots_mr; /* size of a temp slot in async design */ - int slot_size; + int slot_size; /* coll req that is used during the oob service calls */ - ucc_service_coll_req_t *reliability_req; - int reliability_enabled; - int reliability_ready; - int rdma_read_in_progress; - int slots_state; + ucc_service_coll_req_t *reliability_req; + int reliability_enabled; + int reliability_ready; + int rdma_read_in_progress; + enum ucc_tl_mlx5_mcast_one_sided_slot_states slots_state; } ucc_tl_mlx5_mcast_one_sided_reliability_comm_t; typedef struct ucc_tl_mlx5_mcast_service_coll { @@ -273,6 +267,32 @@ typedef struct ucc_tl_mlx5_mcast_service_coll { ucc_status_t (*coll_test) (ucc_service_coll_req_t*); } ucc_tl_mlx5_mcast_service_coll_t; +typedef struct ucc_tl_mlx5_mcast_allgather_comm { + int under_progress_counter; + int coll_counter; + int max_num_packets; + int max_push_send; +} ucc_tl_mlx5_mcast_allgather_comm_t; + +typedef struct ucc_tl_mlx5_mcast_bcast_comm { + uint32_t last_psn; + uint32_t racks_n; + uint32_t sacks_n; + uint32_t last_acked; + uint32_t child_n; + uint32_t parent_n; + struct packet p2p_pkt[MAX_COMM_POW2]; + struct packet p2p_spkt[MAX_COMM_POW2]; + int reliable_in_progress; + int recv_drop_packet_in_progress; + ucc_rank_t parents[MAX_COMM_POW2]; + ucc_rank_t children[MAX_COMM_POW2]; + int nack_requests; + int nacks_counter; + int n_mcast_reliable; + int wsize; +} ucc_tl_mlx5_mcast_bcast_comm_t; + typedef struct ucc_tl_mlx5_mcast_coll_comm { struct pp_packet dummy_packet; ucc_tl_mlx5_mcast_coll_context_t *ctx; @@ -297,22 +317,11 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { char *pp_buf; struct pp_packet *pp; uint32_t psn; - uint32_t last_psn; - uint32_t racks_n; - uint32_t sacks_n; - uint32_t last_acked; - uint32_t naks_n; - uint32_t child_n; - uint32_t parent_n; int buf_n; - struct packet p2p_pkt[MAX_COMM_POW2]; - struct packet p2p_spkt[MAX_COMM_POW2]; ucc_list_link_t bpool; ucc_list_link_t pending_q; ucc_list_link_t posted_q; struct mcast_ctx mcast; - int reliable_in_progress; - int recv_drop_packet_in_progress; struct ibv_recv_wr *call_rwr; struct ibv_sge *call_rsgs; uint64_t timer; @@ -321,24 +330,15 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { void *p2p_ctx; ucc_base_lib_t *lib; struct sockaddr_in6 mcast_addr; - ucc_rank_t parents[MAX_COMM_POW2]; - ucc_rank_t children[MAX_COMM_POW2]; - int nack_requests; - int nacks_counter; - int n_prep_reliable; - int n_mcast_reliable; - int wsize; ucc_tl_mlx5_mcast_join_info_t *group_setup_info; ucc_service_coll_req_t *group_setup_info_req; ucc_tl_mlx5_mcast_service_coll_t service_coll; struct rdma_cm_event *event; ucc_tl_mlx5_mcast_one_sided_reliability_comm_t one_sided; int mcast_group_count; - int ag_under_progress_counter; int pending_recv_per_qp[MAX_GROUP_COUNT]; - int ag_counter; - int ag_max_num_packets; - int max_push_send; + ucc_tl_mlx5_mcast_allgather_comm_t allgather_comm; + ucc_tl_mlx5_mcast_bcast_comm_t bcast_comm; struct pp_packet *r_window[1]; // note: do not add any new variable after here } ucc_tl_mlx5_mcast_coll_comm_t; @@ -355,31 +355,31 @@ typedef struct ucc_tl_mlx5_mcast_nack_req { ucc_tl_mlx5_mcast_coll_comm_t *comm; } ucc_tl_mlx5_mcast_nack_req_t; -#define PSN_IS_IN_RANGE(_psn, _call, _comm) \ - ( \ - ((_psn >= _call->start_psn) && \ - (_psn < _call->start_psn + _call->num_packets) && \ - (_psn >= _comm->last_acked) && \ - (_psn < _comm->last_acked + _comm->wsize)) \ +#define PSN_IS_IN_RANGE(_psn, _call, _comm) \ + ( \ + ((_psn >= _call->start_psn) && \ + (_psn < _call->start_psn + _call->num_packets) && \ + (_psn >= _comm->bcast_comm.last_acked) && \ + (_psn < _comm->bcast_comm.last_acked + _comm->bcast_comm.wsize)) \ ) -#define PSN_TO_RECV_OFFSET(_psn, _call, _comm) \ - ( \ - ((ptrdiff_t)((_psn - _call->start_psn) \ - * (_comm->max_per_packet))) \ +#define PSN_TO_RECV_OFFSET(_psn, _call, _comm) \ + ( \ + ((ptrdiff_t)((_psn - _call->start_psn) \ + * (_comm->max_per_packet))) \ ) -#define PSN_TO_RECV_LEN(_psn, _call, _comm) \ - ( \ - ((_psn - _call->start_psn + 1) % \ - _call->num_packets == 0 ? _call->last_pkt_len : \ - _comm->max_per_packet) \ +#define PSN_TO_RECV_LEN(_psn, _call, _comm) \ + ( \ + ((_psn - _call->start_psn + 1) % \ + _call->num_packets == 0 ? _call->last_pkt_len : \ + _comm->max_per_packet) \ ) -#define PSN_RECEIVED(_psn, _comm) \ - ( \ - (_comm->r_window[(_psn) % \ - _comm->wsize]->psn == (_psn)) \ +#define PSN_RECEIVED(_psn, _comm) \ + ( \ + (_comm->r_window[(_psn) % \ + _comm->bcast_comm.wsize]->psn == (_psn)) \ ) typedef struct ucc_tl_mlx5_mcast_tensor { @@ -404,38 +404,37 @@ typedef struct ucc_tl_mlx5_mcast_pipelined_ag_schedule { } ucc_tl_mlx5_mcast_pipelined_ag_schedule_t; typedef struct ucc_tl_mlx5_mcast_coll_req { - ucc_tl_mlx5_mcast_coll_comm_t *comm; - size_t length; - int proto; - struct ibv_mr *mr; - struct ibv_mr *recv_mr; - struct ibv_recv_wr *rwr; - struct ibv_sge *rsgs; - void *rreg; - char *ptr; - char *rptr; - int am_root; - ucc_rank_t root; - void **rbufs; - int first_send_psn; - int to_send; - int to_recv; - ucc_rank_t parent; - uint32_t start_psn; - int num_packets; - int last_pkt_len; - int offset; - ucc_memory_type_t buf_mem_type; - int one_sided_reliability_scheme; - int concurreny_level; - int ag_counter; - int state; - ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule; - int total_steps; - int step; - ucc_service_coll_req_t *allgather_rkeys_req; - ucc_service_coll_req_t *barrier_req; - void *recv_rreg; + ucc_tl_mlx5_mcast_coll_comm_t *comm; + size_t length; + int proto; + struct ibv_mr *mr; + struct ibv_mr *recv_mr; + struct ibv_recv_wr *rwr; + struct ibv_sge *rsgs; + void *rreg; + char *ptr; + char *rptr; + int am_root; + ucc_rank_t root; + void **rbufs; + int first_send_psn; + int to_send; + int to_recv; + ucc_rank_t parent; + uint32_t start_psn; + int num_packets; + int last_pkt_len; + int offset; + ucc_memory_type_t buf_mem_type; + enum ucc_tl_mlx5_mcast_one_sided_reliability_scheme one_sided_reliability_scheme; + int ag_counter; + int state; + ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule; + int total_steps; + int step; + ucc_service_coll_req_t *allgather_rkeys_req; + ucc_service_coll_req_t *barrier_req; + void *recv_rreg; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -458,6 +457,49 @@ static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast return pp; } +static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t* comm) +{ + struct ibv_recv_wr *bad_wr = NULL; + struct ibv_recv_wr *rwr = comm->call_rwr; + struct ibv_sge *sge = comm->call_rsgs; + struct pp_packet *pp = NULL; + int count = comm->params.rx_depth - comm->pending_recv; + int i; + + if (count <= comm->params.post_recv_thresh) { + return UCC_OK; + } + + for (i = 0; i < count - 1; i++) { + if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) { + break; + } + + rwr[i].wr_id = ((uint64_t) pp); + rwr[i].next = &rwr[i+1]; + sge[2*i + 1].addr = pp->buf; + + ucc_assert((uint64_t)comm->pp <= rwr[i].wr_id + && ((uint64_t)comm->pp + comm->buf_n * sizeof(struct pp_packet)) > rwr[i].wr_id); + } + if (i != 0) { + rwr[i-1].next = NULL; + if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) { + tl_error(comm->lib, "failed to prepost recvs: errno %d", errno); + return UCC_ERR_NO_RESOURCE; + } + comm->pending_recv += i; + } + + return UCC_OK; +} + +static inline uint64_t ucc_tl_mlx5_mcast_get_timer(void) +{ + double t_second = ucc_get_time(); + return (uint64_t) (t_second * 1000000); +} + static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t *comm, ucc_tl_mlx5_mcast_coll_req_t *req, int group_id, ucc_rank_t root, @@ -481,7 +523,6 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ pp->packet_counter = offset / comm->max_per_packet; pp->qp_id = group_id; rwr[i].wr_id = ((uint64_t) pp); - rwr[i].next = &rwr[i+1]; sge[2*i + 1].addr = (uint64_t)req->rptr + root * req->length + offset; sge[2*i + 1].lkey = req->recv_mr->lkey; offset += comm->max_per_packet; @@ -490,6 +531,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ sge[2*i + 1].length = req->last_pkt_len; } else { sge[2*i + 1].length = comm->max_per_packet; + rwr[i].next = &rwr[i+1]; } } @@ -507,49 +549,6 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ return UCC_OK; } -static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t* comm) -{ - struct ibv_recv_wr *bad_wr = NULL; - struct ibv_recv_wr *rwr = comm->call_rwr; - struct ibv_sge *sge = comm->call_rsgs; - struct pp_packet *pp = NULL; - int count = comm->params.rx_depth - comm->pending_recv; - int i; - - if (count <= comm->params.post_recv_thresh) { - return UCC_OK; - } - - for (i = 0; i < count - 1; i++) { - if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) { - break; - } - - rwr[i].wr_id = ((uint64_t) pp); - rwr[i].next = &rwr[i+1]; - sge[2*i + 1].addr = pp->buf; - - ucc_assert((uint64_t)comm->pp <= rwr[i].wr_id - && ((uint64_t)comm->pp + comm->buf_n * sizeof(struct pp_packet)) > rwr[i].wr_id); - } - if (i != 0) { - rwr[i-1].next = NULL; - if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) { - tl_error(comm->lib, "failed to prepost recvs: errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } - comm->pending_recv += i; - } - - return UCC_OK; -} - -static inline uint64_t ucc_tl_mlx5_mcast_get_timer(void) -{ - double t_second = ucc_get_time(); - return (uint64_t) (t_second * 1000000); -} - ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c index 74d7288936..edc0522402 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.c @@ -11,6 +11,24 @@ #include "tl_mlx5_mcast_allgather.h" #include +/* 32 here is the bit count of ib mcast packet's immediate data */ +#define TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT 32 + +static inline void ucc_tl_mlx5_mcast_get_max_allgather_packet_count(int *max_count) +{ + int pow2; + int tmp; + pow2 = log(ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) / log(2); + tmp = TL_MLX5_MCAST_IB_IMMEDIATE_PACKET_BIT_COUNT - pow2; + pow2 = log(ONE_SIDED_MAX_ALLGATHER_COUNTER) / log(2); + tmp = tmp - pow2; + *max_count = pow(2, tmp); +} + +#define MCAST_ALLGATHER_IN_PROGRESS(_req, _comm) \ + (_req->to_send || _req->to_recv || _comm->pending_send || \ + _comm->one_sided.rdma_read_in_progress || (NULL != _req->allgather_rkeys_req)) \ + static inline ucc_status_t ucc_tl_mlx5_mcast_check_collective(ucc_tl_mlx5_mcast_coll_comm_t *comm, ucc_tl_mlx5_mcast_coll_req_t *req) { @@ -80,7 +98,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reset_reliablity(ucc_tl_mlx5_mcast_ ucc_tl_mlx5_mcast_reg_t *reg = NULL; ucc_status_t status; - ucc_assert(req->ag_counter == comm->ag_under_progress_counter); + ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter); if (comm->one_sided.reliability_enabled && !comm->one_sided.reliability_ready) { /* initialize the structures needed by reliablity protocol */ @@ -129,7 +147,7 @@ static inline void ucc_tl_mlx5_mcast_init_async_reliability_slots(ucc_tl_mlx5_mc ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; void *dest; - ucc_assert(req->ag_counter == comm->ag_under_progress_counter); + ucc_assert(req->ag_counter == comm->allgather_comm.under_progress_counter); if (ONE_SIDED_ASYNCHRONOUS_PROTO == req->one_sided_reliability_scheme && ONE_SIDED_INVALID == comm->one_sided.slots_state) { @@ -147,7 +165,7 @@ static inline void ucc_tl_mlx5_mcast_init_async_reliability_slots(ucc_tl_mlx5_mc } } -static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll_req_t *req) +static inline ucc_status_t ucc_tl_mlx5_mcast_do_staging_based_allgather(ucc_tl_mlx5_mcast_coll_req_t *req) { ucc_status_t status = UCC_OK; ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; @@ -162,12 +180,16 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll } if (req->to_send || req->to_recv) { - ucc_assert(comm->max_push_send >= comm->pending_send); + ucc_assert(comm->allgather_comm.max_push_send >= comm->pending_send); if (req->to_send && - (comm->max_push_send - comm->pending_send) > 0) { - ucc_tl_mlx5_mcast_send_collective(comm, req, ucc_min(comm->max_push_send - - comm->pending_send, req->to_send), - zcopy, UCC_COLL_TYPE_ALLGATHER, -1, SIZE_MAX); + (comm->allgather_comm.max_push_send - comm->pending_send) > 0) { + status = ucc_tl_mlx5_mcast_send_collective(comm, req, ucc_min(comm->allgather_comm.max_push_send - + comm->pending_send, req->to_send), + zcopy, UCC_COLL_TYPE_ALLGATHER, -1, SIZE_MAX); + if (status < 0) { + tl_error(comm->lib, "a failure happend during send packets"); + return status; + } } ucc_tl_mlx5_mcast_init_async_reliability_slots(req); @@ -223,11 +245,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_allgather(ucc_tl_mlx5_mcast_coll } } -ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req) +static inline ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req) { ucc_status_t status; - status = ucc_tl_mlx5_mcast_do_allgather(req); + status = ucc_tl_mlx5_mcast_do_staging_based_allgather(req); if (UCC_OK == status) { ucc_assert(req->comm->ctx != NULL); ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg); @@ -248,17 +270,44 @@ ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* req) return status; } -static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void *rbuf, int size, - ucc_tl_mlx5_mcast_coll_comm_t *comm, - ucc_tl_mlx5_mcast_coll_req_t *req) +ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_mcast_reg_t *reg = NULL; - ucc_status_t status; + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); + ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; + ucc_coll_args_t *args = &TASK_ARGS(task); + ucc_datatype_t dt = args->src.info.datatype; + size_t count = args->src.info.count; + ucc_status_t status = UCC_OK; + size_t data_size = ucc_dt_size(dt) * count; + void *sbuf = args->src.info.buffer; + void *rbuf = args->dst.info.buffer; + ucc_tl_mlx5_mcast_coll_comm_t *comm = team->mcast_comm; + ucc_tl_mlx5_mcast_reg_t *reg = NULL; + ucc_tl_mlx5_mcast_coll_req_t *req; + + + if (!data_size) { + coll_task->status = UCC_OK; + return ucc_task_complete(coll_task); + } + + task->coll_mcast.req_handle = NULL; + + tl_trace(comm->lib, "MCAST allgather start, sbuf %p, rbuf %p, size %ld, comm %d, " + "comm_size %d, counter %d", + sbuf, rbuf, data_size, comm->comm_id, comm->commsize, comm->allgather_comm.coll_counter); + + req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req"); + if (!req) { + tl_warn(comm->lib, "malloc failed"); + goto failed; + } req->comm = comm; req->ptr = sbuf; req->rptr = rbuf; - req->length = size; + req->length = data_size; req->mr = comm->pp_mr; req->rreg = NULL; /* - zero copy protocol only provides zero copy design at sender side @@ -267,27 +316,19 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void req->proto = (req->length < comm->max_eager) ? MCAST_PROTO_EAGER : MCAST_PROTO_ZCOPY; - if (comm->commsize > ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) { - tl_warn(comm->lib, - "team size is %d but max supported team size of mcast allgather is %d", - comm->commsize, ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE); - return UCC_ERR_NOT_SUPPORTED; - } + assert(comm->commsize <= ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE); req->offset = 0; - req->num_packets = (req->length + comm->max_per_packet - 1)/comm->max_per_packet; + req->num_packets = ucc_max(1, (req->length + comm->max_per_packet - 1)/comm->max_per_packet); - if (req->num_packets == 0) { - req->num_packets = 1; - } + ucc_tl_mlx5_mcast_get_max_allgather_packet_count(&comm->allgather_comm.max_num_packets); - ONE_SIDED_MAX_PACKET_COUNT(comm->ag_max_num_packets); - - if (comm->ag_max_num_packets < req->num_packets) { + if (comm->allgather_comm.max_num_packets < req->num_packets) { tl_warn(comm->lib, "msg size is %ld but max supported msg size of mcast allgather is %d", - req->length, comm->ag_max_num_packets * comm->max_per_packet); - return UCC_ERR_NOT_SUPPORTED; + req->length, comm->allgather_comm.max_num_packets * comm->max_per_packet); + status = UCC_ERR_NOT_SUPPORTED; + goto failed; } req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet; @@ -298,7 +339,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void /* register the send buffer */ status = ucc_tl_mlx5_mcast_mem_register(comm->ctx, req->ptr, req->length, ®); if (UCC_OK != status) { - return status; + ucc_free(req); + goto failed; } req->rreg = reg; req->mr = reg->mr; @@ -312,84 +354,32 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_allgather(void* sbuf, void req->one_sided_reliability_scheme = ONE_SIDED_NO_RELIABILITY; } - req->ag_counter = comm->ag_counter; + req->ag_counter = comm->allgather_comm.coll_counter; req->to_send = req->num_packets; req->to_recv = comm->commsize * req->num_packets; - comm->ag_counter++; - return UCC_OK; -} - -static inline ucc_status_t ucc_tl_mlx5_mcast_coll_do_allgather(void* sbuf, void *rbuf, int size, - ucc_tl_mlx5_mcast_coll_comm_t *comm, - ucc_tl_mlx5_mcast_coll_req_t **task_req_handle) -{ - ucc_tl_mlx5_mcast_coll_req_t *req; - ucc_status_t status; + comm->allgather_comm.coll_counter++; - tl_trace(comm->lib, "MCAST allgather start, sbuf %p, rbuf %p, size %d, comm %d, " - "comm_size %d, counter %d", - sbuf, rbuf, size, comm->comm_id, comm->commsize, comm->ag_counter); - - req = ucc_calloc(1, sizeof(ucc_tl_mlx5_mcast_coll_req_t), "mcast_req"); - if (!req) { - tl_error(comm->lib, "malloc failed"); - return UCC_ERR_NO_MEMORY; - } - - status = ucc_tl_mlx5_mcast_prepare_allgather(sbuf, rbuf, size, comm, req); - if (UCC_OK != status) { - tl_warn(comm->lib, "prepare mcast allgather failed"); - ucc_free(req); - return status; - } - - status = UCC_INPROGRESS; - - *task_req_handle = req; - - return status; -} - -ucc_status_t ucc_tl_mlx5_mcast_allgather_start(ucc_coll_task_t *coll_task) -{ - ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); - ucc_tl_mlx5_team_t *mlx5_team = TASK_TEAM(task); - ucc_tl_mlx5_mcast_team_t *team = mlx5_team->mcast; - ucc_coll_args_t *args = &TASK_ARGS(task); - ucc_datatype_t dt = args->src.info.datatype; - size_t count = args->src.info.count; - ucc_status_t status = UCC_OK; - size_t data_size = ucc_dt_size(dt) * count; - void *sbuf = args->src.info.buffer; - void *rbuf = args->dst.info.buffer; - ucc_tl_mlx5_mcast_coll_comm_t *comm = team->mcast_comm; - - task->coll_mcast.req_handle = NULL; - - status = ucc_tl_mlx5_mcast_coll_do_allgather(sbuf, rbuf, data_size, comm, &task->coll_mcast.req_handle); - if (status < 0) { - tl_warn(UCC_TASK_LIB(task), "do mcast allgather failed:%d", status); - coll_task->status = status; - return ucc_task_complete(coll_task); - } + task->coll_mcast.req_handle = req; + coll_task->status = UCC_INPROGRESS; + return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); +failed: + tl_warn(UCC_TASK_LIB(task), "mcast start allgather failed:%d", status); coll_task->status = status; - - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); + return ucc_task_complete(coll_task); } void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); - ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle; + ucc_tl_mlx5_task_t *task = ucc_derived_of(coll_task, ucc_tl_mlx5_task_t); + ucc_tl_mlx5_mcast_coll_req_t *req = task->coll_mcast.req_handle; ucc_status_t status; - if (task->coll_mcast.req_handle != NULL) { - req = task->coll_mcast.req_handle; - if (req->ag_counter != req->comm->ag_under_progress_counter) { + if (req != NULL) { + if (req->ag_counter != req->comm->allgather_comm.under_progress_counter) { /* it is not this task's turn for progress */ - ucc_assert(req->comm->ag_under_progress_counter < req->ag_counter); + ucc_assert(req->comm->allgather_comm.under_progress_counter < req->ag_counter); return; } @@ -398,14 +388,27 @@ void ucc_tl_mlx5_mcast_allgather_progress(ucc_coll_task_t *coll_task) return; } else if (UCC_OK == status) { coll_task->status = UCC_OK; - req->comm->ag_under_progress_counter++; + req->comm->allgather_comm.under_progress_counter++; ucc_free(req); task->coll_mcast.req_handle = NULL; } else { tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed:%d", status); coll_task->status = status; + if (req->rreg) { + ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->rreg); + req->rreg = NULL; + } + if (req->recv_rreg) { + ucc_tl_mlx5_mcast_mem_deregister(req->comm->ctx, req->recv_rreg); + req->recv_rreg = NULL; + } + ucc_free(req); ucc_task_complete(coll_task); } + } else { + tl_error(UCC_TASK_LIB(task), "progress mcast allgather failed, mcast coll not initialized"); + coll_task->status = UCC_ERR_NO_RESOURCE; + ucc_task_complete(coll_task); } return; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h index fea17aa79e..a51aea451f 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_allgather.h @@ -10,12 +10,6 @@ #include "tl_mlx5_mcast.h" #include "tl_mlx5_coll.h" -#define MCAST_ALLGATHER_IN_PROGRESS(_req, _comm) \ - (_req->to_send || _req->to_recv || _comm->pending_send || \ - _comm->one_sided.rdma_read_in_progress || (NULL != _req->allgather_rkeys_req)) \ - ucc_status_t ucc_tl_mlx5_mcast_allgather_init(ucc_tl_mlx5_task_t *task); -ucc_status_t ucc_tl_mlx5_mcast_test_allgather(ucc_tl_mlx5_mcast_coll_req_t* _req); - #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index 046e0e5963..00c3765be9 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -12,12 +12,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ ucc_tl_mlx5_mcast_coll_req_t *req) { ucc_status_t status = UCC_OK; - int wsize = comm->wsize; - int num_free_win = wsize - (comm->psn - comm->last_acked); + int wsize = comm->bcast_comm.wsize; + int num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked); int req_completed = (req->to_send == 0 && req->to_recv == 0); struct pp_packet *pp = NULL; - ucc_assert(comm->recv_drop_packet_in_progress == false); + ucc_assert(comm->bcast_comm.recv_drop_packet_in_progress == false); ucc_assert(req->to_send >= 0); /* When do we need to perform reliability protocol: @@ -33,12 +33,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ return status; } - comm->n_mcast_reliable++; + comm->bcast_comm.n_mcast_reliable++; - for (;comm->last_acked < comm->psn; comm->last_acked++) { - pp = comm->r_window[comm->last_acked & (wsize-1)]; + for (;comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) { + pp = comm->r_window[comm->bcast_comm.last_acked & (wsize-1)]; ucc_assert(pp != &comm->dummy_packet); - comm->r_window[comm->last_acked & (wsize-1)] = &comm->dummy_packet; + comm->r_window[comm->bcast_comm.last_acked & (wsize-1)] = &comm->dummy_packet; pp->context = 0; ucc_list_add_tail(&comm->bpool, &pp->super); @@ -60,7 +60,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req ucc_status_t status = UCC_OK; ucc_tl_mlx5_mcast_coll_comm_t *comm = req->comm; int zcopy = req->proto != MCAST_PROTO_EAGER; - int wsize = comm->wsize; + int wsize = comm->bcast_comm.wsize; int num_free_win; int num_sent; int to_send; @@ -74,29 +74,29 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req return status; } - if (ucc_unlikely(comm->recv_drop_packet_in_progress)) { + if (ucc_unlikely(comm->bcast_comm.recv_drop_packet_in_progress)) { /* wait till parent resend the dropped packet */ return UCC_INPROGRESS; } if (req->to_send || req->to_recv) { - num_free_win = wsize - (comm->psn - comm->last_acked); + num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked); /* Send data if i'm root and there is a space in the window */ if (num_free_win && req->am_root) { num_sent = req->num_packets - req->to_send; ucc_assert(req->to_send > 0); - ucc_assert(req->first_send_psn + num_sent < comm->last_acked + wsize); - if (req->first_send_psn + num_sent < comm->last_acked + wsize && + ucc_assert(req->first_send_psn + num_sent < comm->bcast_comm.last_acked + wsize); + if (req->first_send_psn + num_sent < comm->bcast_comm.last_acked + wsize && req->to_send) { /* How many to send: either all that are left (if they fit into window) or up to the window limit */ to_send = ucc_min(req->to_send, - comm->last_acked + wsize - (req->first_send_psn + num_sent)); + comm->bcast_comm.last_acked + wsize - (req->first_send_psn + num_sent)); ucc_tl_mlx5_mcast_send(comm, req, to_send, zcopy); - num_free_win = wsize - (comm->psn - comm->last_acked); + num_free_win = wsize - (comm->psn - comm->bcast_comm.last_acked); } } @@ -119,8 +119,8 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req tl_trace(comm->lib, "Did not receive the packet with psn in" " current window range, so get ready for drop" " event. pending_q_size %d current comm psn %d" - " last_acked psn %d stall threshold %d ", - pending_q_size, comm->psn, comm->last_acked, + " bcast_comm.last_acked psn %d stall threshold %d ", + pending_q_size, comm->psn, comm->bcast_comm.last_acked, DROP_THRESHOLD); status = ucc_tl_mlx5_mcast_bcast_check_drop(comm, req); @@ -144,7 +144,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_do_bcast(ucc_tl_mlx5_mcast_coll_req return status; } - if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->last_acked)) { + if (req->to_send || req->to_recv || (zcopy && comm->psn != comm->bcast_comm.last_acked)) { return UCC_INPROGRESS; } else { return status; @@ -201,16 +201,16 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_prepare_bcast(void* buf, size_t siz } req->offset = 0; - req->start_psn = comm->last_psn; + req->start_psn = comm->bcast_comm.last_psn; req->num_packets = ucc_max(ucc_div_round_up(req->length, comm->max_per_packet), 1); req->last_pkt_len = req->length - (req->num_packets - 1)*comm->max_per_packet; ucc_assert(req->last_pkt_len > 0 && req->last_pkt_len <= comm->max_per_packet); - comm->last_psn += req->num_packets; - req->first_send_psn = req->start_psn; - req->to_send = req->am_root ? req->num_packets : 0; - req->to_recv = req->am_root ? 0 : req->num_packets; + comm->bcast_comm.last_psn += req->num_packets; + req->first_send_psn = req->start_psn; + req->to_send = req->am_root ? req->num_packets : 0; + req->to_recv = req->am_root ? 0 : req->num_packets; return UCC_OK; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index f57daeab5e..b84b7120f9 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -657,7 +657,7 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) tl_debug(comm->lib, "comm_id %d, comm_size %d, comm->psn %d, rank %d, " "nacks counter %d, n_mcast_rel %d", comm->comm_id, comm->commsize, comm->psn, comm->rank, - comm->nacks_counter, comm->n_mcast_reliable); + comm->bcast_comm.nacks_counter, comm->bcast_comm.n_mcast_reliable); } if (comm->p2p_ctx != NULL) { diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index 36440d5c16..1dd338f7a3 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -86,7 +86,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t swr[0].imm_data = htonl(pp->psn); swr[0].send_flags = (length <= comm->max_inline) ? IBV_SEND_INLINE : 0; - comm->r_window[pp->psn & (comm->wsize-1)] = pp; + comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp; comm->psn++; req->to_send--; offset += length; @@ -102,7 +102,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED); if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) { - tl_error(comm->lib, "Post send failed: ret %d, start_psn %d, to_send %d, " + tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " "to_recv %d, length %d, psn %d, inline %d", rc, req->start_psn, req->to_send, req->to_recv, length, pp->psn, length <= comm->max_inline); @@ -127,7 +127,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_process_pp(ucc_tl_mlx5_mcast_coll_c { ucc_status_t status = UCC_OK; - if (PSN_RECEIVED(pp->psn, comm) || pp->psn < comm->last_acked) { + if (PSN_RECEIVED(pp->psn, comm) || pp->psn < comm->bcast_comm.last_acked) { /* This psn was already received */ ucc_assert(pp->context == 0); if (in_pending_queue) { @@ -336,7 +336,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send_collective(ucc_tl_mlx5_mcast_c mcast_group_index); if (0 != (rc = ibv_post_send(comm->mcast.qp_list[mcast_group_index], &swr[0], &bad_wr))) { - tl_error(comm->lib, "Post send failed: ret %d, start_psn %d, to_send %d, " + tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " "to_recv %d, length %d, psn %d, inline %d", rc, req->start_psn, req->to_send, req->to_recv, length, pp->psn, length <= comm->max_inline); @@ -426,6 +426,7 @@ static inline int ucc_tl_mlx5_mcast_recv_collective(ucc_tl_mlx5_mcast_coll_comm_ if (UCC_OK != status) { tl_error(comm->lib, "process allgather packet failed, status %d", status); + ucc_free(wc); return -1; } @@ -499,8 +500,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com { ucc_status_t status = UCC_OK; - if (comm->racks_n != comm->child_n || comm->sacks_n != comm->parent_n || - comm->nack_requests) { + if (comm->bcast_comm.racks_n != comm->bcast_comm.child_n || + comm->bcast_comm.sacks_n != comm->bcast_comm.parent_n || + comm->bcast_comm.nack_requests) { if (comm->pending_send) { status = ucc_tl_mlx5_mcast_poll_send(comm); if (UCC_OK != status) { @@ -508,7 +510,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com } } - if (comm->parent_n) { + if (comm->bcast_comm.parent_n) { status = ucc_tl_mlx5_mcast_poll_recv(comm); if (UCC_OK != status) { return status; @@ -521,26 +523,27 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable(ucc_tl_mlx5_mcast_coll_com } } - if (comm->parent_n && !comm->reliable_in_progress) { + if (comm->bcast_comm.parent_n && !comm->bcast_comm.reliable_in_progress) { status = ucc_tl_mlx5_mcast_reliable_send(comm); if (UCC_OK != status) { return status; } } - if (!comm->reliable_in_progress) { - comm->reliable_in_progress = 1; + if (!comm->bcast_comm.reliable_in_progress) { + comm->bcast_comm.reliable_in_progress = 1; } - if (comm->racks_n == comm->child_n && comm->sacks_n == comm->parent_n && - 0 == comm->nack_requests) { + if (comm->bcast_comm.racks_n == comm->bcast_comm.child_n && + comm->bcast_comm.sacks_n == comm->bcast_comm.parent_n && 0 == + comm->bcast_comm.nack_requests) { // Reset for next round. - memset(comm->parents, 0, sizeof(comm->parents)); - memset(comm->children, 0, sizeof(comm->children)); + memset(comm->bcast_comm.parents, 0, sizeof(comm->bcast_comm.parents)); + memset(comm->bcast_comm.children, 0, sizeof(comm->bcast_comm.children)); - comm->racks_n = comm->child_n = 0; - comm->sacks_n = comm->parent_n = 0; - comm->reliable_in_progress = 0; + comm->bcast_comm.racks_n = comm->bcast_comm.child_n = 0; + comm->bcast_comm.sacks_n = comm->bcast_comm.parent_n = 0; + comm->bcast_comm.reliable_in_progress = 0; return UCC_OK; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.c index a04e593713..7a89ecf592 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_progress.c @@ -223,7 +223,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet_collective(ucc_tl_mlx5_mcast_coll_ // this means that a packet which was considered dropped in previous run has not just arrived // need to check the allgather call counter and ignore this packet if it does not match - if (ag_counter == (req->ag_counter % ONE_SIDED_MAX_ALLGATHER_COUNTER)) { + if (ag_counter == (req->ag_counter % ONE_SIDED_MAX_ALLGATHER_COUNTER)) { if (pp->length) { if (pp->length == comm->max_per_packet) { dest = req->rptr + offset * pp->length + source_rank * req->length; @@ -242,11 +242,11 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet_collective(ucc_tl_mlx5_mcast_coll_ if (comm->one_sided.recvd_pkts_tracker[source_rank] > req->num_packets) { tl_error(comm->lib, "reliablity failed: comm->one_sided.recvd_pkts_tracker[%d] %d" - " req->num_packets %d offset %d PACKET_TO_DROP %d" - " comm->ag_under_progress_counter %d req->ag_counter" + " req->num_packets %d offset %d" + " comm->allgather_comm.under_progress_counter %d req->ag_counter" " %d \n", source_rank, comm->one_sided.recvd_pkts_tracker[source_rank], - req->num_packets, offset, PACKET_TO_DROP, - comm->ag_under_progress_counter, req->ag_counter); + req->num_packets, offset, + comm->allgather_comm.under_progress_counter, req->ag_counter); return UCC_ERR_NO_MESSAGE; } } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c index 85d63a82d0..a3d16fd6d8 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c @@ -169,6 +169,13 @@ ucc_status_t ucc_tl_mlx5_mcast_one_sided_reliability_init(ucc_base_team_t *team) ucc_tl_mlx5_mcast_coll_comm_t *comm = tl_team->mcast->mcast_comm; ucc_status_t status = UCC_OK; + if (comm->commsize > ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) { + tl_warn(comm->lib, + "team size is %d but max supported team size of mcast one-sided reliability is %d", + comm->commsize, ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE); + return UCC_ERR_NOT_SUPPORTED; + } + status = ucc_tl_mlx5_mcast_one_sided_setup_reliability_buffers(team); if (status != UCC_OK) { tl_error(comm->lib, "setup reliablity resources failed"); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index f415acb2b2..2467d470cf 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -24,12 +24,12 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc if (pkt_id != UINT_MAX) { /* we sent the real data to our child so reduce the nack reqs */ - ucc_assert(comm->nack_requests > 0); - ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK_SEND_PENDING); - comm->p2p_pkt[pkt_id].type = MCAST_P2P_ACK; - comm->nack_requests--; - status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[pkt_id], - sizeof(struct packet), comm->p2p_pkt[pkt_id].from, + ucc_assert(comm->bcast_comm.nack_requests > 0); + ucc_assert(comm->bcast_comm.p2p_pkt[pkt_id].type == MCAST_P2P_NACK_SEND_PENDING); + comm->bcast_comm.p2p_pkt[pkt_id].type = MCAST_P2P_ACK; + comm->bcast_comm.nack_requests--; + status = comm->params.p2p_iface.recv_nb(&comm->bcast_comm.p2p_pkt[pkt_id], + sizeof(struct packet), comm->bcast_comm.p2p_pkt[pkt_id].from, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_recv_completion, pkt_id, NULL)); if (status < 0) { @@ -45,21 +45,21 @@ static ucc_status_t ucc_tl_mlx5_mcast_reliability_send_completion(ucc_tl_mlx5_mc static inline ucc_status_t ucc_tl_mlx5_mcast_resend_packet_reliable(ucc_tl_mlx5_mcast_coll_comm_t *comm, int p2p_pkt_id) { - uint32_t psn = comm->p2p_pkt[p2p_pkt_id].psn; - struct pp_packet *pp = comm->r_window[psn % comm->wsize]; + uint32_t psn = comm->bcast_comm.p2p_pkt[p2p_pkt_id].psn; + struct pp_packet *pp = comm->r_window[psn % comm->bcast_comm.wsize]; ucc_status_t status; ucc_assert(pp->psn == psn); - ucc_assert(comm->p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND); + ucc_assert(comm->bcast_comm.p2p_pkt[p2p_pkt_id].type == MCAST_P2P_NEED_NACK_SEND); - comm->p2p_pkt[p2p_pkt_id].type = MCAST_P2P_NACK_SEND_PENDING; + comm->bcast_comm.p2p_pkt[p2p_pkt_id].type = MCAST_P2P_NACK_SEND_PENDING; tl_trace(comm->lib, "[comm %d, rank %d] Send data NACK: to %d, psn %d, context %ld nack_requests %d \n", comm->comm_id, comm->rank, - comm->p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->nack_requests); + comm->bcast_comm.p2p_pkt[p2p_pkt_id].from, psn, pp->context, comm->bcast_comm.nack_requests); status = comm->params.p2p_iface.send_nb((void*) (pp->context ? pp->context : pp->buf), - pp->length, comm->p2p_pkt[p2p_pkt_id].from, + pp->length, comm->bcast_comm.p2p_pkt[p2p_pkt_id].from, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_reliability_send_completion, NULL, p2p_pkt_id)); if (status < 0) { @@ -75,14 +75,14 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t int i; struct pp_packet *pp; - if (!comm->nack_requests) { + if (!comm->bcast_comm.nack_requests) { return UCC_OK; } if (psn != UINT32_MAX) { - for (i=0; ichild_n; i++) { - if (psn == comm->p2p_pkt[i].psn && - comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + for (i=0; ibcast_comm.child_n; i++) { + if (psn == comm->bcast_comm.p2p_pkt[i].psn && + comm->bcast_comm.p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); if (status != UCC_OK) { break; @@ -90,10 +90,10 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t } } } else { - for (i=0; ichild_n; i++){ - if (comm->p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { - psn = comm->p2p_pkt[i].psn; - pp = comm->r_window[psn % comm->wsize]; + for (i=0; ibcast_comm.child_n; i++){ + if (comm->bcast_comm.p2p_pkt[i].type == MCAST_P2P_NEED_NACK_SEND) { + psn = comm->bcast_comm.p2p_pkt[i].psn; + pp = comm->r_window[psn % comm->bcast_comm.wsize]; if (psn == pp->psn) { status = ucc_tl_mlx5_mcast_resend_packet_reliable(comm, i); if (status < 0) { @@ -110,9 +110,9 @@ ucc_status_t ucc_tl_mlx5_mcast_check_nack_requests(ucc_tl_mlx5_mcast_coll_comm_t static inline int ucc_tl_mlx5_mcast_find_nack_psn(ucc_tl_mlx5_mcast_coll_comm_t* comm, ucc_tl_mlx5_mcast_coll_req_t *req) { - int psn = ucc_max(comm->last_acked, req->start_psn); + int psn = ucc_max(comm->bcast_comm.last_acked, req->start_psn); int max_search_psn = ucc_min(req->start_psn + req->num_packets, - comm->last_acked + comm->wsize + 1); + comm->bcast_comm.last_acked + comm->bcast_comm.wsize + 1); for (; psn < max_search_psn; psn++) { if (!PSN_RECEIVED(psn, comm)) { @@ -144,7 +144,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); memcpy(dest, (void*) pp->buf, pp->length); req->to_recv--; - comm->r_window[pp->psn % comm->wsize] = pp; + comm->r_window[pp->psn % comm->bcast_comm.wsize] = pp; status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); if (status < 0) { @@ -152,7 +152,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_data_completion(ucc_tl_mlx5_mcast_p2p } comm->psn++; - comm->recv_drop_packet_in_progress = false; + comm->bcast_comm.recv_drop_packet_in_progress = false; return status; } @@ -174,7 +174,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas parent = ucc_tl_mlx5_mcast_get_nack_parent(req); - comm->nacks_counter++; + comm->bcast_comm.nacks_counter++; status = comm->params.p2p_iface.send_nb(p, sizeof(struct packet), parent, comm->p2p_ctx, GET_COMPL_OBJ(comm, @@ -191,7 +191,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_reliable_send_NACK(ucc_tl_mlx5_mcas pp->psn = psn; pp->length = PSN_TO_RECV_LEN(pp->psn, req, comm); - comm->recv_drop_packet_in_progress = true; + comm->bcast_comm.recv_drop_packet_in_progress = true; status = comm->params.p2p_iface.recv_nb((void*) pp->buf, pp->length, parent, @@ -211,20 +211,20 @@ ucc_status_t ucc_tl_mlx5_mcast_reliable_send(ucc_tl_mlx5_mcast_coll_comm_t *comm ucc_status_t status; tl_trace(comm->lib, "comm %p, psn %d, last_acked %d, n_parent %d", - comm, comm->psn, comm->last_acked, comm->parent_n); + comm, comm->psn, comm->bcast_comm.last_acked, comm->bcast_comm.parent_n); - ucc_assert(!comm->reliable_in_progress); + ucc_assert(!comm->bcast_comm.reliable_in_progress); - for (i=0; iparent_n; i++) { - parent = comm->parents[i]; - comm->p2p_spkt[i].type = MCAST_P2P_ACK; - comm->p2p_spkt[i].psn = comm->last_acked + comm->wsize; - comm->p2p_spkt[i].comm_id = comm->comm_id; + for (i=0; ibcast_comm.parent_n; i++) { + parent = comm->bcast_comm.parents[i]; + comm->bcast_comm.p2p_spkt[i].type = MCAST_P2P_ACK; + comm->bcast_comm.p2p_spkt[i].psn = comm->bcast_comm.last_acked + comm->bcast_comm.wsize; + comm->bcast_comm.p2p_spkt[i].comm_id = comm->comm_id; tl_trace(comm->lib, "rank %d, Posting SEND to parent %d, n_parent %d, psn %d", - comm->rank, parent, comm->parent_n, comm->psn); + comm->rank, parent, comm->bcast_comm.parent_n, comm->psn); - status = comm->params.p2p_iface.send_nb(&comm->p2p_spkt[i], + status = comm->params.p2p_iface.send_nb(&comm->bcast_comm.p2p_spkt[i], sizeof(struct packet), parent, comm->p2p_ctx, GET_COMPL_OBJ(comm, ucc_tl_mlx5_mcast_send_completion, i, NULL)); @@ -244,19 +244,19 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp struct pp_packet *pp; ucc_status_t status; - ucc_assert(comm->comm_id == comm->p2p_pkt[pkt_id].comm_id); + ucc_assert(comm->comm_id == comm->bcast_comm.p2p_pkt[pkt_id].comm_id); - if (comm->p2p_pkt[pkt_id].type != MCAST_P2P_ACK) { - ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_NACK); - psn = comm->p2p_pkt[pkt_id].psn; - pp = comm->r_window[psn % comm->wsize]; + if (comm->bcast_comm.p2p_pkt[pkt_id].type != MCAST_P2P_ACK) { + ucc_assert(comm->bcast_comm.p2p_pkt[pkt_id].type == MCAST_P2P_NACK); + psn = comm->bcast_comm.p2p_pkt[pkt_id].psn; + pp = comm->r_window[psn % comm->bcast_comm.wsize]; tl_trace(comm->lib, "[comm %d, rank %d] Got NACK: from %d, psn %d, avail %d pkt_id %d", comm->comm_id, comm->rank, - comm->p2p_pkt[pkt_id].from, psn, pp->psn == psn, pkt_id); + comm->bcast_comm.p2p_pkt[pkt_id].from, psn, pp->psn == psn, pkt_id); - comm->p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; - comm->nack_requests++; + comm->bcast_comm.p2p_pkt[pkt_id].type = MCAST_P2P_NEED_NACK_SEND; + comm->bcast_comm.nack_requests++; if (pp->psn == psn) { /* parent already has this packet so it is ready to forward it to its child */ @@ -267,8 +267,8 @@ static ucc_status_t ucc_tl_mlx5_mcast_recv_completion(ucc_tl_mlx5_mcast_p2p_comp } } else { - ucc_assert(comm->p2p_pkt[pkt_id].type == MCAST_P2P_ACK); - comm->racks_n++; + ucc_assert(comm->bcast_comm.p2p_pkt[pkt_id].type == MCAST_P2P_ACK); + comm->bcast_comm.racks_n++; } ucc_mpool_put(obj); /* return the completion object back to the mem pool compl_objects_mp */ @@ -280,7 +280,7 @@ static ucc_status_t ucc_tl_mlx5_mcast_send_completion(ucc_tl_mlx5_mcast_p2p_comp { ucc_tl_mlx5_mcast_coll_comm_t *comm = (ucc_tl_mlx5_mcast_coll_comm_t*)obj->data[0]; - comm->sacks_n++; + comm->bcast_comm.sacks_n++; ucc_mpool_put(obj); return UCC_OK; } @@ -314,20 +314,20 @@ ucc_status_t ucc_tl_mlx5_mcast_prepare_reliable(ucc_tl_mlx5_mcast_coll_comm_t *c while (mask < comm->commsize) { if (vrank & mask) { req->parent = TO_ORIGINAL((vrank ^ mask), comm->commsize, root); - add_uniq(comm->parents, &comm->parent_n, req->parent); + add_uniq(comm->bcast_comm.parents, &comm->bcast_comm.parent_n, req->parent); break; } else { child = vrank ^ mask; if (child < comm->commsize) { child = TO_ORIGINAL(child, comm->commsize, root); - if (add_uniq(comm->children, &comm->child_n, child)) { + if (add_uniq(comm->bcast_comm.children, &comm->bcast_comm.child_n, child)) { tl_trace(comm->lib, "rank %d, Posting RECV from child %d, n_child %d, psn %d", - comm->rank, child, comm->child_n, comm->psn); + comm->rank, child, comm->bcast_comm.child_n, comm->psn); - status = comm->params.p2p_iface.recv_nb(&comm->p2p_pkt[comm->child_n - 1], + status = comm->params.p2p_iface.recv_nb(&comm->bcast_comm.p2p_pkt[comm->bcast_comm.child_n - 1], sizeof(struct packet), child, comm->p2p_ctx, GET_COMPL_OBJ(comm, - ucc_tl_mlx5_mcast_recv_completion, comm->child_n - 1, req)); + ucc_tl_mlx5_mcast_recv_completion, comm->bcast_comm.child_n - 1, req)); if (status < 0) { return status; } @@ -376,7 +376,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com memcpy(dest, (void*) pp->buf, pp->length); } - comm->r_window[pp->psn & (comm->wsize-1)] = pp; + comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp; status = ucc_tl_mlx5_mcast_check_nack_requests(comm, pp->psn); if (status < 0) { return status; @@ -384,7 +384,7 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com req->to_recv--; comm->psn++; - ucc_assert(comm->recv_drop_packet_in_progress == false); + ucc_assert(comm->bcast_comm.recv_drop_packet_in_progress == false); return status; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index b4e42614c1..869a13ae40 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -88,13 +88,13 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, memcpy(&comm->params, conf_params, sizeof(*conf_params)); - comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable; - comm->wsize = conf_params->wsize; - comm->max_push_send = conf_params->max_push_send; - comm->max_eager = conf_params->max_eager; - comm->comm_id = team_params->id; - comm->ctx = mcast_context; - comm->mcast_group_count = 1; /* TODO: add support for more number of mcast groups */ + comm->one_sided.reliability_enabled = conf_params->one_sided_reliability_enable; + comm->bcast_comm.wsize = conf_params->wsize; + comm->allgather_comm.max_push_send = conf_params->max_push_send; + comm->max_eager = conf_params->max_eager; + comm->comm_id = team_params->id; + comm->ctx = mcast_context; + comm->mcast_group_count = 1; /* TODO: add support for more number of mcast groups */ comm->grh_buf = (char *)ucc_malloc(GRH_LENGTH * sizeof(char), "grh_buf"); if (!comm->grh_buf) { @@ -132,20 +132,20 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, goto cleanup; } - comm->rank = team_params->rank; - comm->commsize = team_params->size; - comm->max_per_packet = mcast_context->mtu - GRH_LENGTH; - comm->last_acked = comm->last_psn = 0; - comm->racks_n = comm->sacks_n = 0; - comm->child_n = comm->parent_n = 0; - comm->p2p_ctx = conf_params->oob; + comm->rank = team_params->rank; + comm->commsize = team_params->size; + comm->max_per_packet = mcast_context->mtu - GRH_LENGTH; + comm->bcast_comm.last_acked = comm->bcast_comm.last_psn = 0; + comm->bcast_comm.racks_n = comm->bcast_comm.sacks_n = 0; + comm->bcast_comm.child_n = comm->bcast_comm.parent_n = 0; + comm->p2p_ctx = conf_params->oob; memcpy(&comm->p2p, &conf_params->p2p_iface, sizeof(ucc_tl_mlx5_mcast_p2p_interface_t)); comm->dummy_packet.psn = UINT32_MAX; - for (i=0; i< comm->wsize; i++) { + for (i=0; i< comm->bcast_comm.wsize; i++) { comm->r_window[i] = &comm->dummy_packet; } @@ -257,15 +257,14 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ goto error; } - memset(comm->parents, 0, sizeof(comm->parents)); - memset(comm->children, 0, sizeof(comm->children)); + memset(comm->bcast_comm.parents, 0, sizeof(comm->bcast_comm.parents)); + memset(comm->bcast_comm.children, 0, sizeof(comm->bcast_comm.children)); - comm->nacks_counter = 0; - comm->tx = 0; - comm->n_prep_reliable = 0; - comm->n_mcast_reliable = 0; - comm->reliable_in_progress = 0; - comm->recv_drop_packet_in_progress = 0; + comm->bcast_comm.nacks_counter = 0; + comm->bcast_comm.n_mcast_reliable = 0; + comm->bcast_comm.reliable_in_progress = 0; + comm->bcast_comm.recv_drop_packet_in_progress = 0; + comm->tx = 0; return status; diff --git a/src/components/tl/mlx5/tl_mlx5.h b/src/components/tl/mlx5/tl_mlx5.h index 2054fdbc6a..1b6404e6bd 100644 --- a/src/components/tl/mlx5/tl_mlx5.h +++ b/src/components/tl/mlx5/tl_mlx5.h @@ -164,7 +164,6 @@ typedef struct ucc_tl_mlx5_rcache_region { ucc_tl_mlx5_reg_t reg; } ucc_tl_mlx5_rcache_region_t; -//TODO: add UCC_COLL_TYPE_ALLGATHER once mcast design is completed #define UCC_TL_MLX5_SUPPORTED_COLLS (UCC_COLL_TYPE_ALLTOALL | UCC_COLL_TYPE_BCAST) #define UCC_TL_MLX5_TEAM_LIB(_team) \ diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 80bf10d32c..d0dcc59433 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -9,15 +9,15 @@ #include "mcast/tl_mlx5_mcast_allgather.h" #include "alltoall/alltoall.h" -ucc_status_t ucc_tl_mlx5_allgather_mcast_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h) +ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h) { ucc_status_t status = UCC_OK; ucc_tl_mlx5_task_t *task = NULL; if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { - tl_trace(team->context->lib, "mcast allgather not supported for active sets"); + tl_trace(team->context->lib, "mcast collective not supported for active sets"); return UCC_ERR_NOT_SUPPORTED; } @@ -29,50 +29,26 @@ ucc_status_t ucc_tl_mlx5_allgather_mcast_init(ucc_base_coll_args_t *coll_args, task->super.finalize = ucc_tl_mlx5_task_finalize; - status = ucc_tl_mlx5_mcast_allgather_init(task); - if (ucc_unlikely(UCC_OK != status)) { - goto free_task; - } - - *task_h = &(task->super); - - tl_debug(UCC_TASK_LIB(task), "initialized mcast allgather coll task %p", task); - - return UCC_OK; - -free_task: - ucc_mpool_put(task); - return status; -} - -ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h) -{ - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_task_t *task = NULL; - - if (UCC_COLL_ARGS_ACTIVE_SET(&coll_args->args)) { - tl_trace(team->context->lib, "mcast bcast not supported for active sets"); - return UCC_ERR_NOT_SUPPORTED; - } - - task = ucc_tl_mlx5_get_task(coll_args, team); - - if (ucc_unlikely(!task)) { - return UCC_ERR_NO_MEMORY; - } - - task->super.finalize = ucc_tl_mlx5_task_finalize; - - status = ucc_tl_mlx5_mcast_bcast_init(task); - if (ucc_unlikely(UCC_OK != status)) { - goto free_task; + switch (coll_args->args.coll_type) { + case UCC_COLL_TYPE_BCAST: + status = ucc_tl_mlx5_mcast_bcast_init(task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + break; + case UCC_COLL_TYPE_ALLGATHER: + status = ucc_tl_mlx5_mcast_allgather_init(task); + if (ucc_unlikely(UCC_OK != status)) { + goto free_task; + } + break; + default: + status = UCC_ERR_NOT_SUPPORTED; } *task_h = &(task->super); - tl_debug(UCC_TASK_LIB(task), "initialized mcast bcast coll task %p", task); + tl_debug(UCC_TASK_LIB(task), "initialized mcast collective task %p", task); return UCC_OK; @@ -120,10 +96,8 @@ ucc_status_t ucc_tl_mlx5_coll_init(ucc_base_coll_args_t *coll_args, status = ucc_tl_mlx5_alltoall_init(coll_args, team, task_h); break; case UCC_COLL_TYPE_BCAST: - status = ucc_tl_mlx5_bcast_mcast_init(coll_args, team, task_h); - break; case UCC_COLL_TYPE_ALLGATHER: - status = ucc_tl_mlx5_allgather_mcast_init(coll_args, team, task_h); + status = ucc_tl_mlx5_coll_mcast_init(coll_args, team, task_h); break; default: status = UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/tl/mlx5/tl_mlx5_coll.h b/src/components/tl/mlx5/tl_mlx5_coll.h index a63c7462c9..4c0711adfc 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.h +++ b/src/components/tl/mlx5/tl_mlx5_coll.h @@ -113,9 +113,9 @@ static inline void ucc_tl_mlx5_put_schedule(ucc_tl_mlx5_schedule_t *schedule) ucc_mpool_put(schedule); } -ucc_status_t ucc_tl_mlx5_bcast_mcast_init(ucc_base_coll_args_t *coll_args, - ucc_base_team_t *team, - ucc_coll_task_t **task_h); +ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, + ucc_base_team_t *team, + ucc_coll_task_t **task_h); ucc_status_t ucc_tl_mlx5_task_finalize(ucc_coll_task_t *coll_task);