Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a basic pool manager for memory pools #630

Merged
merged 3 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 104 additions & 26 deletions source/common/ur_pool_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
#ifndef USM_POOL_MANAGER_HPP
#define USM_POOL_MANAGER_HPP 1

#include "logger/ur_logger.hpp"
#include "umf_helpers.hpp"
#include "umf_pools/disjoint_pool.hpp"
#include "ur_api.h"
#include "ur_pool_manager.hpp"
#include "ur_util.hpp"

#include <umf/memory_pool.h>
#include <umf/memory_provider.h>

#include <functional>
#include <unordered_map>
#include <vector>

namespace usm {
Expand All @@ -29,8 +35,9 @@ struct pool_descriptor {
ur_usm_type_t type;
bool deviceReadOnly;

static bool equal(const pool_descriptor &lhs, const pool_descriptor &rhs);
static std::size_t hash(const pool_descriptor &desc);
bool operator==(const pool_descriptor &other) const;
friend std::ostream &operator<<(std::ostream &os,
const pool_descriptor &desc);
static std::pair<ur_result_t, std::vector<pool_descriptor>>
create(ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
};
Expand All @@ -45,8 +52,8 @@ urGetSubDevices(ur_device_handle_t hDevice) {
}

ur_device_partition_property_t prop;
prop.type = UR_DEVICE_PARTITION_EQUALLY;
prop.value.equally = nComputeUnits;
prop.type = UR_DEVICE_PARTITION_BY_CSLICE;
prop.value.affinity_domain = 0;

ur_device_partition_properties_t properties{
UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES,
Expand Down Expand Up @@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) {

inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
size_t deviceCount;
size_t deviceCount = 0;
auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES,
sizeof(deviceCount), &deviceCount, nullptr);
if (ret != UR_RESULT_SUCCESS) {
if (ret != UR_RESULT_SUCCESS || deviceCount == 0) {
return {ret, {}};
}

Expand Down Expand Up @@ -110,6 +117,11 @@ urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
for (size_t i = 0; i < deviceCount; i++) {
ret = addPoolsForDevicesRec(devices[i]);
if (ret != UR_RESULT_SUCCESS) {
if (ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
// Return main devices when sub-devices are unsupported.
return {ret, std::move(devices)};
}

return {ret, {}};
}
}
Expand All @@ -122,22 +134,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly;
}

inline bool pool_descriptor::equal(const pool_descriptor &lhs,
const pool_descriptor &rhs) {
ur_native_handle_t lhsNative, rhsNative;
inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
const pool_descriptor &lhs = *this;
const pool_descriptor &rhs = other;
ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr;

// We want to share a memory pool for sub-devices and sub-sub devices.
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
// they share the same native_handle_t (which is used by UMF provider).
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
// TODO: is this L0 specific?
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
if (lhs.hDevice) {
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}
ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;

if (rhs.hDevice) {
auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}

return lhsNative == rhsNative && lhs.type == rhs.type &&
Expand All @@ -146,16 +164,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs,
lhs.poolHandle == rhs.poolHandle;
}

inline std::size_t pool_descriptor::hash(const pool_descriptor &desc) {
ur_native_handle_t native;
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}

return combine_hashes(0, desc.type, native,
isSharedAllocationReadOnlyOnDevice(desc),
desc.poolHandle);
inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) {
os << "pool handle: " << desc.poolHandle
<< " context handle: " << desc.hContext
<< " device handle: " << desc.hDevice << " memory type: " << desc.type
<< " is read only: " << desc.deviceReadOnly;
return os;
}

inline std::pair<ur_result_t, std::vector<pool_descriptor>>
Expand All @@ -177,6 +191,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
pool_descriptor &desc = descriptors.emplace_back();
desc.poolHandle = poolHandle;
desc.hContext = hContext;
desc.hDevice = device;
desc.type = UR_USM_TYPE_DEVICE;
}
{
Expand All @@ -200,6 +215,69 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
return {ret, descriptors};
}

template <typename D> struct pool_manager {
private:
using desc_to_pool_map_t = std::unordered_map<D, umf::pool_unique_handle_t>;

desc_to_pool_map_t descToPoolMap;

public:
static std::pair<ur_result_t, pool_manager>
create(desc_to_pool_map_t descToHandleMap = {}) {
auto manager = pool_manager();

for (auto &[desc, hPool] : descToHandleMap) {
auto ret = manager.addPool(desc, hPool);
if (ret != UR_RESULT_SUCCESS) {
return {ret, pool_manager()};
}
}

return {UR_RESULT_SUCCESS, std::move(manager)};
}

ur_result_t addPool(const D &desc,
umf::pool_unique_handle_t &hPool) noexcept {
if (!descToPoolMap.try_emplace(desc, std::move(hPool)).second) {
logger::error("Pool for pool descriptor: {}, already exists", desc);
return UR_RESULT_ERROR_INVALID_ARGUMENT;
}

return UR_RESULT_SUCCESS;
}

std::optional<umf_memory_pool_handle_t> getPool(const D &desc) noexcept {
auto it = descToPoolMap.find(desc);
if (it == descToPoolMap.end()) {
logger::error("Pool descriptor doesn't match any existing pool: {}",
desc);
return std::nullopt;
}

return it->second.get();
}
};

} // namespace usm

