From 8c1c628da3d462e3d43f18d76360bcec4a52fda3 Mon Sep 17 00:00:00 2001 From: Krzysztof Swiecicki Date: Wed, 6 Sep 2023 14:42:37 +0200 Subject: [PATCH] Add a basic pool manager for memory pools --- source/common/ur_pool_manager.hpp | 121 ++++++++++++++++++++++++------ 1 file changed, 97 insertions(+), 24 deletions(-) diff --git a/source/common/ur_pool_manager.hpp b/source/common/ur_pool_manager.hpp index c4da5d149f..6e70c29e21 100644 --- a/source/common/ur_pool_manager.hpp +++ b/source/common/ur_pool_manager.hpp @@ -11,11 +11,17 @@ #ifndef USM_POOL_MANAGER_HPP #define USM_POOL_MANAGER_HPP 1 +#include "logger/ur_logger.hpp" +#include "umf_helpers.hpp" +#include "umf_pools/disjoint_pool.hpp" #include "ur_api.h" -#include "ur_pool_manager.hpp" #include "ur_util.hpp" +#include +#include + #include +#include #include namespace usm { @@ -29,8 +35,9 @@ struct pool_descriptor { ur_usm_type_t type; bool deviceReadOnly; - static bool equal(const pool_descriptor &lhs, const pool_descriptor &rhs); - static std::size_t hash(const pool_descriptor &desc); + bool operator==(const pool_descriptor &other) const; + friend std::ostream &operator<<(std::ostream &os, + const pool_descriptor &desc); static std::pair> create(ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext); }; @@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) { inline std::pair> urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) { - size_t deviceCount; + size_t deviceCount = 0; auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof(deviceCount), &deviceCount, nullptr); - if (ret != UR_RESULT_SUCCESS) { + if (ret != UR_RESULT_SUCCESS || deviceCount == 0) { return {ret, {}}; } @@ -122,22 +129,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) { return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly; } -inline bool pool_descriptor::equal(const pool_descriptor &lhs, - const pool_descriptor &rhs) { - ur_native_handle_t lhsNative, rhsNative; +inline bool pool_descriptor::operator==(const pool_descriptor &other) const { + const pool_descriptor &lhs = *this; + const pool_descriptor &rhs = other; + ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr; // We want to share a memory pool for sub-devices and sub-sub devices. // Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but // they share the same native_handle_t (which is used by UMF provider). // Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1 // TODO: is this L0 specific? - auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative); - if (ret != UR_RESULT_SUCCESS) { - throw ret; + if (lhs.hDevice) { + auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative); + if (ret != UR_RESULT_SUCCESS) { + throw ret; + } } - ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative); - if (ret != UR_RESULT_SUCCESS) { - throw ret; + + if (rhs.hDevice) { + auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative); + if (ret != UR_RESULT_SUCCESS) { + throw ret; + } } return lhsNative == rhsNative && lhs.type == rhs.type && @@ -146,16 +159,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs, lhs.poolHandle == rhs.poolHandle; } -inline std::size_t pool_descriptor::hash(const pool_descriptor &desc) { - ur_native_handle_t native; - auto ret = urDeviceGetNativeHandle(desc.hDevice, &native); - if (ret != UR_RESULT_SUCCESS) { - throw ret; - } - - return combine_hashes(0, desc.type, native, - isSharedAllocationReadOnlyOnDevice(desc), - desc.poolHandle); +inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) { + os << "pool handle: " << desc.poolHandle + << " context handle: " << desc.hContext + << " device handle: " << desc.hDevice << " memory type: " << desc.type + << " is read only: " << desc.deviceReadOnly; + return os; } inline std::pair> @@ -177,6 +186,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle, pool_descriptor &desc = descriptors.emplace_back(); desc.poolHandle = poolHandle; desc.hContext = hContext; + desc.hDevice = device; desc.type = UR_USM_TYPE_DEVICE; } { @@ -200,6 +210,69 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle, return {ret, descriptors}; } +template struct pool_manager { + private: + using desc_to_pool_map_t = std::unordered_map; + + desc_to_pool_map_t descToPoolMap; + + public: + static std::pair + create(desc_to_pool_map_t descToHandleMap = {}) { + auto manager = pool_manager(); + + for (auto &[desc, hPool] : descToHandleMap) { + auto ret = manager.addPool(desc, hPool); + if (ret != UR_RESULT_SUCCESS) { + return {ret, pool_manager()}; + } + } + + return {UR_RESULT_SUCCESS, std::move(manager)}; + } + + ur_result_t addPool(const D &desc, + umf::pool_unique_handle_t &hPool) noexcept { + if (!descToPoolMap.try_emplace(desc, std::move(hPool)).second) { + logger::error("Pool for pool descriptor: {}, already exists", desc); + return UR_RESULT_ERROR_INVALID_ARGUMENT; + } + + return UR_RESULT_SUCCESS; + } + + std::optional getPool(const D &desc) noexcept { + auto it = descToPoolMap.find(desc); + if (it == descToPoolMap.end()) { + logger::error("Pool descriptor doesn't match any existing pool: {}", + desc); + return std::nullopt; + } + + return it->second.get(); + } +}; + } // namespace usm +namespace std { +/// @brief hash specialization for usm::pool_descriptor +template <> struct hash { + inline size_t operator()(const usm::pool_descriptor &desc) const { + ur_native_handle_t native = nullptr; + if (desc.hDevice) { + auto ret = urDeviceGetNativeHandle(desc.hDevice, &native); + if (ret != UR_RESULT_SUCCESS) { + throw ret; + } + } + + return combine_hashes(0, desc.type, native, + isSharedAllocationReadOnlyOnDevice(desc), + desc.poolHandle); + } +}; + +} // namespace std + #endif /* USM_POOL_MANAGER_HPP */