Skip to content

Commit

Permalink
Add a basic pool manager for memory pools
Browse files Browse the repository at this point in the history
  • Loading branch information
kswiecicki committed Sep 7, 2023
1 parent a3120e7 commit 1c98f6e
Showing 1 changed file with 98 additions and 24 deletions.
122 changes: 98 additions & 24 deletions source/common/ur_pool_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
#ifndef USM_POOL_MANAGER_HPP
#define USM_POOL_MANAGER_HPP 1

#include "logger/ur_logger.hpp"
#include "umf_helpers.hpp"
#include "umf_pools/disjoint_pool.hpp"
#include "ur_api.h"
#include "ur_pool_manager.hpp"
#include "ur_util.hpp"

#include <umf/memory_pool.h>
#include <umf/memory_provider.h>

#include <functional>
#include <unordered_map>
#include <vector>

namespace usm {
Expand All @@ -29,8 +35,9 @@ struct pool_descriptor {
ur_usm_type_t type;
bool deviceReadOnly;

static bool equal(const pool_descriptor &lhs, const pool_descriptor &rhs);
static std::size_t hash(const pool_descriptor &desc);
bool operator==(const pool_descriptor &other) const;
friend std::ostream &operator<<(std::ostream &os,
const pool_descriptor &desc);
static std::pair<ur_result_t, std::vector<pool_descriptor>>
create(ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
};
Expand Down Expand Up @@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) {

inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
size_t deviceCount;
size_t deviceCount = 0;
auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES,
sizeof(deviceCount), &deviceCount, nullptr);
if (ret != UR_RESULT_SUCCESS) {
if (ret != UR_RESULT_SUCCESS || deviceCount == 0) {
return {ret, {}};
}

Expand Down Expand Up @@ -122,22 +129,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly;
}

inline bool pool_descriptor::equal(const pool_descriptor &lhs,
const pool_descriptor &rhs) {
ur_native_handle_t lhsNative, rhsNative;
inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
const pool_descriptor &lhs = *this;
const pool_descriptor &rhs = other;
ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr;

// We want to share a memory pool for sub-devices and sub-sub devices.
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
// they share the same native_handle_t (which is used by UMF provider).
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
// TODO: is this L0 specific?
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
if (lhs.hDevice) {
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}
ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;

if (rhs.hDevice) {
auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}

return lhsNative == rhsNative && lhs.type == rhs.type &&
Expand All @@ -146,16 +159,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs,
lhs.poolHandle == rhs.poolHandle;
}

inline std::size_t pool_descriptor::hash(const pool_descriptor &desc) {
ur_native_handle_t native;
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}

return combine_hashes(0, desc.type, native,
isSharedAllocationReadOnlyOnDevice(desc),
desc.poolHandle);
inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) {
os << "pool handle: " << desc.poolHandle
<< " context handle: " << desc.hContext
<< " device handle: " << desc.hDevice << " memory type: " << desc.type
<< " is read only: " << desc.deviceReadOnly;
return os;
}

inline std::pair<ur_result_t, std::vector<pool_descriptor>>
Expand All @@ -177,6 +186,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
pool_descriptor &desc = descriptors.emplace_back();
desc.poolHandle = poolHandle;
desc.hContext = hContext;
desc.hDevice = device;
desc.type = UR_USM_TYPE_DEVICE;
}
{
Expand All @@ -200,6 +210,70 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
return {ret, descriptors};
}

template <typename D> struct pool_manager {
private:
using desc_to_pool_map_t = std::unordered_map<D, umf::pool_unique_handle_t>;

desc_to_pool_map_t descToPoolMap;
pool_manager() : descToPoolMap(){};

public:
static std::pair<umf_result_t, pool_manager>
create(desc_to_pool_map_t descToHandleMap = {}) {
auto manager = pool_manager();

for (auto &[desc, hPool] : descToHandleMap) {
auto ret = manager.addPool(desc, hPool);
if (ret != UMF_RESULT_SUCCESS) {
return {ret, pool_manager()};
}
}

return {UMF_RESULT_SUCCESS, std::move(manager)};
}

umf_result_t addPool(const D &desc,
umf::pool_unique_handle_t &hPool) noexcept {
if (!descToPoolMap.try_emplace(desc, std::move(hPool)).second) {
logger::error("Pool for pool descriptor: {}, already exists", desc);
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

return UMF_RESULT_SUCCESS;
}

std::optional<umf_memory_pool_handle_t> getPool(const D &desc) noexcept {
auto it = descToPoolMap.find(desc);
if (it == descToPoolMap.end()) {
logger::error("Pool descriptor doesn't match any existing pool: {}",
desc);
return std::nullopt;
}

return it->second.get();
}
};

} // namespace usm

namespace std {
/// @brief hash specialization for usm::pool_descriptor
template <> struct hash<usm::pool_descriptor> {
inline size_t operator()(const usm::pool_descriptor &desc) const {
ur_native_handle_t native = nullptr;
if (desc.hDevice) {
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}

return combine_hashes(0, desc.type, native,
isSharedAllocationReadOnlyOnDevice(desc),
desc.poolHandle);
}
};

} // namespace std

#endif /* USM_POOL_MANAGER_HPP */

0 comments on commit 1c98f6e

Please sign in to comment.