Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EXP][CMDBUF] Add adapters code for Prefetch and Advise commands #937

Merged
merged 16 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,67 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */,
size_t /*Size*/, ur_usm_migration_flags_t /*Flags*/,
uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
// Prefetch cmd is not supported by Cuda Graph.
// We implement it as an empty node to enforce dependencies.
ur_result_t Result = UR_RESULT_SUCCESS;
CUgraphNode GraphNode;

std::vector<CUgraphNode> DepsList;
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
pSyncPointWaitList, DepsList),
Result);

try {
// Add an empty node to preserve dependencies.
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
DepsList.data(), DepsList.size()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
} catch (ur_result_t Err) {
Result = Err;
}
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void * /* Mem */,
size_t /*Size*/, ur_usm_advice_flags_t /*Advice*/,
uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
// Mem-Advise cmd is not supported by Cuda Graph.
// We implement it as an empty node to enforce dependencies.
ur_result_t Result = UR_RESULT_SUCCESS;
CUgraphNode GraphNode;

std::vector<CUgraphNode> DepsList;
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
pSyncPointWaitList, DepsList),
Result);

try {
// Add an empty node to preserve dependencies.
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
DepsList.data(), DepsList.size()));

// Get sync point and register the cuNode with it.
*pSyncPoint =
hCommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
} catch (ur_result_t Err) {
Result = Err;
}

return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/cuda/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
19 changes: 19 additions & 0 deletions source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,25 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t, const void *, size_t,
ur_usm_migration_flags_t, uint32_t,
const ur_exp_command_buffer_sync_point_t *,
ur_exp_command_buffer_sync_point_t *) {
detail::ur::die("Experimental Command-buffer feature is not "
"implemented for HIP adapter.");
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
GeorgeWeb marked this conversation as resolved.
Show resolved Hide resolved
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t, const void *, size_t, ur_usm_advice_flags_t,
uint32_t, const ur_exp_command_buffer_sync_point_t *,
ur_exp_command_buffer_sync_point_t *) {
detail::ur::die("Experimental Command-buffer feature is not "
"implemented for HIP adapter.");
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
GeorgeWeb marked this conversation as resolved.
Show resolved Hide resolved
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t, ur_queue_handle_t, uint32_t,
const ur_event_handle_t *, ur_event_handle_t *) {
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/hip/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
100 changes: 100 additions & 0 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,106 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
SyncPointWaitList, SyncPoint);
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
ur_usm_migration_flags_t Flags, uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {
std::ignore = Flags;

std::vector<ze_event_handle_t> ZeEventList;
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, ZeEventList));

if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_PREFETCH;

// Get sync point and register the event with it.
*SyncPoint = CommandBuffer->GetNextSyncPoint();
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);

// Add the prefetch command to the command buffer.
// Note that L0 does not handle migration flags.
ZE2UR_CALL(zeCommandListAppendMemoryPrefetch,
(CommandBuffer->ZeCommandList, Mem, Size));

// Level Zero does not have a completion "event" with the prefetch API,
// so manually add command to signal our event.
ZE2UR_CALL(zeCommandListAppendSignalEvent,
(CommandBuffer->ZeCommandList, LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t CommandBuffer, const void *Mem, size_t Size,
ur_usm_advice_flags_t Advice, uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {
// A memory chunk can be advised with muliple memory advices
// We therefore prefer if statements to switch cases to combine all potential
// flags
uint32_t Value = 0;
if (Advice & UR_USM_ADVICE_FLAG_SET_READ_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_READ_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_READ_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_READ_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_PREFERRED_LOCATION);
if (Advice & UR_USM_ADVICE_FLAG_SET_NON_ATOMIC_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_NON_ATOMIC_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_NON_ATOMIC_MOSTLY)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_NON_ATOMIC_MOSTLY);
if (Advice & UR_USM_ADVICE_FLAG_BIAS_CACHED)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_BIAS_CACHED);
if (Advice & UR_USM_ADVICE_FLAG_BIAS_UNCACHED)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_BIAS_UNCACHED);
if (Advice & UR_USM_ADVICE_FLAG_SET_PREFERRED_LOCATION_HOST)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_SET_PREFERRED_LOCATION);
if (Advice & UR_USM_ADVICE_FLAG_CLEAR_PREFERRED_LOCATION_HOST)
Value |= static_cast<int>(ZE_MEMORY_ADVICE_CLEAR_PREFERRED_LOCATION);

ze_memory_advice_t ZeAdvice = static_cast<ze_memory_advice_t>(Value);

std::vector<ze_event_handle_t> ZeEventList;
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, ZeEventList));

if (NumSyncPointsInWaitList) {
ZE2UR_CALL(zeCommandListAppendWaitOnEvents,
(CommandBuffer->ZeCommandList, NumSyncPointsInWaitList,
ZeEventList.data()));
}

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
LaunchEvent->CommandType = UR_COMMAND_USM_ADVISE;

// Get sync point and register the event with it.
*SyncPoint = CommandBuffer->GetNextSyncPoint();
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);

ZE2UR_CALL(zeCommandListAppendMemAdvise,
(CommandBuffer->ZeCommandList, CommandBuffer->Device->ZeDevice,
Mem, Size, ZeAdvice));

// Level Zero does not have a completion "event" with the advise API,
// so manually add command to signal our event.
ZE2UR_CALL(zeCommandListAppendSignalEvent,
(CommandBuffer->ZeCommandList, LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t CommandBuffer, ur_queue_handle_t Queue,
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
34 changes: 34 additions & 0 deletions source/adapters/opencl/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferFillExp(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void *mem, size_t size,
ur_usm_migration_flags_t flags, uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
(void)hCommandBuffer;
(void)mem;
(void)size;
(void)flags;
(void)numSyncPointsInWaitList;
(void)pSyncPointWaitList;
(void)pSyncPoint;

// Not implemented
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
ur_exp_command_buffer_handle_t hCommandBuffer, const void *mem, size_t size,
ur_usm_advice_flags_t advice, uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
(void)hCommandBuffer;
(void)mem;
(void)size;
(void)advice;
(void)numSyncPointsInWaitList;
(void)pSyncPointWaitList;
(void)pSyncPoint;

// Not implemented
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/opencl/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteExp;
pDdiTable->pfnAppendMemBufferWriteRectExp =
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down