Skip to content

Commit

Permalink
Changed to standalone function
Browse files Browse the repository at this point in the history
  • Loading branch information
konradkusiak97 committed Mar 4, 2024
1 parent 7a05c32 commit 19df225
Showing 1 changed file with 33 additions and 36 deletions.
69 changes: 33 additions & 36 deletions source/adapters/hip/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,30 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
return Result;
}

static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
size_t Size, const void *pPattern,
hipDeviceptr_t Ptr) {

// Calculate the number of patterns, stride and the number of times the
// pattern needs to be applied.
auto NumberOfSteps = PatternSize / sizeof(uint8_t);
auto Pitch = NumberOfSteps * sizeof(uint8_t);
auto Height = Size / NumberOfSteps;

for (auto step = 4u; step < NumberOfSteps; ++step) {
// take 1 byte of the pattern
auto Value = *(static_cast<const uint8_t *>(pPattern) + step);

// offset the pointer to the part of the buffer we want to write to
auto OffsetPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(Ptr) +
(step * sizeof(uint8_t)));

// set all of the pattern chunks
UR_CHECK_ERROR(hipMemset2DAsync(OffsetPtr, Pitch, Value, sizeof(uint8_t),
Height, Stream));
}
}

// HIP has no memset functions that allow setting values more than 4 bytes. UR
// API lets you pass an arbitrary "pattern" to the buffer fill, which can be
// more than 4 bytes. We must break up the pattern into 1 byte values, and set
Expand All @@ -776,42 +800,14 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
auto Value = *(static_cast<const uint32_t *>(pPattern));
UR_CHECK_ERROR(hipMemsetD32Async(Ptr, Value, Count32, Stream));

auto memsetRemainPattern = [&Stream, &pPattern,
&Ptr](const auto Size, const auto PatternSize) {
// Calculate the number of patterns, stride and the number of times the
// pattern needs to be applied.
auto NumberOfSteps = PatternSize / sizeof(uint8_t);
auto Pitch = NumberOfSteps * sizeof(uint8_t);
auto Height = Size / NumberOfSteps;

for (auto step = 4u; step < NumberOfSteps; ++step) {
// take 1 byte of the pattern
auto Value = *(static_cast<const uint8_t *>(pPattern) + step);

// offset the pointer to the part of the buffer we want to write to
auto OffsetPtr = reinterpret_cast<void *>(
reinterpret_cast<uint8_t *>(Ptr) + (step * sizeof(uint8_t)));

// set all of the pattern chunks
UR_CHECK_ERROR(hipMemset2DAsync(OffsetPtr, Pitch, Value, sizeof(uint8_t),
Height, Stream));
}
};
// There is a bug in ROCm prior to 6.0.0 version which causes hipMemset2D and
// hipMemset3D to behave incorrectly when acting on host pinned memory. In
// such a case the following part of memsetting the remaining part of the
// pattern is emulated with memcpy.
#if HIP_VERSION < 60000000
// There is a bug in ROCm prior to 6.0.0 version which causes hipMemset2D
// to behave incorrectly when acting on host pinned memory.
// In such a case, the memset operation is partially emulated with memcpy.
#if HIP_VERSION_MAJOR < 6
hipPointerAttribute_t ptrAttribs{};
UR_CHECK_ERROR(hipPointerGetAttributes(&ptrAttribs, (const void *)Ptr));

const bool ptrIsHost{ptrAttribs.memoryType == hipMemoryTypeHost};

// The memoryType member of ptrAttrbis is set to hipMemoryTypeHost for both
// hipHostMalloc and (incorrectly) for hipMallocManaged. So to make sure that
// the Ptr is corresponding to host pinned memory we need to additionally use
// a boolean member of ptrAttribs: isManaged.
if (ptrIsHost && !ptrAttribs.isManaged) {
if (ptrAttribs.hostPointer) {
const auto NumOfCopySteps = Size / PatternSize;
const auto Offset = sizeof(uint32_t);
const auto LeftPatternSize = PatternSize - Offset;
Expand All @@ -827,10 +823,11 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
LeftPatternSize, hipMemcpyHostToHost,
Stream));
}
} else
memsetRemainPattern(Size, PatternSize);
} else {
memsetRemainPattern(Stream, PatternSize, Size, pPattern, Ptr);
}
#else
memsetRemainPattern(Size, PatternSize);
memsetRemainPattern(Stream, PatternSize, Size, pPattern, Ptr);
#endif
return UR_RESULT_SUCCESS;
}
Expand Down

0 comments on commit 19df225

Please sign in to comment.