diff --git a/src/mpid/ch4/netmod/ofi/ofi_recv.h b/src/mpid/ch4/netmod/ofi/ofi_recv.h index fd00fc42fbb..ffd66c98c35 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_recv.h +++ b/src/mpid/ch4/netmod/ofi/ofi_recv.h @@ -182,7 +182,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf, } } - if (data_sz && attr.type == MPL_GPU_POINTER_DEV) { + if (data_sz && MPL_gpu_query_pointer_is_dev(recv_buf, &attr)) { MPIDI_OFI_register_am_bufs(); if (!MPIDI_OFI_ENABLE_HMEM || !dt_contig || (MPIDI_OFI_ENABLE_MR_HMEM && !register_mem)) { /* FIXME: at this point, GPU data takes host-buffer staging diff --git a/src/mpid/ch4/netmod/ofi/ofi_send.h b/src/mpid/ch4/netmod/ofi/ofi_send.h index afcaee77cf3..423f44dd017 100644 --- a/src/mpid/ch4/netmod/ofi/ofi_send.h +++ b/src/mpid/ch4/netmod/ofi/ofi_send.h @@ -253,7 +253,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_normal(const void *buf, MPI_Aint cou register_mem = true; } - if (data_sz && attr.type == MPL_GPU_POINTER_DEV) { + if (data_sz && MPL_gpu_query_pointer_is_dev(send_buf, &attr)) { MPIDI_OFI_register_am_bufs(); if (!MPIDI_OFI_ENABLE_HMEM || !dt_contig || (MPIDI_OFI_ENABLE_MR_HMEM && !register_mem)) { /* Force packing of GPU buffer in host memory */ @@ -600,14 +600,13 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send(const void *buf, MPI_Aint count, MPI if (MPIR_CVAR_CH4_OFI_ENABLE_INJECT && !syncflag && dt_contig && (data_sz <= MPIDI_OFI_global.max_buffered_send)) { MPI_Aint actual_pack_bytes = 0; - if (attr.type == MPL_GPU_POINTER_DEV && data_sz) { + if (data_sz && MPL_gpu_query_pointer_is_dev(send_buf, &attr)) { MPIDI_OFI_register_am_bufs(); if (!MPIDI_OFI_ENABLE_HMEM) { /* Force pack for GPU buffer. */ void *host_buf = MPL_malloc(data_sz, MPL_MEM_OTHER); int fast_copy = 0; - if (attr.type == MPL_GPU_POINTER_DEV && dt_contig && - data_sz <= MPIR_CVAR_CH4_IPC_GPU_FAST_COPY_MAX_SIZE) { + if (data_sz <= MPIR_CVAR_CH4_IPC_GPU_FAST_COPY_MAX_SIZE) { int mpl_err; mpl_err = MPL_gpu_fast_memcpy(send_buf, &attr, host_buf, NULL, data_sz); if (mpl_err == MPL_SUCCESS) { @@ -618,8 +617,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send(const void *buf, MPI_Aint count, MPI if (!fast_copy) { MPL_gpu_engine_type_t engine = MPIDI_OFI_gpu_get_send_engine_type(MPIR_CVAR_CH4_OFI_GPU_SEND_ENGINE_TYPE); - if (dt_contig && engine != MPL_GPU_ENGINE_TYPE_LAST && - MPL_gpu_query_pointer_is_dev(send_buf, &attr)) { + if (engine != MPL_GPU_ENGINE_TYPE_LAST) { mpi_errno = MPIR_Localcopy_gpu(send_buf, data_sz, MPI_BYTE, 0, &attr, host_buf, data_sz, MPI_BYTE, 0, NULL,