diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index d3f270c701..d9d980073a 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -170,7 +170,6 @@ static ur_result_t enqueueCommandBufferFillHelper( try { const size_t N = Size / PatternSize; - auto Value = *static_cast(Pattern); auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE ? *static_cast(DstDevice) : (CUdeviceptr)DstDevice; @@ -183,9 +182,27 @@ static ur_result_t enqueueCommandBufferFillHelper( NodeParams.elementSize = PatternSize; NodeParams.height = N; NodeParams.pitch = PatternSize; - NodeParams.value = Value; NodeParams.width = 1; + // pattern size in bytes + switch (PatternSize) { + case 1: { + auto Value = *static_cast(Pattern); + NodeParams.value = Value; + break; + } + case 2: { + auto Value = *static_cast(Pattern); + NodeParams.value = Value; + break; + } + case 4: { + auto Value = *static_cast(Pattern); + NodeParams.value = Value; + break; + } + } + UR_CHECK_ERROR(cuGraphAddMemsetNode( &GraphNode, CommandBuffer->CudaGraph, DepsList.data(), DepsList.size(), &NodeParams, CommandBuffer->Device->getContext())); @@ -198,29 +215,54 @@ static ur_result_t enqueueCommandBufferFillHelper( // CUDA has no memset functions that allow setting values more than 4 // bytes. UR API lets you pass an arbitrary "pattern" to the buffer // fill, which can be more than 4 bytes. We must break up the pattern - // into 4 byte values, and set the buffer using multiple strided calls. - // This means that one cuGraphAddMemsetNode call is made for every 4 bytes - // in the pattern. + // into 1 byte values, and set the buffer using multiple strided calls. + // This means that one cuGraphAddMemsetNode call is made for every 1 + // bytes in the pattern. + + size_t NumberOfSteps = PatternSize / sizeof(uint8_t); - size_t NumberOfSteps = PatternSize / sizeof(uint32_t); + // Shared pointer that will point to the last node created + std::shared_ptr GraphNodePtr; + // Create a new node + CUgraphNode GraphNodeFirst; + // Update NodeParam + CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {}; + NodeParamsStepFirst.dst = DstPtr; + NodeParamsStepFirst.elementSize = sizeof(uint32_t); + NodeParamsStepFirst.height = Size / sizeof(uint32_t); + NodeParamsStepFirst.pitch = sizeof(uint32_t); + NodeParamsStepFirst.value = *static_cast(Pattern); + NodeParamsStepFirst.width = 1; - // we walk up the pattern in 4-byte steps, and call cuMemset for each - // 4-byte chunk of the pattern. - for (auto Step = 0u; Step < NumberOfSteps; ++Step) { + UR_CHECK_ERROR(cuGraphAddMemsetNode( + &GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(), + DepsList.size(), &NodeParamsStepFirst, + CommandBuffer->Device->getContext())); + + // Get sync point and register the cuNode with it. + *SyncPoint = CommandBuffer->addSyncPoint( + std::make_shared(GraphNodeFirst)); + + DepsList.clear(); + DepsList.push_back(GraphNodeFirst); + + // we walk up the pattern in 1-byte steps, and call cuMemset for each + // 1-byte chunk of the pattern. + for (auto Step = 4u; Step < NumberOfSteps; ++Step) { // take 4 bytes of the pattern - auto Value = *(static_cast(Pattern) + Step); + auto Value = *(static_cast(Pattern) + Step); // offset the pointer to the part of the buffer we want to write to - auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t)); + auto OffsetPtr = DstPtr + (Step * sizeof(uint8_t)); // Create a new node CUgraphNode GraphNode; // Update NodeParam CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {}; NodeParamsStep.dst = (CUdeviceptr)OffsetPtr; - NodeParamsStep.elementSize = 4; - NodeParamsStep.height = N; - NodeParamsStep.pitch = PatternSize; + NodeParamsStep.elementSize = sizeof(uint8_t); + NodeParamsStep.height = Size / NumberOfSteps; + NodeParamsStep.pitch = NumberOfSteps * sizeof(uint8_t); NodeParamsStep.value = Value; NodeParamsStep.width = 1; @@ -229,9 +271,12 @@ static ur_result_t enqueueCommandBufferFillHelper( DepsList.size(), &NodeParamsStep, CommandBuffer->Device->getContext())); + GraphNodePtr = std::make_shared(GraphNode); // Get sync point and register the cuNode with it. - *SyncPoint = CommandBuffer->addSyncPoint( - std::make_shared(GraphNode)); + *SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr); + + DepsList.clear(); + DepsList.push_back(*GraphNodePtr.get()); } } } catch (ur_result_t Err) {