Skip to content

Commit

Permalink
Merge pull request #938 from Bensuo/cmdbuf-fill-memset-l0
Browse files Browse the repository at this point in the history
[EXP][CMDBUF] Implement Fill commands for L0 adapter
  • Loading branch information
kbenzie committed Jan 4, 2024
2 parents cf87428 + 3ee71a7 commit 1d78636
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 1 deletion.
127 changes: 127 additions & 0 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,91 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType,
Params.Depth = 1;
}

// Helper function for enqueuing memory fills
static ur_result_t enqueueCommandBufferFillHelper(
ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice,
const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
size_t Size, uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {
ur_result_t Result = UR_RESULT_SUCCESS;
std::vector<CUgraphNode> DepsList;
UR_CALL(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, DepsList),
Result);

try {
const size_t N = Size / PatternSize;
auto Value = *static_cast<const uint32_t *>(Pattern);
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
? *static_cast<CUdeviceptr *>(DstDevice)
: (CUdeviceptr)DstDevice;

if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) {
// Create a new node
CUgraphNode GraphNode;
CUDA_MEMSET_NODE_PARAMS NodeParams = {};
NodeParams.dst = DstPtr;
NodeParams.elementSize = PatternSize;
NodeParams.height = N;
NodeParams.pitch = PatternSize;
NodeParams.value = Value;
NodeParams.width = 1;

UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));

// Get sync point and register the cuNode with it.
*SyncPoint =
CommandBuffer->AddSyncPoint(std::make_shared<CUgraphNode>(GraphNode));

} else {
// CUDA has no memset functions that allow setting values more than 4
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
// fill, which can be more than 4 bytes. We must break up the pattern
// into 4 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 4 bytes
// in the pattern.

size_t NumberOfSteps = PatternSize / sizeof(uint32_t);

// we walk up the pattern in 4-byte steps, and call cuMemset for each
// 4-byte chunk of the pattern.
for (auto Step = 0u; Step < NumberOfSteps; ++Step) {
// take 4 bytes of the pattern
auto Value = *(static_cast<const uint32_t *>(Pattern) + Step);

// offset the pointer to the part of the buffer we want to write to
auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t));

// Create a new node
CUgraphNode GraphNode;
// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
NodeParamsStep.elementSize = 4;
NodeParamsStep.height = N;
NodeParamsStep.pitch = PatternSize;
NodeParamsStep.value = Value;
NodeParamsStep.width = 1;

UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParamsStep,
CommandBuffer->Device->getContext()));

// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->AddSyncPoint(
std::make_shared<CUgraphNode>(GraphNode));
}
}
} catch (ur_result_t Err) {
Result = Err;
}
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
ur_context_handle_t hContext, ur_device_handle_t hDevice,
const ur_exp_command_buffer_desc_t *pCommandBufferDesc,
Expand Down Expand Up @@ -596,6 +681,48 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer,
const void *pPattern, size_t patternSize, size_t offset, size_t size,
uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {
auto ArgsAreMultiplesOfPatternSize =
(offset % patternSize == 0) || (size % patternSize == 0);

auto PatternIsValid = (pPattern != nullptr);

auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) &&
(patternSize > 0); // is a positive power of two
UR_ASSERT(ArgsAreMultiplesOfPatternSize && PatternIsValid &&
PatternSizeIsValid,
UR_RESULT_ERROR_INVALID_SIZE);

auto DstDevice = std::get<BufferMem>(hBuffer->Mem).get() + offset;

return enqueueCommandBufferFillHelper(
hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize,
size, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
ur_exp_command_buffer_handle_t hCommandBuffer, void *pPtr,
const void *pPattern, size_t patternSize, size_t size,
uint32_t numSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList,
ur_exp_command_buffer_sync_point_t *pSyncPoint) {

auto PatternIsValid = (pPattern != nullptr);

auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) &&
(patternSize > 0); // is a positive power of two

UR_ASSERT(PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE);
return enqueueCommandBufferFillHelper(
hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size,
numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint);
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue,
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/cuda/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
pDdiTable->pfnAppendMemBufferCopyRectExp =
urCommandBufferAppendMemBufferCopyRectExp;
Expand All @@ -291,6 +292,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
18 changes: 18 additions & 0 deletions source/adapters/hip/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
ur_exp_command_buffer_handle_t, ur_mem_handle_t, const void *, size_t,
size_t, size_t, uint32_t, const ur_exp_command_buffer_sync_point_t *,
ur_exp_command_buffer_sync_point_t *) {
detail::ur::die("Experimental Command-buffer feature is not "
"implemented for HIP adapter.");
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
ur_exp_command_buffer_handle_t, void *, const void *, size_t, size_t,
uint32_t, const ur_exp_command_buffer_sync_point_t *,
ur_exp_command_buffer_sync_point_t *) {
detail::ur::die("Experimental Command-buffer feature is not "
"implemented for HIP adapter.");
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t, ur_queue_handle_t, uint32_t,
const ur_event_handle_t *, ur_event_handle_t *) {
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/hip/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
pDdiTable->pfnAppendMemBufferCopyRectExp =
urCommandBufferAppendMemBufferCopyRectExp;
Expand All @@ -289,6 +290,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;

return retVal;
}
Expand Down
77 changes: 77 additions & 0 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,48 @@ static ur_result_t enqueueCommandBufferMemCopyRectHelper(
return UR_RESULT_SUCCESS;
}

