Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EXP][CMDBUF] Improve CUDA Fill op implementation #1319

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ static ur_result_t enqueueCommandBufferFillHelper(

try {
const size_t N = Size / PatternSize;
auto Value = *static_cast<const uint32_t *>(Pattern);
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
? *static_cast<CUdeviceptr *>(DstDevice)
: (CUdeviceptr)DstDevice;
Expand All @@ -183,9 +182,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
NodeParams.elementSize = PatternSize;
NodeParams.height = N;
NodeParams.pitch = PatternSize;
NodeParams.value = Value;
NodeParams.width = 1;

// pattern size in bytes
switch (PatternSize) {
case 1: {
auto Value = *static_cast<const uint8_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 2: {
auto Value = *static_cast<const uint16_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 4: {
auto Value = *static_cast<const uint32_t *>(Pattern);
NodeParams.value = Value;
break;
}
}

UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
Expand All @@ -198,29 +215,54 @@ static ur_result_t enqueueCommandBufferFillHelper(
// CUDA has no memset functions that allow setting values more than 4
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
// fill, which can be more than 4 bytes. We must break up the pattern
// into 4 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 4 bytes
// in the pattern.
// into 1 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 1
// bytes in the pattern.

size_t NumberOfSteps = PatternSize / sizeof(uint8_t);

size_t NumberOfSteps = PatternSize / sizeof(uint32_t);
// Shared pointer that will point to the last node created
std::shared_ptr<CUgraphNode> GraphNodePtr;
// Create a new node
CUgraphNode GraphNodeFirst;
// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
NodeParamsStepFirst.dst = DstPtr;
NodeParamsStepFirst.elementSize = sizeof(uint32_t);
NodeParamsStepFirst.height = Size / sizeof(uint32_t);
NodeParamsStepFirst.pitch = sizeof(uint32_t);
NodeParamsStepFirst.value = *static_cast<const uint32_t *>(Pattern);
NodeParamsStepFirst.width = 1;

// we walk up the pattern in 4-byte steps, and call cuMemset for each
// 4-byte chunk of the pattern.
for (auto Step = 0u; Step < NumberOfSteps; ++Step) {
UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParamsStepFirst,
CommandBuffer->Device->getContext()));

// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->addSyncPoint(
std::make_shared<CUgraphNode>(GraphNodeFirst));

DepsList.clear();
DepsList.push_back(GraphNodeFirst);

// we walk up the pattern in 1-byte steps, and call cuMemset for each
// 1-byte chunk of the pattern.
for (auto Step = 4u; Step < NumberOfSteps; ++Step) {
// take 4 bytes of the pattern
auto Value = *(static_cast<const uint32_t *>(Pattern) + Step);
auto Value = *(static_cast<const uint8_t *>(Pattern) + Step);

// offset the pointer to the part of the buffer we want to write to
auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t));
auto OffsetPtr = DstPtr + (Step * sizeof(uint8_t));

// Create a new node
CUgraphNode GraphNode;
// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
NodeParamsStep.elementSize = 4;
NodeParamsStep.height = N;
NodeParamsStep.pitch = PatternSize;
NodeParamsStep.elementSize = sizeof(uint8_t);
NodeParamsStep.height = Size / NumberOfSteps;
NodeParamsStep.pitch = NumberOfSteps * sizeof(uint8_t);
NodeParamsStep.value = Value;
NodeParamsStep.width = 1;

Expand All @@ -229,9 +271,12 @@ static ur_result_t enqueueCommandBufferFillHelper(
DepsList.size(), &NodeParamsStep,
CommandBuffer->Device->getContext()));

GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->addSyncPoint(
std::make_shared<CUgraphNode>(GraphNode));
*SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr);

DepsList.clear();
DepsList.push_back(*GraphNodePtr.get());
}
}
} catch (ur_result_t Err) {
Expand Down