Skip to content

Commit

Permalink
Add some parameterization to MemBufferFill tests.
Browse files Browse the repository at this point in the history
This effectively ports some cases from the old PI unit tests.
  • Loading branch information
aarongreig committed Aug 8, 2023
1 parent 5513ee2 commit 4a02d7b
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 70 deletions.
162 changes: 122 additions & 40 deletions test/conformance/enqueue/urEnqueueMemBufferFill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,172 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include <uur/fixtures.h>

using urEnqueueMemBufferFillTest = uur::urMemBufferQueueTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urEnqueueMemBufferFillTest);
struct testParametersFill {
size_t size;
size_t pattern_size;
};

template <typename T>
inline std::string
printFillTestString(const testing::TestParamInfo<typename T::ParamType> &info) {
const auto device_handle = std::get<0>(info.param);
const auto platform_device_name =
uur::GetPlatformAndDeviceName(device_handle);
std::stringstream test_name;
test_name << platform_device_name << "__size__"
<< std::get<1>(info.param).size << "__patternSize__"
<< std::get<1>(info.param).pattern_size;
return test_name.str();
}

struct urEnqueueMemBufferFillTest
: uur::urQueueTestWithParam<testParametersFill> {
void SetUp() override {
UUR_RETURN_ON_FATAL_FAILURE(
urQueueTestWithParam<testParametersFill>::SetUp());
size = std::get<1>(GetParam()).size;
pattern_size = std::get<1>(GetParam()).pattern_size;
pattern = std::vector<uint8_t>(pattern_size);
uur::generateMemFillPattern(pattern);
ASSERT_SUCCESS(urMemBufferCreate(this->context, UR_MEM_FLAG_READ_WRITE,
size, nullptr, &buffer));
}

void TearDown() override {
if (buffer) {
EXPECT_SUCCESS(urMemRelease(buffer));
}
UUR_RETURN_ON_FATAL_FAILURE(
urQueueTestWithParam<testParametersFill>::TearDown());
}

void verifyData(std::vector<uint8_t> &output, size_t verify_size) {
size_t pattern_index = 0;
for (size_t i = 0; i < verify_size; ++i) {
ASSERT_EQ(output[i], pattern[pattern_index])
<< "Result mismatch at index: " << i;

++pattern_index;
if (pattern_index % pattern_size == 0) {
pattern_index = 0;
}
}
}

ur_mem_handle_t buffer = nullptr;
std::vector<uint8_t> pattern;
size_t size;
size_t pattern_size;
};

static std::vector<testParametersFill> test_cases{
/* Everything set to 1 */
{1, 1},
/* pattern_size == size */
{256, 256},
/* pattern_size < size */
{1024, 256},
/* pattern sizes corresponding to some common scalar and vector types */
{256, 4},
{256, 8},
{256, 16},
{256, 32}};

UUR_TEST_SUITE_P(urEnqueueMemBufferFillTest, testing::ValuesIn(test_cases),
printFillTestString<urEnqueueMemBufferFillTest>);

TEST_P(urEnqueueMemBufferFillTest, Success) {
const uint32_t pattern = 0xdeadbeef;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, &pattern,
sizeof(pattern), 0, size, 0, nullptr,
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, pattern.data(),
pattern_size, 0, size, 0, nullptr,
nullptr));
std::vector<uint32_t> output(count, 1);
std::vector<uint8_t> output(size, 1);
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, size,
output.data(), 0, nullptr, nullptr));
for (unsigned i = 0; i < count; ++i) {
ASSERT_EQ(output[i], pattern) << "Result mismatch at index: " << i;
}
verifyData(output, size);
}

TEST_P(urEnqueueMemBufferFillTest, SuccessPartialFill) {
const std::vector<uint32_t> input(count, 42);
if (size == 1) {
// Can't partially fill one byte
GTEST_SKIP();
}
const std::vector<uint8_t> input(size, 0);
ASSERT_SUCCESS(urEnqueueMemBufferWrite(queue, buffer, true, 0, size,
input.data(), 0, nullptr, nullptr));
const uint32_t pattern = 0xdeadbeef;
const size_t partial_fill_size = size / 2;
const size_t fill_count = count / 2;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, &pattern,
sizeof(pattern), 0, partial_fill_size,
0, nullptr, nullptr));
std::vector<uint32_t> output(count, 1);
// Make sure we don't end up with pattern_size > size
pattern_size = pattern_size / 2;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, pattern.data(),
pattern_size, 0, partial_fill_size, 0,
nullptr, nullptr));
std::vector<uint8_t> output(size, 1);
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, size,
output.data(), 0, nullptr, nullptr));
for (size_t i = 0; i < count - fill_count; ++i) {
ASSERT_EQ(output[i], pattern) << "Result mismatch at index: " << i;
}
// Check the first half matches the pattern and the second half remains untouched.
verifyData(output, partial_fill_size);

