Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ch4/am: Caching buffer attribute in request and use typerep fast path for H2H #7082

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mpid/ch4/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ typedef struct MPIDIG_req_async {
struct iovec iov_one; /* used with MPIDIG_RECV_CONTIG */
MPIDIG_recv_data_copy_cb data_copy_cb; /* called in recv_init/recv_type_init for async
* data copying */
int typerep_flags;
} MPIDIG_rreq_async_t;

typedef struct MPIDIG_sreq_async {
Expand Down Expand Up @@ -209,6 +210,7 @@ typedef struct MPIDIG_req_t {
void *buffer;
MPI_Aint count;
MPI_Datatype datatype;
MPL_pointer_attr_t buf_attr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes the compile-time assertion about the size of the extended request object to fail in Jenkins tests. Do we need the whole attribute struct or can we get by with just the type?

union {
struct {
int dest;
Expand Down
3 changes: 3 additions & 0 deletions src/mpid/ch4/src/mpidig_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_handle_unexpected(void *buf, MPI_Aint count,
MPIDIG_REQUEST(rreq, datatype) = datatype;
MPIDIG_REQUEST(rreq, buffer) = buf;
MPIDIG_REQUEST(rreq, count) = count;
MPIR_GPU_query_pointer_attr(buf, &MPIDIG_REQUEST(rreq, buf_attr));
MPIDIG_recv_type_init(unexp_data_sz, rreq);
}
}
Expand Down Expand Up @@ -256,6 +257,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Data
MPIDIG_REQUEST(unexp_req, buffer) = buf;
MPIDIG_REQUEST(unexp_req, count) = count;
MPIDIG_REQUEST(unexp_req, req->status) &= ~MPIDIG_REQ_UNEXPECTED;
MPIR_GPU_query_pointer_attr(buf, &MPIDIG_REQUEST(rreq, buf_attr));
/* MPIDIG_recv_type_init will call the callback to finish the rndv protocol */
mpi_errno = MPIDIG_recv_type_init(data_sz, unexp_req);
} else {
Expand Down Expand Up @@ -284,6 +286,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Data

MPIR_Datatype_add_ref_if_not_builtin(datatype);
MPIDIG_prepare_recv_req(rank, tag, context_id, buf, count, datatype, rreq);
MPIR_GPU_query_pointer_attr(buf, &MPIDIG_REQUEST(rreq, buf_attr));
MPIDIG_enqueue_request(rreq, &MPIDI_global.per_vci[vci].posted_list, MPIDIG_PT2PT_POSTED);
}

Expand Down
20 changes: 14 additions & 6 deletions src/mpid/ch4/src/mpidig_recv_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ MPL_STATIC_INLINE_PREFIX void MPIDIG_recv_copy(void *in_data, MPIR_Request * rre
{
MPIDIG_rreq_async_t *p = &(MPIDIG_REQUEST(rreq, req->recv_async));
MPI_Aint in_data_sz = p->in_data_sz;
int flags = (MPIDIG_REQUEST(rreq, buf_attr).type &
(MPL_GPU_POINTER_UNREGISTERED_HOST | MPL_GPU_POINTER_REGISTERED_HOST))
? MPIR_TYPEREP_FLAG_H2H : MPIR_TYPEREP_FLAG_NONE;
if (in_data_sz == 0) {
/* otherwise if recv size = 0, it is at least a truncation error */
MPIR_STATUS_SET_COUNT(rreq->status, 0);
Expand All @@ -228,7 +231,7 @@ MPL_STATIC_INLINE_PREFIX void MPIDIG_recv_copy(void *in_data, MPIR_Request * rre
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype),
0, &actual_unpack_bytes, MPIR_TYPEREP_FLAG_NONE);
0, &actual_unpack_bytes, flags & MPIR_TYPEREP_FLAG_NONE);
if (!rreq->status.MPI_ERROR && in_data_sz > actual_unpack_bytes) {
/* Truncation error has been checked at MPIDIG_recv_type_init.
* If the receive buffer had enough space, but we still
Expand All @@ -251,7 +254,7 @@ MPL_STATIC_INLINE_PREFIX void MPIDIG_recv_copy(void *in_data, MPIR_Request * rre
}

data_sz = MPL_MIN(data_sz, in_data_sz);
MPIR_Typerep_copy(data, in_data, data_sz, MPIR_TYPEREP_FLAG_NONE);
MPIR_Typerep_copy(data, in_data, data_sz, flags & MPIR_TYPEREP_FLAG_NONE);
MPIR_STATUS_SET_COUNT(rreq->status, data_sz);
} else {
/* noncontig case */
Expand All @@ -263,7 +266,7 @@ MPL_STATIC_INLINE_PREFIX void MPIDIG_recv_copy(void *in_data, MPIR_Request * rre
for (int i = 0; i < iov_len && rem > 0; i++) {
int curr_len = MPL_MIN(rem, iov[i].iov_len);
MPIR_Typerep_copy(iov[i].iov_base, (char *) in_data + done, curr_len,
MPIR_TYPEREP_FLAG_NONE);
flags & MPIR_TYPEREP_FLAG_NONE);
rem -= curr_len;
done += curr_len;
}
Expand All @@ -282,6 +285,10 @@ MPL_STATIC_INLINE_PREFIX void MPIDIG_recv_copy(void *in_data, MPIR_Request * rre
MPL_STATIC_INLINE_PREFIX void MPIDIG_recv_setup(MPIR_Request * rreq)
{
MPIDIG_rreq_async_t *p = &(MPIDIG_REQUEST(rreq, req->recv_async));
MPIDIG_REQUEST(rreq, req->recv_async).typerep_flags =
(MPIDIG_REQUEST(rreq, buf_attr).type &
(MPL_GPU_POINTER_UNREGISTERED_HOST | MPL_GPU_POINTER_REGISTERED_HOST)) ?
MPIR_TYPEREP_FLAG_H2H : MPIR_TYPEREP_FLAG_NONE;
p->offset = 0;
if (p->recv_type == MPIDIG_RECV_DATATYPE) {
/* it's ready, rreq status to be set */
Expand Down Expand Up @@ -330,7 +337,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_recv_copy_seg(void *payload, MPI_Aint payloa
MPIDIG_REQUEST(rreq, buffer),
MPIDIG_REQUEST(rreq, count),
MPIDIG_REQUEST(rreq, datatype),
p->offset, &actual_unpack_bytes, MPIR_TYPEREP_FLAG_NONE);
p->offset, &actual_unpack_bytes,
p->typerep_flags & MPIR_TYPEREP_FLAG_NONE);
p->offset += payload_sz;
if (payload_sz > actual_unpack_bytes) {
/* basic element size mismatch */
Expand All @@ -353,15 +361,15 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_recv_copy_seg(void *payload, MPI_Aint payloa
for (int i = 0; i < p->iov_num; i++) {
if (payload_sz < p->iov_ptr[i].iov_len) {
MPIR_Typerep_copy(p->iov_ptr[i].iov_base, payload, payload_sz,
MPIR_TYPEREP_FLAG_NONE);
p->typerep_flags & MPIR_TYPEREP_FLAG_NONE);
p->iov_ptr[i].iov_base = (char *) p->iov_ptr[i].iov_base + payload_sz;
p->iov_ptr[i].iov_len -= payload_sz;
/* not done */
break;
} else {
/* fill one iov */
MPIR_Typerep_copy(p->iov_ptr[i].iov_base, payload, p->iov_ptr[i].iov_len,
MPIR_TYPEREP_FLAG_NONE);
p->typerep_flags & MPIR_TYPEREP_FLAG_NONE);
payload = (char *) payload + p->iov_ptr[i].iov_len;
payload_sz -= p->iov_ptr[i].iov_len;
iov_done++;
Expand Down