diff --git a/source/common/ur_pool_manager.hpp b/source/common/ur_pool_manager.hpp index c4da5d149f..2215bd0575 100644 --- a/source/common/ur_pool_manager.hpp +++ b/source/common/ur_pool_manager.hpp @@ -11,11 +11,17 @@ #ifndef USM_POOL_MANAGER_HPP #define USM_POOL_MANAGER_HPP 1 +#include "logger/ur_logger.hpp" +#include "umf_helpers.hpp" +#include "umf_pools/disjoint_pool.hpp" #include "ur_api.h" -#include "ur_pool_manager.hpp" #include "ur_util.hpp" +#include +#include + #include +#include #include namespace usm { @@ -29,8 +35,9 @@ struct pool_descriptor { ur_usm_type_t type; bool deviceReadOnly; - static bool equal(const pool_descriptor &lhs, const pool_descriptor &rhs); - static std::size_t hash(const pool_descriptor &desc); + bool operator==(const pool_descriptor &other) const; + friend std::ostream &operator<<(std::ostream &os, + const pool_descriptor &desc); static std::pair> create(ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext); }; @@ -45,8 +52,8 @@ urGetSubDevices(ur_device_handle_t hDevice) { } ur_device_partition_property_t prop; - prop.type = UR_DEVICE_PARTITION_EQUALLY; - prop.value.equally = nComputeUnits; + prop.type = UR_DEVICE_PARTITION_BY_CSLICE; + prop.value.affinity_domain = 0; ur_device_partition_properties_t properties{ UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES, @@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) { inline std::pair> urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) { - size_t deviceCount; + size_t deviceCount = 0; auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(deviceCount), &deviceCount, nullptr); - if (ret != UR_RESULT_SUCCESS) { + if (ret != UR_RESULT_SUCCESS || deviceCount == 0) { return {ret, {}}; } @@ -110,6 +117,11 @@ urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) { for (size_t i = 0; i < deviceCount; i++) { ret = addPoolsForDevicesRec(devices[i]); if (ret != UR_RESULT_SUCCESS) { + if (ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) { + // Return main devices when sub-devices are unsupported. + return {ret, std::move(devices)}; + } + return {ret, {}}; } } @@ -122,22 +134,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) { return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly; } -inline bool pool_descriptor::equal(const pool_descriptor &lhs, - const pool_descriptor &rhs) { - ur_native_handle_t lhsNative, rhsNative; +inline bool pool_descriptor::operator==(const pool_descriptor &other) const { + const pool_descriptor &lhs = *this; + const pool_descriptor &rhs = other; + ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr; // We want to share a memory pool for sub-devices and sub-sub devices. // Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but // they share the same native_handle_t (which is used by UMF provider). // Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1 // TODO: is this L0 specific? - auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative); - if (ret != UR_RESULT_SUCCESS) { - throw ret; + if (lhs.hDevice) { + auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative); + if (ret != UR_RESULT_SUCCESS) { + throw ret; + } } - ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative); - if (ret != UR_RESULT_SUCCESS) { - throw ret; + + if (rhs.hDevice) { + auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative); + if (ret != UR_RESULT_SUCCESS) { + throw ret; + } } return lhsNative == rhsNative && lhs.type == rhs.type && @@ -146,16 +164,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs, lhs.poolHandle == rhs.poolHandle; } -inline std::size_t pool_descriptor::hash(const pool_descriptor &desc) { - ur_native_handle_t native; - auto ret = urDeviceGetNativeHandle(desc.hDevice, &native); - if (ret != UR_RESULT_SUCCESS) { - throw ret; - } - - return combine_hashes(0, desc.type, native, - isSharedAllocationReadOnlyOnDevice(desc), - desc.poolHandle); +inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) { + os << "pool handle: " << desc.poolHandle + << " context handle: " << desc.hContext + << " device handle: " << desc.hDevice << " memory type: " << desc.type + << " is read only: " << desc.deviceReadOnly; + return os; } inline std::pair> @@ -177,6 +191,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle, pool_descriptor &desc = descriptors.emplace_back(); desc.poolHandle = poolHandle; desc.hContext = hContext; + desc.hDevice = device; desc.type = UR_USM_TYPE_DEVICE; } { @@ -200,6 +215,69 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle, return {ret, descriptors}; } +template struct pool_manager { + private: + using desc_to_pool_map_t = std::unordered_map; + + desc_to_pool_map_t descToPoolMap; + + public: + static std::pair + create(desc_to_pool_map_t descToHandleMap = {}) { + auto manager = pool_manager(); + + for (auto &[desc, hPool] : descToHandleMap) { + auto ret = manager.addPool(desc, hPool); + if (ret != UR_RESULT_SUCCESS) { + return {ret, pool_manager()}; + } + } + + return {UR_RESULT_SUCCESS, std::move(manager)}; + } + + ur_result_t addPool(const D &desc, + umf::pool_unique_handle_t &hPool) noexcept { + if (!descToPoolMap.try_emplace(desc, std::move(hPool)).second) { + logger::error("Pool for pool descriptor: {}, already exists", desc); + return UR_RESULT_ERROR_INVALID_ARGUMENT; + } + + return UR_RESULT_SUCCESS; + } + + std::optional getPool(const D &desc) noexcept { + auto it = descToPoolMap.find(desc); + if (it == descToPoolMap.end()) { + logger::error("Pool descriptor doesn't match any existing pool: {}", + desc); + return std::nullopt; + } + + return it->second.get(); + } +}; + } // namespace usm +namespace std { +/// @brief hash specialization for usm::pool_descriptor +template <> struct hash { + inline size_t operator()(const usm::pool_descriptor &desc) const { + ur_native_handle_t native = nullptr; + if (desc.hDevice) { + auto ret = urDeviceGetNativeHandle(desc.hDevice, &native); + if (ret != UR_RESULT_SUCCESS) { + throw ret; + } + } + + return combine_hashes(0, desc.type, native, + isSharedAllocationReadOnlyOnDevice(desc), + desc.poolHandle); + } +}; + +} // namespace std + #endif /* USM_POOL_MANAGER_HPP */ diff --git a/test/usm/CMakeLists.txt b/test/usm/CMakeLists.txt index b673b6d1b9..fa5454d4db 100644 --- a/test/usm/CMakeLists.txt +++ b/test/usm/CMakeLists.txt @@ -10,6 +10,8 @@ function(add_usm_test name) add_ur_executable(${TEST_TARGET_NAME} ${UR_USM_TEST_DIR}/../conformance/source/environment.cpp ${UR_USM_TEST_DIR}/../conformance/source/main.cpp + ${UR_USM_TEST_DIR}/../unified_malloc_framework/common/provider.c + ${UR_USM_TEST_DIR}/../unified_malloc_framework/common/pool.c ${ARGN}) target_link_libraries(${TEST_TARGET_NAME} PRIVATE @@ -17,10 +19,12 @@ function(add_usm_test name) ${PROJECT_NAME}::loader ur_testing GTest::gtest_main) - add_test(NAME usm-${name} + add_test(NAME usm-${name} COMMAND ${TEST_TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) - set_tests_properties(usm-${name} PROPERTIES LABELS "usm") + set_tests_properties(usm-${name} PROPERTIES + LABELS "usm" + ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$\"") target_compile_definitions("usm_test-${name}" PRIVATE DEVICES_ENVIRONMENT) endfunction() diff --git a/test/usm/usmPoolManager.cpp b/test/usm/usmPoolManager.cpp index eaf44e119d..4365a5ce7b 100644 --- a/test/usm/usmPoolManager.cpp +++ b/test/usm/usmPoolManager.cpp @@ -3,19 +3,18 @@ // See LICENSE.TXT // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "../unified_malloc_framework/common/pool.hpp" -#include "../unified_malloc_framework/common/provider.hpp" #include "ur_pool_manager.hpp" -#include +#include "../unified_malloc_framework/common/pool.h" +#include "../unified_malloc_framework/common/provider.h" -#include +#include -struct urUsmPoolManagerTest +struct urUsmPoolDescriptorTest : public uur::urMultiDeviceContextTest, ::testing::WithParamInterface {}; -TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) { +TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) { auto &devices = uur::DevicesEnvironment::instance->devices; auto poolHandle = this->GetParam(); @@ -49,7 +48,71 @@ TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) { ASSERT_EQ(sharedPools, devices.size() * 2); } -INSTANTIATE_TEST_SUITE_P(urUsmPoolManagerTest, urUsmPoolManagerTest, +INSTANTIATE_TEST_SUITE_P(urUsmPoolDescriptorTest, urUsmPoolDescriptorTest, ::testing::Values(nullptr)); // TODO: add test with sub-devices + +struct urUsmPoolManagerTest : public uur::urContextTest { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE(urContextTest::SetUp()); + auto [ret, descs] = usm::pool_descriptor::create(nullptr, context); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + poolDescriptors = descs; + } + + std::vector poolDescriptors; +}; + +TEST_P(urUsmPoolManagerTest, poolManagerPopulate) { + auto [ret, manager] = usm::pool_manager::create(); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + for (auto &desc : poolDescriptors) { + // Populate the pool manager + auto pool = nullPoolCreate(); + ASSERT_NE(pool, nullptr); + auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy); + ASSERT_NE(poolUnique, nullptr); + ret = manager.addPool(desc, poolUnique); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + } + + for (auto &desc : poolDescriptors) { + // Confirm that there is a pool for each descriptor + auto hPoolOpt = manager.getPool(desc); + ASSERT_TRUE(hPoolOpt.has_value()); + ASSERT_NE(hPoolOpt.value(), nullptr); + } +} + +TEST_P(urUsmPoolManagerTest, poolManagerInsertExisting) { + auto [ret, manager] = usm::pool_manager::create(); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + auto desc = poolDescriptors[0]; + + auto pool = nullPoolCreate(); + ASSERT_NE(pool, nullptr); + auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy); + ASSERT_NE(poolUnique, nullptr); + + ret = manager.addPool(desc, poolUnique); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + // Inserting an existing key should return an error + ret = manager.addPool(desc, poolUnique); + ASSERT_EQ(ret, UR_RESULT_ERROR_INVALID_ARGUMENT); +} + +TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) { + auto [ret, manager] = usm::pool_manager::create(); + ASSERT_EQ(ret, UR_RESULT_SUCCESS); + + for (auto &desc : poolDescriptors) { + auto hPool = manager.getPool(desc); + ASSERT_FALSE(hPool.has_value()); + } +} + +UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);