Skip to content

Commit

Permalink
TL/MLX5: adding mcast allgather staging based algo
Browse files Browse the repository at this point in the history
  • Loading branch information
MamziB committed Jun 27, 2024
1 parent 36389f4 commit 044e785
Show file tree
Hide file tree
Showing 14 changed files with 1,167 additions and 49 deletions.
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ mcast = \
mcast/tl_mlx5_mcast_service_coll.c \
mcast/tl_mlx5_mcast_one_sided_reliability.h \
mcast/tl_mlx5_mcast_one_sided_reliability.c \
mcast/tl_mlx5_mcast_allgather.h \
mcast/tl_mlx5_mcast_allgather.c \
mcast/tl_mlx5_mcast_team.c

sources = \
Expand Down
181 changes: 155 additions & 26 deletions src/components/tl/mlx5/mcast/tl_mlx5_mcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#define GRH_LENGTH 40
#define DROP_THRESHOLD 1000
#define MAX_COMM_POW2 32
#define MAX_GROUP_COUNT 64

/* Allgather RDMA-based reliability designs */
#define ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE 1024
Expand All @@ -40,6 +41,24 @@
#define ONE_SIDED_ASYNCHRONOUS_PROTO 2
#define ONE_SIDED_SLOTS_COUNT 2 /* number of memory slots during async design */
#define ONE_SIDED_SLOTS_INFO_SIZE sizeof(uint32_t) /* size of metadata prepended to each slots in bytes */
#define ONE_SIDED_INVALID -1
#define ONE_SIDED_VALID -2
#define ONE_SIDED_PENDING_INFO -3
#define ONE_SIDED_PENDING_DATA -4
#define ONE_SIDED_MAX_ALLGATHER_COUNTER 32
#define ONE_SIDED_MAX_CONCURRENT_LEVEL 64

/* 32 here is the bit count of ib send immediate */
#define ONE_SIDED_MAX_PACKET_COUNT(_max_count) \
do { \
int pow2; \
int tmp; \
pow2 = log(ONE_SIDED_RELIABILITY_MAX_TEAM_SIZE) / log(2); \
tmp = 32 - pow2; \
pow2 = log(ONE_SIDED_MAX_ALLGATHER_COUNTER) / log(2); \
tmp = tmp - pow2; \
_max_count = pow(2, tmp); \
} while(0);