for (size_t i = fill_count; i < count; ++i) {
ASSERT_EQ(output[i], 42) << "Result mismatch at index: " << i;
for (size_t i = partial_fill_size; i < size; ++i) {
ASSERT_EQ(output[i], input[i]) << "Result mismatch at index: " << i;
}
}

TEST_P(urEnqueueMemBufferFillTest, SuccessOffset) {
const std::vector<uint32_t> input(count, 42);
if (size == 1) {
// No room for an offset
GTEST_SKIP();
}
const std::vector<uint8_t> input(size, 0);
ASSERT_SUCCESS(urEnqueueMemBufferWrite(queue, buffer, true, 0, size,
input.data(), 0, nullptr, nullptr));
const uint32_t pattern = 0xdeadbeef;

const size_t offset_size = size / 2;
const size_t offset_count = count / 2;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, &pattern,
sizeof(pattern), offset_size,
// Make sure we don't end up with pattern_size > size
pattern_size = pattern_size / 2;
ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, buffer, pattern.data(),
pattern_size, offset_size,
offset_size, 0, nullptr, nullptr));
std::vector<uint32_t> output(count, 1);
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, size,
output.data(), 0, nullptr, nullptr));
for (size_t i = 0; i < offset_count; ++i) {
ASSERT_EQ(output[i], 42) << "Result mismatch at index: " << i;
}

for (size_t i = offset_count; i < count; ++i) {
ASSERT_EQ(output[i], pattern) << "Result mismatch at index: " << i;
// Check the second half matches the pattern and the first half remains untouched.
std::vector<uint8_t> output(offset_size);
ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, offset_size,
offset_size, output.data(), 0,
nullptr, nullptr));
verifyData(output, offset_size);

ASSERT_SUCCESS(urEnqueueMemBufferRead(queue, buffer, true, 0, offset_size,
output.data(), 0, nullptr, nullptr));
for (size_t i = 0; i < offset_size; ++i) {
ASSERT_EQ(output[i], input[i]) << "Result mismatch at index: " << i;
}
}

TEST_P(urEnqueueMemBufferFillTest, InvalidNullHandleQueue) {
using urEnqueueMemBufferFillNegativeTest = uur::urMemBufferQueueTest;

UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urEnqueueMemBufferFillNegativeTest);

TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullHandleQueue) {
const uint32_t pattern = 0xdeadbeef;
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
urEnqueueMemBufferFill(nullptr, buffer, &pattern,
sizeof(pattern), 0, size, 0,
nullptr, nullptr));
}

TEST_P(urEnqueueMemBufferFillTest, InvalidNullHandleBuffer) {
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullHandleBuffer) {
const uint32_t pattern = 0xdeadbeef;
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
urEnqueueMemBufferFill(queue, nullptr, &pattern,
sizeof(pattern), 0, size, 0,
nullptr, nullptr));
}

TEST_P(urEnqueueMemBufferFillTest, InvalidNullHandlePointerPattern) {
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullHandlePointerPattern) {
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER,
urEnqueueMemBufferFill(queue, buffer, nullptr,
sizeof(uint32_t), 0, size, 0,
nullptr, nullptr));
}

