Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiomestre committed Nov 9, 2023
1 parent 74f27e2 commit e1902fc
Showing 1 changed file with 36 additions and 26 deletions.
62 changes: 36 additions & 26 deletions source/adapters/cuda/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ void getUSMHostOrDevicePtr(PtrT USMPtr, CUmemorytype *OutMemType,
ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,
ur_usm_advice_flags_t URAdviceFlags,
CUdevice Device) {

std::unordered_map<ur_usm_advice_flags_t, CUmem_advise>
URToCUMemAdviseDeviceFlagsMap = {
{UR_USM_ADVICE_FLAG_SET_READ_MOSTLY, CU_MEM_ADVISE_SET_READ_MOSTLY},
Expand Down Expand Up @@ -121,7 +122,10 @@ ur_result_t setCuMemAdvise(CUdeviceptr DevPtr, size_t Size,

for (auto &UnmappedFlag : UnmappedMemAdviceFlags) {
if (URAdviceFlags & UnmappedFlag) {
throw UR_RESULT_ERROR_INVALID_ENUMERATION;
setErrorMessage("Memory advice ignored because the CUDA backend does not "
"support some of the specified flags",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}
}

Expand Down Expand Up @@ -1363,20 +1367,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
ur_device_handle_t Device = hQueue->getContext()->getDevice();

bool Supported = true;
// Certain cuda devices and Windows do not have support for some Unified
// Memory features. cuMemPrefetchAsync requires concurrent memory access
// for managed memory. Therefore, ignore prefetch hint if concurrent managed
// memory access is not available.
if (!getAttribute(Device, CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS)) {
Supported = false;
setErrorMessage("Prefetch hint ignored as device does not support "
"concurrent managed access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

unsigned int IsManaged;
UR_CHECK_ERROR(cuPointerGetAttribute(
&IsManaged, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)pMem));
if (!IsManaged) {
Supported = false;
setErrorMessage("Prefetch hint ignored as prefetch only works with USM",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

ur_result_t Result = UR_RESULT_SUCCESS;
Expand All @@ -1393,10 +1401,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
UR_COMMAND_MEM_BUFFER_COPY, hQueue, CuStream));
UR_CHECK_ERROR(EventPtr->start());
}
if (Supported) {
UR_CHECK_ERROR(
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Device->get(), CuStream));
}
UR_CHECK_ERROR(
cuMemPrefetchAsync((CUdeviceptr)pMem, size, Device->get(), CuStream));

if (phEvent) {
UR_CHECK_ERROR(EventPtr->record());
*phEvent = EventPtr.release();
Expand All @@ -1416,7 +1423,6 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
&PointerRangeSize, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)pMem));
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);

bool Supported = true;
// Certain cuda devices and Windows do not have support for some Unified
// Memory features. Passing CU_MEM_ADVISE_SET/CLEAR_PREFERRED_LOCATION and
// to cuMemAdvise on a GPU device requires the GPU device to report a non-zero
Expand All @@ -1429,7 +1435,10 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
(advice & UR_USM_ADVICE_FLAG_DEFAULT)) {
ur_device_handle_t Device = hQueue->getContext()->getDevice();
if (!getAttribute(Device, CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS)) {
Supported = false;
setErrorMessage("Mem advise ignored as device does not support "
"concurrent managed access",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

// TODO: If ptr points to valid system-allocated pageable memory we should
Expand All @@ -1441,7 +1450,10 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
UR_CHECK_ERROR(cuPointerGetAttribute(
&IsManaged, CU_POINTER_ATTRIBUTE_IS_MANAGED, (CUdeviceptr)pMem));
if (!IsManaged) {
Supported = false;
setErrorMessage(
"Memory advice ignored as memory advices only works with USM",
UR_RESULT_SUCCESS);
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
}

ur_result_t Result = UR_RESULT_SUCCESS;
Expand All @@ -1457,21 +1469,19 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
UR_CHECK_ERROR(EventPtr->start());
}

if (Supported) {
if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
CU_MEM_ADVISE_UNSET_READ_MOSTLY,
hQueue->getContext()->getDevice()->get()));
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
hQueue->getContext()->getDevice()->get()));
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
CU_MEM_ADVISE_UNSET_ACCESSED_BY,
hQueue->getContext()->getDevice()->get()));
} else {
Result = setCuMemAdvise((CUdeviceptr)pMem, size, advice,
hQueue->getContext()->getDevice()->get());
}
if (advice & UR_USM_ADVICE_FLAG_DEFAULT) {
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
CU_MEM_ADVISE_UNSET_READ_MOSTLY,
hQueue->getContext()->getDevice()->get()));
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION,
hQueue->getContext()->getDevice()->get()));
UR_CHECK_ERROR(cuMemAdvise((CUdeviceptr)pMem, size,
CU_MEM_ADVISE_UNSET_ACCESSED_BY,
hQueue->getContext()->getDevice()->get()));
} else {
Result = setCuMemAdvise((CUdeviceptr)pMem, size, advice,
hQueue->getContext()->getDevice()->get());
}

if (phEvent) {
Expand Down

0 comments on commit e1902fc

Please sign in to comment.