// Helper function for enqueuing memory fills
static ur_result_t enqueueCommandBufferFillHelper(
ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer,
void *Ptr, const void *Pattern, size_t PatternSize, size_t Size,
uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {
// Pattern size must be a power of two.
UR_ASSERT((PatternSize > 0) && ((PatternSize & (PatternSize - 1)) == 0),
UR_RESULT_ERROR_INVALID_VALUE);

// Pattern size must fit the compute queue capabilities.
UR_ASSERT(
PatternSize <=
CommandBuffer->Device
->QueueGroup[ur_device_handle_t_::queue_group_info_t::Compute]
.ZeProperties.maxMemoryFillPatternSize,
UR_RESULT_ERROR_INVALID_VALUE);

std::vector<ze_event_handle_t> ZeEventList;
UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
SyncPointWaitList, ZeEventList));

ur_event_handle_t LaunchEvent;
UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent));
LaunchEvent->CommandType = CommandType;

// Get sync point and register the event with it.
*SyncPoint = CommandBuffer->GetNextSyncPoint();
CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent);

ZE2UR_CALL(zeCommandListAppendMemoryFill,
(CommandBuffer->ZeCommandList, Ptr, Pattern, PatternSize, Size,
LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data()));

urPrint("calling zeCommandListAppendMemoryFill() with"
" ZeEvent %#lx\n",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device,
const ur_exp_command_buffer_desc_t *CommandBufferDesc,
Expand Down Expand Up @@ -783,6 +825,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
ur_exp_command_buffer_handle_t CommandBuffer, ur_mem_handle_t Buffer,
const void *Pattern, size_t PatternSize, size_t Offset, size_t Size,
uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {

std::scoped_lock<ur_shared_mutex> Lock(Buffer->Mutex);

char *ZeHandleDst = nullptr;
_ur_buffer *UrBuffer = reinterpret_cast<_ur_buffer *>(Buffer);
UR_CALL(UrBuffer->getZeHandle(ZeHandleDst, ur_mem_handle_t_::write_only,
CommandBuffer->Device));

return enqueueCommandBufferFillHelper(
UR_COMMAND_MEM_BUFFER_FILL, CommandBuffer, ZeHandleDst + Offset,
Pattern, // It will be interpreted as an 8-bit value,
PatternSize, // which is indicated with this pattern_size==1
Size, NumSyncPointsInWaitList, SyncPointWaitList, SyncPoint);
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp(
ur_exp_command_buffer_handle_t CommandBuffer, void *Ptr,
const void *Pattern, size_t PatternSize, size_t Size,
uint32_t NumSyncPointsInWaitList,
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
ur_exp_command_buffer_sync_point_t *SyncPoint) {

return enqueueCommandBufferFillHelper(
UR_COMMAND_MEM_BUFFER_FILL, CommandBuffer, Ptr,
Pattern, // It will be interpreted as an 8-bit value,
PatternSize, // which is indicated with this pattern_size==1
Size, NumSyncPointsInWaitList, SyncPointWaitList, SyncPoint);
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
ur_exp_command_buffer_handle_t CommandBuffer, ur_queue_handle_t Queue,
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/level_zero/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
pDdiTable->pfnAppendMemBufferCopyRectExp =
urCommandBufferAppendMemBufferCopyRectExp;
Expand All @@ -338,6 +339,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/opencl/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferFillExp(
UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer,
const void *pPattern, size_t patternSize, size_t offset, size_t size,
uint32_t numSyncPointsInWaitList,
Expand Down
2 changes: 2 additions & 0 deletions source/adapters/opencl/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp;
pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp;
pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp;
pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp;
pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp;
pDdiTable->pfnAppendMemBufferCopyRectExp =
urCommandBufferAppendMemBufferCopyRectExp;
Expand All @@ -298,6 +299,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable(
urCommandBufferAppendMemBufferWriteRectExp;
pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp;
pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp;
pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp;
pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp;

return retVal;
Expand Down

0 comments on commit 1d78636

Please sign in to comment.