namespace std {
/// @brief hash specialization for usm::pool_descriptor
template <> struct hash<usm::pool_descriptor> {
inline size_t operator()(const usm::pool_descriptor &desc) const {
ur_native_handle_t native = nullptr;
if (desc.hDevice) {
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
throw ret;
}
}

return combine_hashes(0, desc.type, native,
isSharedAllocationReadOnlyOnDevice(desc),
desc.poolHandle);
}
};

} // namespace std

#endif /* USM_POOL_MANAGER_HPP */
8 changes: 6 additions & 2 deletions test/usm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,21 @@ function(add_usm_test name)
add_ur_executable(${TEST_TARGET_NAME}
${UR_USM_TEST_DIR}/../conformance/source/environment.cpp
${UR_USM_TEST_DIR}/../conformance/source/main.cpp
${UR_USM_TEST_DIR}/../unified_malloc_framework/common/provider.c
${UR_USM_TEST_DIR}/../unified_malloc_framework/common/pool.c
${ARGN})
target_link_libraries(${TEST_TARGET_NAME}
PRIVATE
${PROJECT_NAME}::common
${PROJECT_NAME}::loader
ur_testing
GTest::gtest_main)
add_test(NAME usm-${name}
add_test(NAME usm-${name}
COMMAND ${TEST_TARGET_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_tests_properties(usm-${name} PROPERTIES LABELS "usm")
set_tests_properties(usm-${name} PROPERTIES
LABELS "usm"
ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_null>\"")
target_compile_definitions("usm_test-${name}" PRIVATE DEVICES_ENVIRONMENT)
endfunction()

Expand Down
77 changes: 70 additions & 7 deletions test/usm/usmPoolManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "../unified_malloc_framework/common/pool.hpp"
#include "../unified_malloc_framework/common/provider.hpp"
#include "ur_pool_manager.hpp"

#include <uur/fixtures.h>
#include "../unified_malloc_framework/common/pool.h"
#include "../unified_malloc_framework/common/provider.h"

#include <unordered_set>
#include <uur/fixtures.h>

struct urUsmPoolManagerTest
struct urUsmPoolDescriptorTest
: public uur::urMultiDeviceContextTest,
::testing::WithParamInterface<ur_usm_pool_handle_t> {};

TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) {
TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) {
auto &devices = uur::DevicesEnvironment::instance->devices;
auto poolHandle = this->GetParam();

Expand Down Expand Up @@ -49,7 +48,71 @@ TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) {
ASSERT_EQ(sharedPools, devices.size() * 2);
}

INSTANTIATE_TEST_SUITE_P(urUsmPoolManagerTest, urUsmPoolManagerTest,
INSTANTIATE_TEST_SUITE_P(urUsmPoolDescriptorTest, urUsmPoolDescriptorTest,
::testing::Values(nullptr));

// TODO: add test with sub-devices

struct urUsmPoolManagerTest : public uur::urContextTest {
void SetUp() override {
UUR_RETURN_ON_FATAL_FAILURE(urContextTest::SetUp());
auto [ret, descs] = usm::pool_descriptor::create(nullptr, context);
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
poolDescriptors = descs;
}

std::vector<usm::pool_descriptor> poolDescriptors;
};

TEST_P(urUsmPoolManagerTest, poolManagerPopulate) {
auto [ret, manager] = usm::pool_manager<usm::pool_descriptor>::create();
ASSERT_EQ(ret, UR_RESULT_SUCCESS);

for (auto &desc : poolDescriptors) {
// Populate the pool manager
auto pool = nullPoolCreate();
ASSERT_NE(pool, nullptr);
auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy);
ASSERT_NE(poolUnique, nullptr);
ret = manager.addPool(desc, poolUnique);
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
}

for (auto &desc : poolDescriptors) {
// Confirm that there is a pool for each descriptor
auto hPoolOpt = manager.getPool(desc);
ASSERT_TRUE(hPoolOpt.has_value());
ASSERT_NE(hPoolOpt.value(), nullptr);
}
}

TEST_P(urUsmPoolManagerTest, poolManagerInsertExisting) {
auto [ret, manager] = usm::pool_manager<usm::pool_descriptor>::create();
ASSERT_EQ(ret, UR_RESULT_SUCCESS);

auto desc = poolDescriptors[0];

auto pool = nullPoolCreate();
ASSERT_NE(pool, nullptr);
auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy);
ASSERT_NE(poolUnique, nullptr);

ret = manager.addPool(desc, poolUnique);
ASSERT_EQ(ret, UR_RESULT_SUCCESS);

// Inserting an existing key should return an error
ret = manager.addPool(desc, poolUnique);
ASSERT_EQ(ret, UR_RESULT_ERROR_INVALID_ARGUMENT);
}

TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) {
auto [ret, manager] = usm::pool_manager<usm::pool_descriptor>::create();
ASSERT_EQ(ret, UR_RESULT_SUCCESS);

for (auto &desc : poolDescriptors) {
auto hPool = manager.getPool(desc);
ASSERT_FALSE(hPool.has_value());
}
}

UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);