TEST_P(urEnqueueMemBufferFillTest, InvalidNullPtrEventWaitList) {
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidNullPtrEventWaitList) {
const uint32_t pattern = 0xdeadbeef;
ASSERT_EQ_RESULT(urEnqueueMemBufferFill(queue, buffer, &pattern,
sizeof(uint32_t), 0, size, 1,
Expand All @@ -103,7 +185,7 @@ TEST_P(urEnqueueMemBufferFillTest, InvalidNullPtrEventWaitList) {
UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST);
}

TEST_P(urEnqueueMemBufferFillTest, InvalidSize) {
TEST_P(urEnqueueMemBufferFillNegativeTest, InvalidSize) {
const uint32_t pattern = 0xdeadbeef;
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_SIZE,
urEnqueueMemBufferFill(queue, buffer, &pattern,
Expand Down
22 changes: 6 additions & 16 deletions test/conformance/enqueue/urEnqueueUSMFill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <random>
#include <uur/fixtures.h>

struct testParametersFill {
Expand Down Expand Up @@ -34,7 +33,7 @@ struct urEnqueueUSMFillTestWithParam
host_mem = std::vector<uint8_t>(size);
pattern_size = std::get<1>(GetParam()).pattern_size;
pattern = std::vector<uint8_t>(pattern_size);
generatePattern();
uur::generateMemFillPattern(pattern);

ur_device_usm_access_capability_flags_t device_usm = 0;
ASSERT_SUCCESS(uur::GetDeviceUSMDeviceSupport(device, device_usm));
Expand All @@ -54,19 +53,6 @@ struct urEnqueueUSMFillTestWithParam
UUR_RETURN_ON_FATAL_FAILURE(urQueueTestWithParam::TearDown());
}

void generatePattern() {

const size_t seed = 1;
std::mt19937 mersenne_engine{seed};
std::uniform_int_distribution<int> dist{0, 255};

auto gen = [&dist, &mersenne_engine]() {
return static_cast<uint8_t>(dist(mersenne_engine));
};

std::generate(begin(pattern), end(pattern), gen);
}

void verifyData() {
ASSERT_SUCCESS(urEnqueueUSMMemcpy(queue, true, host_mem.data(), ptr,
size, 0, nullptr, nullptr));
Expand Down Expand Up @@ -98,7 +84,11 @@ static std::vector<testParametersFill> test_cases{
{256, 256},
/* pattern_size < size */
{1024, 256},
};
/* pattern sizes corresponding to some common scalar and vector types */
{256, 4},
{256, 8},
{256, 16},
{256, 32}};

UUR_TEST_SUITE_P(urEnqueueUSMFillTestWithParam, testing::ValuesIn(test_cases),
printFillTestString<urEnqueueUSMFillTestWithParam>);
Expand Down
15 changes: 1 addition & 14 deletions test/conformance/enqueue/urEnqueueUSMFill2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct urEnqueueUSMFill2DTestWithParam
height = std::get<1>(GetParam()).height;
pattern_size = std::get<1>(GetParam()).pattern_size;
pattern = std::vector<uint8_t>(pattern_size);
generatePattern();
uur::generateMemFillPattern(pattern);
allocation_size = pitch * height;
host_mem = std::vector<uint8_t>(allocation_size);

Expand All @@ -60,19 +60,6 @@ struct urEnqueueUSMFill2DTestWithParam
UUR_RETURN_ON_FATAL_FAILURE(urQueueTestWithParam::TearDown());
}

void generatePattern() {

const size_t seed = 1;
std::mt19937 mersenne_engine{seed};
std::uniform_int_distribution<int> dist{0, 255};

auto gen = [&dist, &mersenne_engine]() {
return static_cast<uint8_t>(dist(mersenne_engine));
};

std::generate(begin(pattern), end(pattern), gen);
}

void verifyData() {
ASSERT_SUCCESS(urEnqueueUSMMemcpy2D(queue, true, host_mem.data(), pitch,
ptr, pitch, width, height, 0,
Expand Down
15 changes: 15 additions & 0 deletions test/conformance/testing/include/uur/fixtures.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <uur/environment.h>
#include <uur/utils.h>

#include <random>

#define UUR_RETURN_ON_FATAL_FAILURE(...) \
__VA_ARGS__; \
if (this->HasFatalFailure() || this->IsSkipped()) { \
Expand Down Expand Up @@ -846,6 +848,19 @@ struct urUSMDeviceAllocTestWithParam : urQueueTestWithParam<T> {
void *ptr = nullptr;
};

// Generates a random byte pattern for MemFill type entry-points.
inline void generateMemFillPattern(std::vector<uint8_t> &pattern) {
const size_t seed = 1;
std::mt19937 mersenne_engine{seed};
std::uniform_int_distribution<int> dist{0, 255};

auto gen = [&dist, &mersenne_engine]() {
return static_cast<uint8_t>(dist(mersenne_engine));
};

std::generate(begin(pattern), end(pattern), gen);
}

/// @brief
/// @tparam T
/// @param info
Expand Down

0 comments on commit 4a02d7b

Please sign in to comment.