Skip to content

Commit

Permalink
Merge pull request #937 from Bensuo/support-prefetch-advise-cmd-buffers
Browse files Browse the repository at this point in the history
[EXP][CMDBUF] Add adapters code for Prefetch and Advise commands
  • Loading branch information
kbenzie authored Jan 3, 2024
2 parents 749d8e5 + 01cd56d commit 5d8173a
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 0 deletions.
71 changes: 71 additions & 0 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,77 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */,
size_t /*Size*/, ur_usm_migration_flags_t /*Flags*/,
uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
// Prefetch cmd is not supported by Cuda Graph.
// We implement it as an empty node to enforce dependencies.
ur_result_t Result = UR_RESULT_SUCCESS;
CUgraphNode GraphNode;

std::vector<CUgraphNode> DepsList;
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
pSyncPointWaitList, DepsList),
Result);

try {
// Add an empty node to preserve dependencies.
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
DepsList.data(), DepsList.size()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));

setErrorMessage("Prefetch hint ignored and replaced with empty node as "
"prefetch is not supported by CUDA Graph backend",
UR_RESULT_SUCCESS);
Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC;
} catch (ur_result_t Err) {
Result = Err;
}
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */,
size_t /*Size*/, ur_usm_advice_flags_t /*Advice*/,
uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
// Mem-Advise cmd is not supported by Cuda Graph.
// We implement it as an empty node to enforce dependencies.
ur_result_t Result = UR_RESULT_SUCCESS;
CUgraphNode GraphNode;

std::vector<CUgraphNode> DepsList;
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
pSyncPointWaitList, DepsList),
Result);

try {
// Add an empty node to preserve dependencies.
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
DepsList.data(), DepsList.size()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));

setErrorMessage("Memory advice ignored and replaced with empty node as "
"memory advice is not supported by CUDA Graph backend",
UR_RESULT_SUCCESS);
Result = UR_RESULT_ERROR_ADAPTER_SPECIFIC;
} catch (ur_result_t Err) {
Result = Err;
}

return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/cuda/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
15 changes: 15 additions & 0 deletions source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t, const void *, size_t,
ur_usm_migration_flags_t, uint32_t,
const ur_exp_command_buffer_sync_point_t *,
ur_exp_command_buffer_sync_point_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t, const void *, size_t, ur_usm_advice_flags_t,
uint32_t, const ur_exp_command_buffer_sync_point_t *,
ur_exp_command_buffer_sync_point_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t, ur_queue_handle_t, uint32_t,
const ur_event_handle_t *, ur_event_handle_t *) {
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/hip/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
100 changes: 100 additions & 0 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,106 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
SyncPointWaitList, SyncPoint);
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
ur_usm_migration_flags_t Flags, uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {
std::ignore = Flags;

std::vector<ze_event_handle_t> ZeEventList;
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, ZeEventList));

if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_PREFETCH;

// Get sync point and register the event with it.
*SyncPoint = CommandBuffer->GetNextSyncPoint();
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);

// Add the prefetch command to the command buffer.
// Note that L0 does not handle migration flags.
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
(CommandBuffer->ZeCommandList, Mem, Size));

// Level Zero does not have a completion "event" with the prefetch API,
// so manually add command to signal our event.
ZE2UR_CALL(zeCommandListAppendSignalEvent,
(CommandBuffer->ZeCommandList, LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
ur_usm_advice_flags_t Advice, uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {
// A memory chunk can be advised with muliple memory advices
// We therefore prefer if statements to switch cases to combine all potential
// flags
uint32_t Value = 0;
if (Advice & UR_USM_ADVICE_FLAG_SET_READ_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_READ_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_READ_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_PREFERRED_LOCATION);
if (Advice & UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_NON_ATOMIC_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_NON_ATOMIC_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_BIAS_CACHED)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_BIAS_CACHED);
if (Advice & UR_USM_ADVICE_FLAG_BIAS_UNCACHED)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_BIAS_UNCACHED);
if (Advice & UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_PREFERRED_LOCATION);

ze_memory_advice_t ZeAdvice = static_cast<ze_memory_advice_t>(Value);

std::vector<ze_event_handle_t> ZeEventList;
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, ZeEventList));

if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_ADVISE;

// Get sync point and register the event with it.
*SyncPoint = CommandBuffer->GetNextSyncPoint();
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);

ZE2UR_CALL(zeCommandListAppendMemAdvise,
(CommandBuffer->ZeCommandList, CommandBuffer->Device->ZeDevice,
Mem, Size, ZeAdvice));

// Level Zero does not have a completion "event" with the advise API,
// so manually add command to signal our event.
ZE2UR_CALL(zeCommandListAppendSignalEvent,
(CommandBuffer->ZeCommandList, LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t CommandBuffer, ur_queue_handle_t Queue,
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
34 changes: 34 additions & 0 deletions source/adapters/opencl/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferFillExp(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void *mem, size_t size,
ur_usm_migration_flags_t flags, uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
(void)hCommandBuffer;
(void)mem;
(void)size;
(void)flags;
(void)numSyncPointsInWaitList;
(void)pSyncPointWaitList;
(void)pSyncPoint;

// Not implemented
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void *mem, size_t size,
ur_usm_advice_flags_t advice, uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
(void)hCommandBuffer;
(void)mem;
(void)size;
(void)advice;
(void)numSyncPointsInWaitList;
(void)pSyncPointWaitList;
(void)pSyncPoint;

// Not implemented
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/opencl/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down

0 comments on commit 5d8173a

Please sign in to comment.