Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ch4/am: Caching buffer attribute in request and use typerep fast path for H2H #7082

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/include/mpir_typerep.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ int MPIR_Typerep_iov_len(MPI_Aint count, MPI_Datatype type, MPI_Aint max_iov_byt

#define MPIR_TYPEREP_FLAG_NONE 0x0UL
#define MPIR_TYPEREP_FLAG_STREAM 0x1UL
#define MPIR_TYPEREP_FLAG_H2H 0x2UL

int MPIR_Typerep_copy(void *outbuf, const void *inbuf, MPI_Aint num_bytes, uint32_t flags);
int MPIR_Typerep_pack(const void *inbuf, MPI_Aint incount, MPI_Datatype datatype,
Expand Down
254 changes: 224 additions & 30 deletions src/mpi/datatype/typerep/src/typerep_yaksa_pack.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@
*/

#define IS_HOST(attr) \
((attr).type == MPL_GPU_POINTER_UNREGISTERED_HOST || \
(attr).type == MPL_GPU_POINTER_REGISTERED_HOST)
((attr).type & (MPL_GPU_POINTER_UNREGISTERED_HOST | MPL_GPU_POINTER_REGISTERED_HOST))

/* When a returned typerep_req is expected, using the nonblocking yaksa routine and
* return the request; otherwise use the blocking yaksa routine. */
Expand All @@ -87,31 +86,35 @@ static int typerep_do_copy(void *outbuf, const void *inbuf, MPI_Aint num_bytes,
goto fn_exit;
}

if (flags & MPIR_TYPEREP_FLAG_H2H) {
if (flags & MPIR_TYPEREP_FLAG_STREAM) {
MPIR_Memcpy(outbuf, inbuf, num_bytes);
} else {
MPIR_Memcpy_stream(outbuf, inbuf, num_bytes);
}
}

MPL_pointer_attr_t inattr, outattr;
MPIR_GPU_query_pointer_attr(inbuf, &inattr);
MPIR_GPU_query_pointer_attr(outbuf, &outattr);

if (IS_HOST(inattr) && IS_HOST(outattr)) {
MPIR_Memcpy(outbuf, inbuf, num_bytes);
uintptr_t actual_pack_bytes;

yaksa_info_t info = MPII_yaksa_get_info(&inattr, &outattr);
if (typerep_req) {
typerep_req->info = info;
rc = yaksa_ipack(inbuf, num_bytes, YAKSA_TYPE__BYTE, 0, outbuf, num_bytes,
&actual_pack_bytes, info, YAKSA_OP__REPLACE,
(yaksa_request_t *) & typerep_req->req);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
MPIR_Assert(actual_pack_bytes == num_bytes);
} else {
uintptr_t actual_pack_bytes;

yaksa_info_t info = MPII_yaksa_get_info(&inattr, &outattr);
if (typerep_req) {
typerep_req->info = info;
rc = yaksa_ipack(inbuf, num_bytes, YAKSA_TYPE__BYTE, 0, outbuf, num_bytes,
&actual_pack_bytes, info, YAKSA_OP__REPLACE,
(yaksa_request_t *) & typerep_req->req);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
MPIR_Assert(actual_pack_bytes == num_bytes);
} else {
rc = yaksa_pack(inbuf, num_bytes, YAKSA_TYPE__BYTE, 0, outbuf, num_bytes,
&actual_pack_bytes, info, YAKSA_OP__REPLACE);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
MPIR_Assert(actual_pack_bytes == num_bytes);
rc = MPII_yaksa_free_info(info);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
}
rc = yaksa_pack(inbuf, num_bytes, YAKSA_TYPE__BYTE, 0, outbuf, num_bytes,
&actual_pack_bytes, info, YAKSA_OP__REPLACE);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
MPIR_Assert(actual_pack_bytes == num_bytes);
rc = MPII_yaksa_free_info(info);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
}

fn_exit:
Expand Down Expand Up @@ -213,6 +216,94 @@ static int typerep_do_pack(const void *inbuf, MPI_Aint incount, MPI_Datatype dat
goto fn_exit;
}

static int typerep_do_pack_h2h(const void *inbuf, MPI_Aint incount, MPI_Datatype datatype,
MPI_Aint inoffset, void *outbuf, MPI_Aint max_pack_bytes,
MPI_Aint * actual_pack_bytes, MPIR_Typerep_req * typerep_req,
uint32_t flags)
{
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
int rc;

if (typerep_req) {
typerep_req->req = MPIR_TYPEREP_REQ_NULL;
}

if (incount == 0) {
*actual_pack_bytes = 0;
goto fn_exit;
}

MPIR_Assert(datatype != MPI_DATATYPE_NULL);

int is_contig = 0;
int element_size = -1;
const void *inbuf_ptr; /* adjusted by true_lb */
MPI_Aint total_size = 0;
if (HANDLE_IS_BUILTIN(datatype)) {
is_contig = 1;
element_size = MPIR_Datatype_get_basic_size(datatype);
inbuf_ptr = inbuf;
total_size = incount * element_size;
} else {
MPIR_Datatype *dtp;
MPIR_Datatype_get_ptr(datatype, dtp);
is_contig = dtp->is_contig;
element_size = dtp->builtin_element_size;
inbuf_ptr = MPIR_get_contig_ptr(inbuf, dtp->true_lb);
total_size = incount * dtp->size;
}

/* only query the pointer attributes in case of relative addressing */
// bool rel_addressing = (inbuf != MPI_BOTTOM);
// if (rel_addressing) {
// MPIR_GPU_query_pointer_attr(inbuf_ptr, &inattr);
// MPIR_GPU_query_pointer_attr(outbuf, &outattr);
// }

if (is_contig) {
MPI_Aint real_bytes = MPL_MIN(total_size - inoffset, max_pack_bytes);
/* Make sure we never pack partial element */
real_bytes -= real_bytes % element_size;
if (flags & MPIR_TYPEREP_FLAG_STREAM) {
MPIR_Memcpy_stream(outbuf, MPIR_get_contig_ptr(inbuf_ptr, inoffset), real_bytes);
} else {
MPIR_Memcpy(outbuf, MPIR_get_contig_ptr(inbuf_ptr, inoffset), real_bytes);
}
*actual_pack_bytes = real_bytes;
goto fn_exit;
}

yaksa_type_t type = MPII_Typerep_get_yaksa_type(datatype);
yaksa_info_t info = MPII_yaksa_info_nogpu;

uintptr_t real_pack_bytes;
if (typerep_req) {
typerep_req->info = info;
rc = yaksa_ipack(inbuf, incount, type, inoffset, outbuf, max_pack_bytes,
&real_pack_bytes, info, YAKSA_OP__REPLACE,
(yaksa_request_t *) & typerep_req->req);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
} else {
rc = yaksa_pack(inbuf, incount, type, inoffset, outbuf, max_pack_bytes,
&real_pack_bytes, info, YAKSA_OP__REPLACE);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
if (info) {
rc = MPII_yaksa_free_info(info);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
}
}

*actual_pack_bytes = (MPI_Aint) real_pack_bytes;

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

/* This function checks whether the operation is supported in yaksa for the
* provided datatype */
int MPIR_Typerep_reduce_is_supported(MPI_Op op, MPI_Aint count, MPI_Datatype datatype)
Expand Down Expand Up @@ -372,6 +463,89 @@ static int typerep_do_unpack(const void *inbuf, MPI_Aint insize, void *outbuf, M
goto fn_exit;
}

static int typerep_do_unpack_h2h(const void *inbuf, MPI_Aint insize, void *outbuf,
MPI_Aint outcount, MPI_Datatype datatype, MPI_Aint outoffset,
MPI_Aint * actual_unpack_bytes, MPIR_Typerep_req * typerep_req,
uint32_t flags)
{
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
int rc;

if (typerep_req) {
typerep_req->req = MPIR_TYPEREP_REQ_NULL;
}

if (insize == 0) {
*actual_unpack_bytes = 0;
goto fn_exit;
}

MPIR_Assert(datatype != MPI_DATATYPE_NULL);

int is_contig = 0;
int element_size = -1;
const void *outbuf_ptr; /* adjusted by true_lb */
MPI_Aint total_size = 0;
if (HANDLE_IS_BUILTIN(datatype)) {
is_contig = 1;
element_size = MPIR_Datatype_get_basic_size(datatype);
outbuf_ptr = outbuf;
total_size = outcount * element_size;
} else {
MPIR_Datatype *dtp;
MPIR_Datatype_get_ptr(datatype, dtp);
is_contig = dtp->is_contig;
element_size = dtp->builtin_element_size;
outbuf_ptr = MPIR_get_contig_ptr(outbuf, dtp->true_lb);
total_size = outcount * dtp->size;
}

if (is_contig) {
*actual_unpack_bytes = MPL_MIN(total_size - outoffset, insize);
/* We assume the amount we unpack is multiple of element_size */
MPIR_Assert(element_size < 0 || *actual_unpack_bytes % element_size == 0);
if (flags & MPIR_TYPEREP_FLAG_STREAM) {
MPIR_Memcpy_stream(MPIR_get_contig_ptr(outbuf_ptr, outoffset), inbuf,
*actual_unpack_bytes);
} else {
MPIR_Memcpy(MPIR_get_contig_ptr(outbuf_ptr, outoffset), inbuf, *actual_unpack_bytes);
}
goto fn_exit;
}

yaksa_type_t type = MPII_Typerep_get_yaksa_type(datatype);
yaksa_info_t info = (outbuf != MPI_BOTTOM) ? MPII_yaksa_info_nogpu : NULL;

uintptr_t real_insize = MPL_MIN(total_size - outoffset, insize);

uintptr_t real_unpack_bytes;
if (typerep_req) {
typerep_req->info = info;
rc = yaksa_iunpack(inbuf, real_insize, outbuf, outcount, type, outoffset,
&real_unpack_bytes, info, YAKSA_OP__REPLACE,
(yaksa_request_t *) & typerep_req->req);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
} else {
rc = yaksa_unpack(inbuf, real_insize, outbuf, outcount, type, outoffset, &real_unpack_bytes,
info, YAKSA_OP__REPLACE);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
if (info) {
rc = MPII_yaksa_free_info(info);
MPIR_ERR_CHKANDJUMP(rc, mpi_errno, MPI_ERR_INTERN, "**yaksa");
}
}

*actual_unpack_bytes = (MPI_Aint) real_unpack_bytes;

fn_exit:
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
goto fn_exit;
}

int MPIR_Typerep_icopy(void *outbuf, const void *inbuf, MPI_Aint num_bytes,
MPIR_Typerep_req * typerep_req, uint32_t flags)
{
Expand Down Expand Up @@ -402,8 +576,13 @@ int MPIR_Typerep_ipack(const void *inbuf, MPI_Aint incount, MPI_Datatype datatyp
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
mpi_errno = typerep_do_pack(inbuf, incount, datatype, inoffset, outbuf, max_pack_bytes,
actual_pack_bytes, typerep_req, flags);
if (flags & MPIR_TYPEREP_FLAG_H2H) {
mpi_errno = typerep_do_pack_h2h(inbuf, incount, datatype, inoffset, outbuf, max_pack_bytes,
actual_pack_bytes, typerep_req, flags);
} else {
mpi_errno = typerep_do_pack(inbuf, incount, datatype, inoffset, outbuf, max_pack_bytes,
actual_pack_bytes, typerep_req, flags);
}

MPIR_FUNC_EXIT;
return mpi_errno;
Expand All @@ -416,8 +595,13 @@ int MPIR_Typerep_pack(const void *inbuf, MPI_Aint incount, MPI_Datatype datatype
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
mpi_errno = typerep_do_pack(inbuf, incount, datatype, inoffset, outbuf, max_pack_bytes,
actual_pack_bytes, NULL, flags);
if (flags & MPIR_TYPEREP_FLAG_H2H) {
mpi_errno = typerep_do_pack_h2h(inbuf, incount, datatype, inoffset, outbuf, max_pack_bytes,
actual_pack_bytes, NULL, flags);
} else {
mpi_errno = typerep_do_pack(inbuf, incount, datatype, inoffset, outbuf, max_pack_bytes,
actual_pack_bytes, NULL, flags);
}

MPIR_FUNC_EXIT;
return mpi_errno;
Expand All @@ -430,8 +614,13 @@ int MPIR_Typerep_iunpack(const void *inbuf, MPI_Aint insize, void *outbuf, MPI_A
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
mpi_errno = typerep_do_unpack(inbuf, insize, outbuf, outcount, datatype, outoffset,
actual_unpack_bytes, typerep_req, flags);
if (flags & MPIR_TYPEREP_FLAG_H2H) {
mpi_errno = typerep_do_unpack_h2h(inbuf, insize, outbuf, outcount, datatype, outoffset,
actual_unpack_bytes, typerep_req, flags);
} else {
mpi_errno = typerep_do_unpack(inbuf, insize, outbuf, outcount, datatype, outoffset,
actual_unpack_bytes, typerep_req, flags);
}

MPIR_FUNC_EXIT;
return mpi_errno;
Expand All @@ -444,8 +633,13 @@ int MPIR_Typerep_unpack(const void *inbuf, MPI_Aint insize, void *outbuf, MPI_Ai
MPIR_FUNC_ENTER;

int mpi_errno = MPI_SUCCESS;
mpi_errno = typerep_do_unpack(inbuf, insize, outbuf, outcount, datatype, outoffset,
actual_unpack_bytes, NULL, flags);
if (flags & MPIR_TYPEREP_FLAG_H2H) {
mpi_errno = typerep_do_unpack_h2h(inbuf, insize, outbuf, outcount, datatype, outoffset,
actual_unpack_bytes, NULL, flags);
} else {
mpi_errno = typerep_do_unpack(inbuf, insize, outbuf, outcount, datatype, outoffset,
actual_unpack_bytes, NULL, flags);
}

MPIR_FUNC_EXIT;
return mpi_errno;
Expand Down
2 changes: 2 additions & 0 deletions src/mpid/ch4/include/mpidpre.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ typedef struct MPIDIG_req_async {
struct iovec iov_one; /* used with MPIDIG_RECV_CONTIG */
MPIDIG_recv_data_copy_cb data_copy_cb; /* called in recv_init/recv_type_init for async
* data copying */
int typerep_flags;
} MPIDIG_rreq_async_t;

typedef struct MPIDIG_sreq_async {
Expand Down Expand Up @@ -209,6 +210,7 @@ typedef struct MPIDIG_req_t {
void *buffer;
MPI_Aint count;
MPI_Datatype datatype;
MPL_pointer_attr_t buf_attr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes the compile-time assertion about the size of the extended request object to fail in Jenkins tests. Do we need the whole attribute struct or can we get by with just the type?

union {
struct {
int dest;
Expand Down
3 changes: 3 additions & 0 deletions src/mpid/ch4/src/mpidig_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_handle_unexpected(void *buf, MPI_Aint count,
MPIDIG_REQUEST(rreq, datatype) = datatype;
MPIDIG_REQUEST(rreq, buffer) = buf;
MPIDIG_REQUEST(rreq, count) = count;
MPIR_GPU_query_pointer_attr(buf, &MPIDIG_REQUEST(rreq, buf_attr));
MPIDIG_recv_type_init(unexp_data_sz, rreq);
}
}
Expand Down Expand Up @@ -256,6 +257,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Data
MPIDIG_REQUEST(unexp_req, buffer) = buf;
MPIDIG_REQUEST(unexp_req, count) = count;
MPIDIG_REQUEST(unexp_req, req->status) &= ~MPIDIG_REQ_UNEXPECTED;
MPIR_GPU_query_pointer_attr(buf, &MPIDIG_REQUEST(rreq, buf_attr));
/* MPIDIG_recv_type_init will call the callback to finish the rndv protocol */
mpi_errno = MPIDIG_recv_type_init(data_sz, unexp_req);
} else {
Expand Down Expand Up @@ -284,6 +286,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDIG_do_irecv(void *buf, MPI_Aint count, MPI_Data

MPIR_Datatype_add_ref_if_not_builtin(datatype);
MPIDIG_prepare_recv_req(rank, tag, context_id, buf, count, datatype, rreq);
MPIR_GPU_query_pointer_attr(buf, &MPIDIG_REQUEST(rreq, buf_attr));
MPIDIG_enqueue_request(rreq, &MPIDI_global.per_vci[vci].posted_list, MPIDIG_PT2PT_POSTED);
}

Expand Down
Loading