Skip to content

Commit

Permalink
[SYCL][Graph] Improve CUDA Fill op implementation.
Browse files Browse the repository at this point in the history
Adjustment of value pointer size according to pattern size.
Large patterns are now broken into 1-byte chunks, as in the regular implementation.
  • Loading branch information
mfrancepillois authored and EwanC committed Mar 12, 2024
1 parent 93e8469 commit b34d7a7
Showing 1 changed file with 61 additions and 16 deletions.
77 changes: 61 additions & 16 deletions source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ static ur_result_t enqueueCommandBufferFillHelper(

try {
const size_t N = Size / PatternSize;
auto Value = *static_cast<const uint32_t *>(Pattern);
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
? *static_cast<CUdeviceptr *>(DstDevice)
: (CUdeviceptr)DstDevice;
Expand All @@ -183,9 +182,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
NodeParams.elementSize = PatternSize;
NodeParams.height = N;
NodeParams.pitch = PatternSize;
NodeParams.value = Value;
NodeParams.width = 1;

// pattern size in bytes
switch (PatternSize) {
case 1: {
auto Value = *static_cast<const uint8_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 2: {
auto Value = *static_cast<const uint16_t *>(Pattern);
NodeParams.value = Value;
break;
}
case 4: {
auto Value = *static_cast<const uint32_t *>(Pattern);
NodeParams.value = Value;
break;
}
}

UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
Expand All @@ -198,29 +215,54 @@ static ur_result_t enqueueCommandBufferFillHelper(
// CUDA has no memset functions that allow setting values more than 4
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
// fill, which can be more than 4 bytes. We must break up the pattern
// into 4 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 4 bytes
// in the pattern.
// into 1 byte values, and set the buffer using multiple strided calls.
// This means that one cuGraphAddMemsetNode call is made for every 1
// bytes in the pattern.

size_t NumberOfSteps = PatternSize / sizeof(uint8_t);

size_t NumberOfSteps = PatternSize / sizeof(uint32_t);
// Shared pointer that will point to the last node created
std::shared_ptr<CUgraphNode> GraphNodePtr;
// Create a new node
CUgraphNode GraphNodeFirst;
// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
NodeParamsStepFirst.dst = DstPtr;
NodeParamsStepFirst.elementSize = sizeof(uint32_t);
NodeParamsStepFirst.height = Size / sizeof(uint32_t);
NodeParamsStepFirst.pitch = sizeof(uint32_t);
NodeParamsStepFirst.value = *static_cast<const uint32_t *>(Pattern);
NodeParamsStepFirst.width = 1;

// we walk up the pattern in 4-byte steps, and call cuMemset for each
// 4-byte chunk of the pattern.
for (auto Step = 0u; Step < NumberOfSteps; ++Step) {
UR_CHECK_ERROR(cuGraphAddMemsetNode(
&GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(),
DepsList.size(), &NodeParamsStepFirst,
CommandBuffer->Device->getContext()));

// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->AddSyncPoint(
std::make_shared<CUgraphNode>(GraphNodeFirst));

DepsList.clear();
DepsList.push_back(GraphNodeFirst);

// we walk up the pattern in 1-byte steps, and call cuMemset for each
// 1-byte chunk of the pattern.
for (auto Step = 4u; Step < NumberOfSteps; ++Step) {
// take 4 bytes of the pattern
auto Value = *(static_cast<const uint32_t *>(Pattern) + Step);
auto Value = *(static_cast<const uint8_t *>(Pattern) + Step);

// offset the pointer to the part of the buffer we want to write to
auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t));
auto OffsetPtr = DstPtr + (Step * sizeof(uint8_t));

// Create a new node
CUgraphNode GraphNode;
// Update NodeParam
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
NodeParamsStep.elementSize = 4;
NodeParamsStep.height = N;
NodeParamsStep.pitch = PatternSize;
NodeParamsStep.elementSize = sizeof(uint8_t);
NodeParamsStep.height = Size / NumberOfSteps;
NodeParamsStep.pitch = NumberOfSteps * sizeof(uint8_t);
NodeParamsStep.value = Value;
NodeParamsStep.width = 1;

Expand All @@ -229,9 +271,12 @@ static ur_result_t enqueueCommandBufferFillHelper(
DepsList.size(), &NodeParamsStep,
CommandBuffer->Device->getContext()));

GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
// Get sync point and register the cuNode with it.
*SyncPoint = CommandBuffer->addSyncPoint(
std::make_shared<CUgraphNode>(GraphNode));
*SyncPoint = CommandBuffer->AddSyncPoint(GraphNodePtr);

DepsList.clear();
DepsList.push_back(*GraphNodePtr.get());
}
}
} catch (ur_result_t Err) {
Expand Down

0 comments on commit b34d7a7

Please sign in to comment.