Skip to content

Commit

Permalink
Merge pull request #1319 from Bensuo/maxime/cuda-large-fill-pattern
Browse files Browse the repository at this point in the history
[EXP][CMDBUF] Improve CUDA Fill op implementation
  • Loading branch information
kbenzie authored Mar 14, 2024
2 parents ec634ff + ef72b3f commit bb589ca
Showing 1 changed file with 61 additions and 16 deletions.
77 changes: 61 additions & 16 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ static ur_result_t enqueueCommandBufferFillHelper(

try {
const size_t N = Size / PatternSize;
auto Value = *static_cast<const uint32_t *>(Pattern);
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
? *static_cast<CUdeviceptr *>(DstDevice)
: (CUdeviceptr)DstDevice;
Expand All @@ -183,9 +182,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
NodeParams.elementSize = PatternSize;
NodeParams.height = N;
NodeParams.pitch = PatternSize;
NodeParams.value = Value;
NodeParams.width = 1;

// pattern size in bytes
switch (PatternSize) {
case 1: {
auto Value = *static_cast<const uint8_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 2: {
auto Value = *static_cast<const uint16_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 4: {
auto Value = *static_cast<const uint32_t *>(Pattern);
NodeParams.value = Value;
break;
}
}

UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
Expand All @@ -198,29 +215,54 @@ static ur_result_t enqueueCommandBufferFillHelper(
// CUDA has no memset functions that allow setting values more than 4
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
// fill, which can be more than 4 bytes. We must break up the pattern
// into 4 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 4 bytes
// in the pattern.
// into 1 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 1
// bytes in the pattern.

size_t NumberOfSteps = PatternSize / sizeof(uint8_t);

size_t NumberOfSteps = PatternSize / sizeof(uint32_t);
// Shared pointer that will point to the last node created
std::shared_ptr<CUgraphNode> GraphNodePtr;
// Create a new node
CUgraphNode GraphNodeFirst;
// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
NodeParamsStepFirst.dst = DstPtr;
NodeParamsStepFirst.elementSize = sizeof(uint32_t);
NodeParamsStepFirst.height = Size / sizeof(uint32_t);
NodeParamsStepFirst.pitch = sizeof(uint32_t);
NodeParamsStepFirst.value = *static_cast<const uint32_t *>(Pattern);
NodeParamsStepFirst.width = 1;

// we walk up the pattern in 4-byte steps, and call cuMemset for each
// 4-byte chunk of the pattern.
for (auto Step = 0u; Step < NumberOfSteps; ++Step) {
UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParamsStepFirst,
CommandBuffer->Device->getContext()));

// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->addSyncPoint(
std::make_shared<CUgraphNode>(GraphNodeFirst));

DepsList.clear();
DepsList.push_back(GraphNodeFirst);

// we walk up the pattern in 1-byte steps, and call cuMemset for each
// 1-byte chunk of the pattern.
for (auto Step = 4u; Step < NumberOfSteps; ++Step) {
// take 4 bytes of the pattern
auto Value = *(static_cast<const uint32_t *>(Pattern) + Step);
auto Value = *(static_cast<const uint8_t *>(Pattern) + Step);

// offset the pointer to the part of the buffer we want to write to
auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t));
auto OffsetPtr = DstPtr + (Step * sizeof(uint8_t));

// Create a new node
CUgraphNode GraphNode;
// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
NodeParamsStep.elementSize = 4;
NodeParamsStep.height = N;
NodeParamsStep.pitch = PatternSize;
NodeParamsStep.elementSize = sizeof(uint8_t);
NodeParamsStep.height = Size / NumberOfSteps;
NodeParamsStep.pitch = NumberOfSteps * sizeof(uint8_t);
NodeParamsStep.value = Value;
NodeParamsStep.width = 1;

Expand All @@ -229,9 +271,12 @@ static ur_result_t enqueueCommandBufferFillHelper(
DepsList.size(), &NodeParamsStep,
CommandBuffer->Device->getContext()));

GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->addSyncPoint(
std::make_shared<CUgraphNode>(GraphNode));
*SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr);

DepsList.clear();
DepsList.push_back(*GraphNodePtr.get());
}
}
} catch (ur_result_t Err) {
Expand Down

0 comments on commit bb589ca

Please sign in to comment.