enum {
MCAST_PROTO_EAGER, /* Internal staging buffers */
Expand Down Expand Up @@ -87,7 +106,7 @@ typedef struct ucc_tl_mlx5_mcast_p2p_interface {
ucc_tl_mlx5_mcast_p2p_recv_nb_fn_t recv_nb;
} ucc_tl_mlx5_mcast_p2p_interface_t;

typedef struct mcast_coll_comm_init_spec {
typedef struct ucc_tl_mlx5_mcast_coll_comm_init_spec {
ucc_tl_mlx5_mcast_p2p_interface_t p2p_iface;
int sx_depth;
int rx_depth;
Expand All @@ -97,7 +116,9 @@ typedef struct mcast_coll_comm_init_spec {
int post_recv_thresh;
int scq_moderation;
int wsize;
int max_push_send;
int max_eager;
int one_sided_reliability_enable;
void *oob;
} ucc_tl_mlx5_mcast_coll_comm_init_spec_t;

Expand Down Expand Up @@ -175,18 +196,27 @@ struct pp_packet {
uint32_t psn;
int length;
uintptr_t context;
uintptr_t buf;
uint32_t packet_counter;
int qp_id;
uintptr_t buf; // buffer address, initialized once
};

struct mcast_ctx {
struct ibv_qp *qp;
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;
struct ibv_qp *qp;
struct ibv_ah *ah;
struct ibv_send_wr swr;
struct ibv_sge ssg;

// RC connection info for supporing one-sided based relibality
struct ibv_qp **rc_qp;
uint16_t *rc_lid;
union ibv_gid *rc_gid;

// multiple mcast group
struct ibv_qp **qp_list;
struct ibv_ah **ah_list;
struct ibv_send_wr *swr_list;
struct ibv_sge *ssg_list;
};

struct packet {
Expand Down Expand Up @@ -230,6 +260,10 @@ typedef struct ucc_tl_mlx5_mcast_one_sided_reliability_comm {
int slot_size;
/* coll req that is used during the oob service calls */
ucc_service_coll_req_t *reliability_req;
int reliability_enabled;
int reliability_ready;
int rdma_read_in_progress;
int slots_state;
} ucc_tl_mlx5_mcast_one_sided_reliability_comm_t;

typedef struct ucc_tl_mlx5_mcast_service_coll {
Expand Down Expand Up @@ -275,6 +309,7 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
struct packet p2p_spkt[MAX_COMM_POW2];
ucc_list_link_t bpool;
ucc_list_link_t pending_q;
ucc_list_link_t posted_q;
struct mcast_ctx mcast;
int reliable_in_progress;
int recv_drop_packet_in_progress;
Expand All @@ -298,6 +333,12 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm {
ucc_tl_mlx5_mcast_service_coll_t service_coll;
struct rdma_cm_event *event;
ucc_tl_mlx5_mcast_one_sided_reliability_comm_t one_sided;
int mcast_group_count;
int ag_under_progress_counter;
int pending_recv_per_qp[MAX_GROUP_COUNT];
int ag_counter;
int ag_max_num_packets;
int max_push_send;
struct pp_packet *r_window[1]; // note: do not add any new variable after here
} ucc_tl_mlx5_mcast_coll_comm_t;

Expand Down Expand Up @@ -341,27 +382,60 @@ typedef struct ucc_tl_mlx5_mcast_nack_req {
_comm->wsize]->psn == (_psn)) \
)

typedef struct ucc_tl_mlx5_mcast_tensor {
int group_id;
size_t offset;
size_t offset_left;
int root;
int count;
int to_recv;
int to_send_left;
} ucc_tl_mlx5_mcast_tensor_t;

typedef struct ucc_tl_mlx5_mcast_pipelined_ag_schedule {
ucc_tl_mlx5_mcast_tensor_t multicast_op[ONE_SIDED_MAX_CONCURRENT_LEVEL];
ucc_tl_mlx5_mcast_tensor_t prepost_buf_op[ONE_SIDED_MAX_CONCURRENT_LEVEL];
int prepost_buf_op_done;
int multicast_op_done;
int total_steps;
int num_recvd;
int to_recv;
int to_send;
} ucc_tl_mlx5_mcast_pipelined_ag_schedule_t;

typedef struct ucc_tl_mlx5_mcast_coll_req {
ucc_tl_mlx5_mcast_coll_comm_t *comm;
size_t length; /* bcast buffer size */
int proto;
struct ibv_mr *mr;
struct ibv_recv_wr *rwr;
struct ibv_sge *rsgs;
void *rreg;
char *ptr;
int am_root;
ucc_rank_t root;
void **rbufs;
int first_send_psn;
int to_send;
int to_recv;
ucc_rank_t parent;
uint32_t start_psn;
int num_packets;
int last_pkt_len;
int offset;
ucc_memory_type_t buf_mem_type;
ucc_tl_mlx5_mcast_coll_comm_t *comm;
size_t length;
int proto;
struct ibv_mr *mr;
struct ibv_mr *recv_mr;
struct ibv_recv_wr *rwr;
struct ibv_sge *rsgs;
void *rreg;
char *ptr;
char *rptr;
int am_root;
ucc_rank_t root;
void **rbufs;
int first_send_psn;
int to_send;
int to_recv;
ucc_rank_t parent;
uint32_t start_psn;
int num_packets;
int last_pkt_len;
int offset;
ucc_memory_type_t buf_mem_type;
int one_sided_reliability_scheme;
int concurreny_level;
int ag_counter;
int state;
ucc_tl_mlx5_mcast_pipelined_ag_schedule_t *ag_schedule;
int total_steps;
int step;
ucc_service_coll_req_t *allgather_rkeys_req;
ucc_service_coll_req_t *barrier_req;
void *recv_rreg;
} ucc_tl_mlx5_mcast_coll_req_t;

typedef struct ucc_tl_mlx5_mcast_oob_p2p_context {
Expand All @@ -384,6 +458,55 @@ static inline struct pp_packet* ucc_tl_mlx5_mcast_buf_get_free(ucc_tl_mlx5_mcast
return pp;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t *comm,
ucc_tl_mlx5_mcast_coll_req_t *req,
int group_id, ucc_rank_t root,
int coll_type,
int count,
size_t offset)
{
struct ibv_recv_wr *bad_wr = NULL;
struct ibv_recv_wr *rwr = comm->call_rwr;
struct ibv_sge *sge = comm->call_rsgs;
struct pp_packet *pp = NULL;
uint32_t i;

for (i = 0; i < count; i++) {
if (NULL == (pp = ucc_tl_mlx5_mcast_buf_get_free(comm))) {
tl_error(comm->lib, "not enought free pp packets to cover the entire message");
return UCC_ERR_NO_RESOURCE;
}

assert(offset % comm->max_per_packet == 0);
pp->packet_counter = offset / comm->max_per_packet;
pp->qp_id = group_id;
rwr[i].wr_id = ((uint64_t) pp);
rwr[i].next = &rwr[i+1];
sge[2*i + 1].addr = (uint64_t)req->rptr + root * req->length + offset;
sge[2*i + 1].lkey = req->recv_mr->lkey;
offset += comm->max_per_packet;

if (i == count - 1) {
sge[2*i + 1].length = req->last_pkt_len;
} else {
sge[2*i + 1].length = comm->max_per_packet;
}
}

if (i > 0) {
rwr[i-1].next = NULL;
if (ibv_post_recv(comm->mcast.qp_list[group_id], &rwr[0], &bad_wr)) {
tl_error(comm->lib, "Failed to prepost recvs: errno %d buffer count %d",
errno, i);
return UCC_ERR_NO_RESOURCE;
}
comm->pending_recv += i;
comm->pending_recv_per_qp[group_id] += i;
}

return UCC_OK;
}

static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast_coll_comm_t* comm)
{
struct ibv_recv_wr *bad_wr = NULL;
Expand Down Expand Up @@ -421,6 +544,12 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast
return UCC_OK;
}

static inline uint64_t ucc_tl_mlx5_mcast_get_timer(void)
{
double t_second = ucc_get_time();
return (uint64_t) (t_second * 1000000);
}

ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context,
ucc_tl_mlx5_mcast_team_t **mcast_team,
ucc_tl_mlx5_mcast_context_t *ctx,
Expand Down
Loading

0 comments on commit 044e785

Please sign in to comment.