diff --git a/test/usm/CMakeLists.txt b/test/usm/CMakeLists.txt index b673b6d1b9..fa5454d4db 100644 --- a/test/usm/CMakeLists.txt +++ b/test/usm/CMakeLists.txt @@ -10,6 +10,8 @@ function(add_usm_test name) add_ur_executable(${TEST_TARGET_NAME} ${UR_USM_TEST_DIR}/../conformance/source/environment.cpp ${UR_USM_TEST_DIR}/../conformance/source/main.cpp + ${UR_USM_TEST_DIR}/../unified_malloc_framework/common/provider.c + ${UR_USM_TEST_DIR}/../unified_malloc_framework/common/pool.c ${ARGN}) target_link_libraries(${TEST_TARGET_NAME} PRIVATE @@ -17,10 +19,12 @@ function(add_usm_test name) ${PROJECT_NAME}::loader ur_testing GTest::gtest_main) - add_test(NAME usm-${name} + add_test(NAME usm-${name} COMMAND ${TEST_TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) - set_tests_properties(usm-${name} PROPERTIES LABELS "usm") + set_tests_properties(usm-${name} PROPERTIES + LABELS "usm" + ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$\"") target_compile_definitions("usm_test-${name}" PRIVATE DEVICES_ENVIRONMENT) endfunction() diff --git a/test/usm/usmPoolManager.cpp b/test/usm/usmPoolManager.cpp index eaf44e119d..4365a5ce7b 100644 --- a/test/usm/usmPoolManager.cpp +++ b/test/usm/usmPoolManager.cpp @@ -3,19 +3,18 @@ // See LICENSE.TXT // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "../unified_malloc_framework/common/pool.hpp" -#include "../unified_malloc_framework/common/provider.hpp" #include "ur_pool_manager.hpp" -#include +#include "../unified_malloc_framework/common/pool.h" +#include "../unified_malloc_framework/common/provider.h" -#include +#include -struct urUsmPoolManagerTest +struct urUsmPoolDescriptorTest : public uur::urMultiDeviceContextTest, ::testing::WithParamInterface {}; -TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) { +TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) { auto &devices = uur::DevicesEnvironment::instance->devices; auto poolHandle = this->GetParam(); @@ -49,7 +48,71 @@ TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) { ASSERT_EQ(sharedPools, devices.size() * 2); } -INSTANTIATE_TEST_SUITE_P(urUsmPoolManagerTest, urUsmPoolManagerTest, +INSTANTIATE_TEST_SUITE_P(urUsmPoolDescriptorTest, urUsmPoolDescriptorTest, ::testing::Values(nullptr)); // TODO: add test with sub-devices + +struct urUsmPoolManagerTest : public uur::urContextTest { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE(urContextTest::SetUp()); + auto [ret, descs] = usm::pool_descriptor::create(nullptr, context); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + poolDescriptors = descs; + } + + std::vector poolDescriptors; +}; + +TEST_P(urUsmPoolManagerTest, poolManagerPopulate) { + auto [ret, manager] = usm::pool_manager::create(); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + for (auto &desc : poolDescriptors) { + // Populate the pool manager + auto pool = nullPoolCreate(); + ASSERT_NE(pool, nullptr); + auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy); + ASSERT_NE(poolUnique, nullptr); + ret = manager.addPool(desc, poolUnique); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + } + + for (auto &desc : poolDescriptors) { + // Confirm that there is a pool for each descriptor + auto hPoolOpt = manager.getPool(desc); + ASSERT_TRUE(hPoolOpt.has_value()); + ASSERT_NE(hPoolOpt.value(), nullptr); + } +} + +TEST_P(urUsmPoolManagerTest, poolManagerInsertExisting) { + auto [ret, manager] = usm::pool_manager::create(); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + auto desc = poolDescriptors[0]; + + auto pool = nullPoolCreate(); + ASSERT_NE(pool, nullptr); + auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy); + ASSERT_NE(poolUnique, nullptr); + + ret = manager.addPool(desc, poolUnique); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + // Inserting an existing key should return an error + ret = manager.addPool(desc, poolUnique); + ASSERT_EQ(ret, UR_RESULT_ERROR_INVALID_ARGUMENT); +} + +TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) { + auto [ret, manager] = usm::pool_manager::create(); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + for (auto &desc : poolDescriptors) { + auto hPool = manager.getPool(desc); + ASSERT_FALSE(hPool.has_value()); + } +} + +UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);