diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index dd97f48d6a..a65530a1f1 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -99,6 +99,91 @@ static void setCopyParams(const void *SrcPtr, const CUmemorytype_enum SrcType, Params.Depth = 1; } +// Helper function for enqueuing memory fills +static ur_result_t enqueueCommandBufferFillHelper( + ur_exp_command_buffer_handle_t CommandBuffer, void *DstDevice, + const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize, + size_t Size, uint32_t NumSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, + ur_exp_command_buffer_sync_point_t *SyncPoint) { + ur_result_t Result = UR_RESULT_SUCCESS; + std::vector DepsList; + UR_CALL(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList, + SyncPointWaitList, DepsList), + Result); + + try { + const size_t N = Size / PatternSize; + auto Value = *static_cast(Pattern); + auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE + ? *static_cast(DstDevice) + : (CUdeviceptr)DstDevice; + + if ((PatternSize == 1) || (PatternSize == 2) || (PatternSize == 4)) { + // Create a new node + CUgraphNode GraphNode; + CUDA_MEMSET_NODE_PARAMS NodeParams = {}; + NodeParams.dst = DstPtr; + NodeParams.elementSize = PatternSize; + NodeParams.height = N; + NodeParams.pitch = PatternSize; + NodeParams.value = Value; + NodeParams.width = 1; + + UR_CHECK_ERROR(cuGraphAddMemsetNode( + &GraphNode, CommandBuffer->CudaGraph, DepsList.data(), + DepsList.size(), &NodeParams, CommandBuffer->Device->getContext())); + + // Get sync point and register the cuNode with it. + *SyncPoint = + CommandBuffer->AddSyncPoint(std::make_shared(GraphNode)); + + } else { + // CUDA has no memset functions that allow setting values more than 4 + // bytes. UR API lets you pass an arbitrary "pattern" to the buffer + // fill, which can be more than 4 bytes. We must break up the pattern + // into 4 byte values, and set the buffer using multiple strided calls. + // This means that one cuGraphAddMemsetNode call is made for every 4 bytes + // in the pattern. + + size_t NumberOfSteps = PatternSize / sizeof(uint32_t); + + // we walk up the pattern in 4-byte steps, and call cuMemset for each + // 4-byte chunk of the pattern. + for (auto Step = 0u; Step < NumberOfSteps; ++Step) { + // take 4 bytes of the pattern + auto Value = *(static_cast(Pattern) + Step); + + // offset the pointer to the part of the buffer we want to write to + auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t)); + + // Create a new node + CUgraphNode GraphNode; + // Update NodeParam + CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {}; + NodeParamsStep.dst = (CUdeviceptr)OffsetPtr; + NodeParamsStep.elementSize = 4; + NodeParamsStep.height = N; + NodeParamsStep.pitch = PatternSize; + NodeParamsStep.value = Value; + NodeParamsStep.width = 1; + + UR_CHECK_ERROR(cuGraphAddMemsetNode( + &GraphNode, CommandBuffer->CudaGraph, DepsList.data(), + DepsList.size(), &NodeParamsStep, + CommandBuffer->Device->getContext())); + + // Get sync point and register the cuNode with it. + *SyncPoint = CommandBuffer->AddSyncPoint( + std::make_shared(GraphNode)); + } + } + } catch (ur_result_t Err) { + Result = Err; + } + return Result; +} + UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( ur_context_handle_t hContext, ur_device_handle_t hDevice, const ur_exp_command_buffer_desc_t *pCommandBufferDesc, @@ -596,6 +681,48 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( return Result; } +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( + ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer, + const void *pPattern, size_t patternSize, size_t offset, size_t size, + uint32_t numSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, + ur_exp_command_buffer_sync_point_t *pSyncPoint) { + auto ArgsAreMultiplesOfPatternSize = + (offset % patternSize == 0) || (size % patternSize == 0); + + auto PatternIsValid = (pPattern != nullptr); + + auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) && + (patternSize > 0); // is a positive power of two + UR_ASSERT(ArgsAreMultiplesOfPatternSize && PatternIsValid && + PatternSizeIsValid, + UR_RESULT_ERROR_INVALID_SIZE); + + auto DstDevice = std::get(hBuffer->Mem).get() + offset; + + return enqueueCommandBufferFillHelper( + hCommandBuffer, &DstDevice, CU_MEMORYTYPE_DEVICE, pPattern, patternSize, + size, numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint); +} + +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( + ur_exp_command_buffer_handle_t hCommandBuffer, void *pPtr, + const void *pPattern, size_t patternSize, size_t size, + uint32_t numSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *pSyncPointWaitList, + ur_exp_command_buffer_sync_point_t *pSyncPoint) { + + auto PatternIsValid = (pPattern != nullptr); + + auto PatternSizeIsValid = ((patternSize & (patternSize - 1)) == 0) && + (patternSize > 0); // is a positive power of two + + UR_ASSERT(PatternIsValid && PatternSizeIsValid, UR_RESULT_ERROR_INVALID_SIZE); + return enqueueCommandBufferFillHelper( + hCommandBuffer, pPtr, CU_MEMORYTYPE_UNIFIED, pPattern, patternSize, size, + numSyncPointsInWaitList, pSyncPointWaitList, pSyncPoint); +} + UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( ur_exp_command_buffer_handle_t hCommandBuffer, ur_queue_handle_t hQueue, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, diff --git a/source/adapters/cuda/ur_interface_loader.cpp b/source/adapters/cuda/ur_interface_loader.cpp index af18d96017..f31ffe6d87 100644 --- a/source/adapters/cuda/ur_interface_loader.cpp +++ b/source/adapters/cuda/ur_interface_loader.cpp @@ -279,6 +279,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp; pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp; pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp; + pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp; pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp; pDdiTable->pfnAppendMemBufferCopyRectExp = urCommandBufferAppendMemBufferCopyRectExp; @@ -291,6 +292,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp; pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp; + pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; return retVal; diff --git a/source/adapters/hip/command_buffer.cpp b/source/adapters/hip/command_buffer.cpp index c7609b6110..54a6fa2f4e 100644 --- a/source/adapters/hip/command_buffer.cpp +++ b/source/adapters/hip/command_buffer.cpp @@ -137,6 +137,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( + ur_exp_command_buffer_handle_t, ur_mem_handle_t, const void *, size_t, + size_t, size_t, uint32_t, const ur_exp_command_buffer_sync_point_t *, + ur_exp_command_buffer_sync_point_t *) { + detail::ur::die("Experimental Command-buffer feature is not " + "implemented for HIP adapter."); + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} + +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( + ur_exp_command_buffer_handle_t, void *, const void *, size_t, size_t, + uint32_t, const ur_exp_command_buffer_sync_point_t *, + ur_exp_command_buffer_sync_point_t *) { + detail::ur::die("Experimental Command-buffer feature is not " + "implemented for HIP adapter."); + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +} + UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( ur_exp_command_buffer_handle_t, ur_queue_handle_t, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) { diff --git a/source/adapters/hip/ur_interface_loader.cpp b/source/adapters/hip/ur_interface_loader.cpp index f23d395d1a..7707e78425 100644 --- a/source/adapters/hip/ur_interface_loader.cpp +++ b/source/adapters/hip/ur_interface_loader.cpp @@ -276,6 +276,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp; pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp; pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp; + pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp; pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp; pDdiTable->pfnAppendMemBufferCopyRectExp = urCommandBufferAppendMemBufferCopyRectExp; @@ -289,6 +290,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp; pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; + pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp; return retVal; } diff --git a/source/adapters/level_zero/command_buffer.cpp b/source/adapters/level_zero/command_buffer.cpp index bb081f9b2d..4b811ab033 100644 --- a/source/adapters/level_zero/command_buffer.cpp +++ b/source/adapters/level_zero/command_buffer.cpp @@ -379,6 +379,48 @@ static ur_result_t enqueueCommandBufferMemCopyRectHelper( return UR_RESULT_SUCCESS; } +// Helper function for enqueuing memory fills +static ur_result_t enqueueCommandBufferFillHelper( + ur_command_t CommandType, ur_exp_command_buffer_handle_t CommandBuffer, + void *Ptr, const void *Pattern, size_t PatternSize, size_t Size, + uint32_t NumSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, + ur_exp_command_buffer_sync_point_t *SyncPoint) { + // Pattern size must be a power of two. + UR_ASSERT((PatternSize > 0) && ((PatternSize & (PatternSize - 1)) == 0), + UR_RESULT_ERROR_INVALID_VALUE); + + // Pattern size must fit the compute queue capabilities. + UR_ASSERT( + PatternSize <= + CommandBuffer->Device + ->QueueGroup[ur_device_handle_t_::queue_group_info_t::Compute] + .ZeProperties.maxMemoryFillPatternSize, + UR_RESULT_ERROR_INVALID_VALUE); + + std::vector ZeEventList; + UR_CALL(getEventsFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList, + SyncPointWaitList, ZeEventList)); + + ur_event_handle_t LaunchEvent; + UR_CALL(EventCreate(CommandBuffer->Context, nullptr, true, &LaunchEvent)); + LaunchEvent->CommandType = CommandType; + + // Get sync point and register the event with it. + *SyncPoint = CommandBuffer->GetNextSyncPoint(); + CommandBuffer->RegisterSyncPoint(*SyncPoint, LaunchEvent); + + ZE2UR_CALL(zeCommandListAppendMemoryFill, + (CommandBuffer->ZeCommandList, Ptr, Pattern, PatternSize, Size, + LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data())); + + urPrint("calling zeCommandListAppendMemoryFill() with" + " ZeEvent %#lx\n", + ur_cast(LaunchEvent->ZeEvent)); + + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(ur_context_handle_t Context, ur_device_handle_t Device, const ur_exp_command_buffer_desc_t *CommandBufferDesc, @@ -783,6 +825,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp( return UR_RESULT_SUCCESS; } +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( + ur_exp_command_buffer_handle_t CommandBuffer, ur_mem_handle_t Buffer, + const void *Pattern, size_t PatternSize, size_t Offset, size_t Size, + uint32_t NumSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, + ur_exp_command_buffer_sync_point_t *SyncPoint) { + + std::scoped_lock Lock(Buffer->Mutex); + + char *ZeHandleDst = nullptr; + _ur_buffer *UrBuffer = reinterpret_cast<_ur_buffer *>(Buffer); + UR_CALL(UrBuffer->getZeHandle(ZeHandleDst, ur_mem_handle_t_::write_only, + CommandBuffer->Device)); + + return enqueueCommandBufferFillHelper( + UR_COMMAND_MEM_BUFFER_FILL, CommandBuffer, ZeHandleDst + Offset, + Pattern, // It will be interpreted as an 8-bit value, + PatternSize, // which is indicated with this pattern_size==1 + Size, NumSyncPointsInWaitList, SyncPointWaitList, SyncPoint); +} + +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMFillExp( + ur_exp_command_buffer_handle_t CommandBuffer, void *Ptr, + const void *Pattern, size_t PatternSize, size_t Size, + uint32_t NumSyncPointsInWaitList, + const ur_exp_command_buffer_sync_point_t *SyncPointWaitList, + ur_exp_command_buffer_sync_point_t *SyncPoint) { + + return enqueueCommandBufferFillHelper( + UR_COMMAND_MEM_BUFFER_FILL, CommandBuffer, Ptr, + Pattern, // It will be interpreted as an 8-bit value, + PatternSize, // which is indicated with this pattern_size==1 + Size, NumSyncPointsInWaitList, SyncPointWaitList, SyncPoint); +} + UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp( ur_exp_command_buffer_handle_t CommandBuffer, ur_queue_handle_t Queue, uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList, diff --git a/source/adapters/level_zero/ur_interface_loader.cpp b/source/adapters/level_zero/ur_interface_loader.cpp index 5371fac082..74d0706b31 100644 --- a/source/adapters/level_zero/ur_interface_loader.cpp +++ b/source/adapters/level_zero/ur_interface_loader.cpp @@ -326,6 +326,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp; pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp; pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp; + pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp; pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp; pDdiTable->pfnAppendMemBufferCopyRectExp = urCommandBufferAppendMemBufferCopyRectExp; @@ -338,6 +339,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp; pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp; + pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; return retVal; diff --git a/source/adapters/opencl/command_buffer.cpp b/source/adapters/opencl/command_buffer.cpp index 25d3311b79..74cdd8a03d 100644 --- a/source/adapters/opencl/command_buffer.cpp +++ b/source/adapters/opencl/command_buffer.cpp @@ -273,7 +273,7 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } -UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMembufferFillExp( +UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp( ur_exp_command_buffer_handle_t hCommandBuffer, ur_mem_handle_t hBuffer, const void *pPattern, size_t patternSize, size_t offset, size_t size, uint32_t numSyncPointsInWaitList, diff --git a/source/adapters/opencl/ur_interface_loader.cpp b/source/adapters/opencl/ur_interface_loader.cpp index b9887b1b1a..ac2c33475b 100644 --- a/source/adapters/opencl/ur_interface_loader.cpp +++ b/source/adapters/opencl/ur_interface_loader.cpp @@ -286,6 +286,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnFinalizeExp = urCommandBufferFinalizeExp; pDdiTable->pfnAppendKernelLaunchExp = urCommandBufferAppendKernelLaunchExp; pDdiTable->pfnAppendUSMMemcpyExp = urCommandBufferAppendUSMMemcpyExp; + pDdiTable->pfnAppendUSMFillExp = urCommandBufferAppendUSMFillExp; pDdiTable->pfnAppendMemBufferCopyExp = urCommandBufferAppendMemBufferCopyExp; pDdiTable->pfnAppendMemBufferCopyRectExp = urCommandBufferAppendMemBufferCopyRectExp; @@ -298,6 +299,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnAppendUSMPrefetchExp = urCommandBufferAppendUSMPrefetchExp; pDdiTable->pfnAppendUSMAdviseExp = urCommandBufferAppendUSMAdviseExp; + pDdiTable->pfnAppendMemBufferFillExp = urCommandBufferAppendMemBufferFillExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; return retVal;