Skip to content

Commit

Permalink
Add a basic pool manager for memory pools
Browse files Browse the repository at this point in the history
  • Loading branch information
kswiecicki committed Jul 3, 2023
1 parent 25abfe6 commit d66270b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 21 deletions.
5 changes: 1 addition & 4 deletions source/adapters/null/ur_null.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ class __urdlllocal context_t {
context_t();
~context_t() = default;

void *get() {
static uint64_t count = 0x80800000;
return reinterpret_cast<void *>(++count);
}
void *get() { return reinterpret_cast<void *>(this + 0x80800000); };
};

extern context_t d_context;
Expand Down
105 changes: 88 additions & 17 deletions source/common/ur_pool_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@
#ifndef USM_POOL_MANAGER_HPP
#define USM_POOL_MANAGER_HPP 1

#include "logger/ur_logger.hpp"
#include "umf_helpers.hpp"
#include "umf_pools/disjoint_pool.hpp"
#include "ur_api.h"
#include "ur_pool_manager.hpp"
#include "ur_util.hpp"

#include <umf/memory_pool.h>
#include <umf/memory_provider.h>

#include <functional>
#include <vector>

Expand All @@ -29,8 +34,9 @@ struct pool_descriptor {
ur_usm_type_t type;
bool deviceReadOnly;

static bool equal(const pool_descriptor &lhs, const pool_descriptor &rhs);
static std::size_t hash(const pool_descriptor &desc);
bool operator==(const pool_descriptor &other) const;
friend std::ostream &operator<<(std::ostream &os,
const pool_descriptor &desc);
static std::pair<ur_result_t, std::vector<pool_descriptor>>
create(ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
};
Expand Down Expand Up @@ -75,7 +81,7 @@ urGetSubDevices(ur_device_handle_t hDevice) {

inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
size_t deviceCount;
size_t deviceCount = 0;
auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES,
sizeof(deviceCount), &deviceCount, nullptr);
if (ret != UR_RESULT_SUCCESS) {
Expand Down Expand Up @@ -122,9 +128,10 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly;
}

inline bool pool_descriptor::equal(const pool_descriptor &lhs,
const pool_descriptor &rhs) {
ur_native_handle_t lhsNative, rhsNative;
inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
const pool_descriptor &lhs = *this;
const pool_descriptor &rhs = other;
ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr;

// We want to share a memory pool for sub-devices and sub-sub devices.
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
Expand All @@ -146,16 +153,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs,
lhs.poolHandle == rhs.poolHandle;
}

inline std::size_t pool_descriptor::hash(const pool_descriptor &desc) {
ur_native_handle_t native;
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}

return combine_hashes(0, desc.type, native,
isSharedAllocationReadOnlyOnDevice(desc),
desc.poolHandle);
inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) {
os << "pool handle: " << desc.poolHandle
<< " context handle: " << desc.hContext
<< " device handle: " << desc.hDevice << " memory type: " << desc.type
<< " is read only: " << desc.deviceReadOnly;
return os;
}

inline std::pair<ur_result_t, std::vector<pool_descriptor>>
Expand Down Expand Up @@ -200,6 +203,74 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
return {ret, descriptors};
}

template <typename D> struct pool_manager {
private:
std::unordered_map<D, umf::pool_unique_handle_t> descToPoolMap;
pool_manager() : descToPoolMap(){};

public:
static std::pair<umf_result_t, pool_manager> create(
std::unordered_map<D, umf::pool_unique_handle_t> descToHandleMap = {}) {
auto manager = pool_manager();

for (auto &[desc, hPool] : descToHandleMap) {
auto ret = manager.addPool(desc, hPool);
if (ret != UMF_RESULT_SUCCESS) {
return {ret, pool_manager()};
}
}

return {UMF_RESULT_SUCCESS, std::move(manager)};
}

umf_result_t addPool(const D &desc,
umf::pool_unique_handle_t &hPool) noexcept {
if (descToPoolMap.count(desc) != 0) {
logger::error("Pool for pool descriptor: {}, already exists", desc);
return UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

try {
descToPoolMap.emplace(desc, std::move(hPool));
} catch (std::bad_alloc &) {
return UMF_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
return UMF_RESULT_ERROR_UNKNOWN;
}

return UMF_RESULT_SUCCESS;
}

std::optional<umf_memory_pool_handle_t> getPool(const D &desc) noexcept {
auto it = descToPoolMap.find(desc);
if (it == descToPoolMap.end()) {
logger::error("Pool descriptor doesn't match any existing pool: {}",
desc);
return std::nullopt;
}

return it->second.get();
}
};

} // namespace usm

namespace std {
/// @brief hash specialization for usm::pool_descriptor
template <> struct hash<usm::pool_descriptor> {
inline size_t operator()(const usm::pool_descriptor &desc) const {
ur_native_handle_t native = nullptr;
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}

return combine_hashes(0, desc.type, native,
isSharedAllocationReadOnlyOnDevice(desc),
desc.poolHandle);
}
};

} // namespace std

#endif /* USM_POOL_MANAGER_HPP */

0 comments on commit d66270b

Please sign in to comment.