Skip to content

Commit

Permalink
[SYCL][UR][CUDA] Access UMF pool handles through usm::pool_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
kswiecicki committed Dec 15, 2023
1 parent 67e4d1b commit a63a56a
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 42 deletions.
161 changes: 124 additions & 37 deletions source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,19 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
return USMHostAllocImpl(ppMem, hContext, nullptr, size, alignment);
}

auto UMFPool = hPool->HostMemPool.get();
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
usm::pool_descriptor Desc = {hPool, hContext, nullptr, UR_USM_TYPE_HOST,
false};
auto hPoolInternalOpt = hPool->PoolManager.getPool(Desc);
if (!hPoolInternalOpt.has_value()) {
// Internal error, every L0 context and usm pool should have Host, Device,
// Shared and SharedReadOnly UMF pools.
return UR_RESULT_ERROR_UNKNOWN;
}

auto hPoolInternal = hPoolInternalOpt.value();
*ppMem = umfPoolAlignedMalloc(hPoolInternal, size, alignment);
if (*ppMem == nullptr) {
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
auto umfErr = umfPoolGetLastAllocationError(hPoolInternal);
return umf::umf2urResult(umfErr);
}
return UR_RESULT_SUCCESS;
Expand All @@ -60,10 +69,19 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
alignment);
}

auto UMFPool = hPool->DeviceMemPool.get();
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
usm::pool_descriptor Desc = {hPool, hContext, hDevice, UR_USM_TYPE_DEVICE,
false};
auto hPoolInternalOpt = hPool->PoolManager.getPool(Desc);
if (!hPoolInternalOpt.has_value()) {
// Internal error, every L0 context and usm pool should have Host, Device,
// Shared and SharedReadOnly UMF pools.
return UR_RESULT_ERROR_UNKNOWN;
}

auto hPoolInternal = hPoolInternalOpt.value();
*ppMem = umfPoolAlignedMalloc(hPoolInternal, size, alignment);
if (*ppMem == nullptr) {
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
auto umfErr = umfPoolGetLastAllocationError(hPoolInternal);
return umf::umf2urResult(umfErr);
}
return UR_RESULT_SUCCESS;
Expand All @@ -85,10 +103,19 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
alignment);
}

auto UMFPool = hPool->SharedMemPool.get();
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
usm::pool_descriptor Desc = {hPool, hContext, hDevice, UR_USM_TYPE_SHARED,
false};
auto hPoolInternalOpt = hPool->PoolManager.getPool(Desc);
if (!hPoolInternalOpt.has_value()) {
// Internal error, every L0 context and usm pool should have Host, Device,
// Shared and SharedReadOnly UMF pools.
return UR_RESULT_ERROR_UNKNOWN;
}

auto hPoolInternal = hPoolInternalOpt.value();
*ppMem = umfPoolAlignedMalloc(hPoolInternal, size, alignment);
if (*ppMem == nullptr) {
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
auto umfErr = umfPoolGetLastAllocationError(hPoolInternal);
return umf::umf2urResult(umfErr);
}
return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -376,6 +403,48 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
return USMHostAllocImpl(ResultPtr, Context, nullptr, Size, Alignment);
}

// Template helper function for creating USM pools for given pool descriptor.
template <typename P, typename... Args>
std::pair<ur_result_t, umf::pool_unique_handle_t>
createUMFPoolForDesc(usm::pool_descriptor &Desc, Args &&...args) {
umf_result_t UmfRet = UMF_RESULT_SUCCESS;
umf::provider_unique_handle_t MemProvider = nullptr;

switch (Desc.type) {
case UR_USM_TYPE_HOST: {
std::tie(UmfRet, MemProvider) =
umf::memoryProviderMakeUnique<USMHostMemoryProvider>(Desc.hContext,
Desc.hDevice);
break;
}
case UR_USM_TYPE_DEVICE: {
std::tie(UmfRet, MemProvider) =
umf::memoryProviderMakeUnique<USMDeviceMemoryProvider>(Desc.hContext,
Desc.hDevice);
break;
}
case UR_USM_TYPE_SHARED: {
std::tie(UmfRet, MemProvider) =
umf::memoryProviderMakeUnique<USMSharedMemoryProvider>(Desc.hContext,
Desc.hDevice);
break;
}
default:
UmfRet = UMF_RESULT_ERROR_INVALID_ARGUMENT;
}

if (UmfRet)
return std::pair<ur_result_t, umf::pool_unique_handle_t>{
umf::umf2urResult(UmfRet), nullptr};

umf::pool_unique_handle_t Pool = nullptr;
std::tie(UmfRet, Pool) =
umf::poolMakeUnique<P, 1>({std::move(MemProvider)}, args...);

return std::pair<ur_result_t, umf::pool_unique_handle_t>{
umf::umf2urResult(UmfRet), std::move(Pool)};
};

ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc)
: Context(Context) {
Expand All @@ -399,40 +468,58 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
pNext = BaseDesc->pNext;
}

auto MemProvider =
umf::memoryProviderMakeUnique<USMHostMemoryProvider>(Context, nullptr)
.second;

HostMemPool =
umf::poolMakeUnique<usm::DisjointPool, 1>(
{std::move(MemProvider)},
this->DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Host])
.second;
ur_result_t Ret;
std::tie(Ret, PoolManager) =
usm::pool_manager<usm::pool_descriptor>::create();
if (Ret) {
throw UsmAllocationException(Ret);
}

auto Device = Context->DeviceID;
MemProvider =
umf::memoryProviderMakeUnique<USMDeviceMemoryProvider>(Context, Device)
.second;
DeviceMemPool =
umf::poolMakeUnique<usm::DisjointPool, 1>(
{std::move(MemProvider)},
this->DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Device])
.second;

MemProvider =
umf::memoryProviderMakeUnique<USMSharedMemoryProvider>(Context, Device)
.second;
SharedMemPool =
umf::poolMakeUnique<usm::DisjointPool, 1>(
{std::move(MemProvider)},
this->DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Shared])
.second;
auto UrUSMPool = reinterpret_cast<ur_usm_pool_handle_t>(this);

// TODO: Replace this with appropriate usm::pool_descriptor 'create' static
// function.
usm::pool_descriptor Descs[] = {
{UrUSMPool, Context, nullptr, UR_USM_TYPE_HOST, false},
{UrUSMPool, Context, Device, UR_USM_TYPE_DEVICE, false},
{UrUSMPool, Context, Device, UR_USM_TYPE_SHARED, false}};

// Helper lambda function matching USM type to DisjointPoolMemType
auto descTypeToDisjointPoolType =
[](usm::pool_descriptor &Desc) -> usm::DisjointPoolMemType {
switch (Desc.type) {
case UR_USM_TYPE_HOST:
return usm::DisjointPoolMemType::Host;
case UR_USM_TYPE_DEVICE:
return usm::DisjointPoolMemType::Device;
case UR_USM_TYPE_SHARED:
return (Desc.deviceReadOnly) ? usm::DisjointPoolMemType::SharedReadOnly
: usm::DisjointPoolMemType::Shared;
default:
// Added to suppress 'not all control paths return a value' warning.
return usm::DisjointPoolMemType::All;
}
};

for (auto &Desc : Descs) {
umf::pool_unique_handle_t Pool = nullptr;
auto PoolType = descTypeToDisjointPoolType(Desc);

std::tie(Ret, Pool) = createUMFPoolForDesc<usm::DisjointPool>(
Desc, this->DisjointPoolConfigs.Configs[PoolType]);
if (Ret) {
throw UsmAllocationException(Ret);
}

PoolManager.addPool(Desc, Pool);
}

Context->addPool(this);
}

bool ur_usm_pool_handle_t_::hasUMFPool(umf_memory_pool_t *umf_pool) {
return DeviceMemPool.get() == umf_pool || SharedMemPool.get() == umf_pool ||
HostMemPool.get() == umf_pool;
return PoolManager.hasPool(umf_pool);
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
Expand Down
5 changes: 2 additions & 3 deletions source/adapters/cuda/usm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <umf_helpers.hpp>
#include <umf_pools/disjoint_pool_config_parser.hpp>
#include <ur_pool_manager.hpp>

usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig();

Expand All @@ -23,9 +24,7 @@ struct ur_usm_pool_handle_t_ {
usm::DisjointPoolAllConfigs DisjointPoolConfigs =
usm::DisjointPoolAllConfigs();

umf::pool_unique_handle_t DeviceMemPool;
umf::pool_unique_handle_t SharedMemPool;
umf::pool_unique_handle_t HostMemPool;
usm::pool_manager<usm::pool_descriptor> PoolManager;

ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc);
Expand Down
10 changes: 8 additions & 2 deletions source/common/ur_pool_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,19 @@ template <typename D> struct pool_manager {
std::optional<umf_memory_pool_handle_t> getPool(const D &desc) noexcept {
auto it = descToPoolMap.find(desc);
if (it == descToPoolMap.end()) {
logger::error("Pool descriptor doesn't match any existing pool: {}",
desc);
logger::error(
"Pool descriptor: {}, doesn't match any existing pool", desc);
return std::nullopt;
}

return it->second.get();
}

bool hasPool(umf_memory_pool_handle_t hPool) noexcept {
return std::any_of(
descToPoolMap.begin(), descToPoolMap.end(),
[&hPool](const auto &pair) { return hPool == pair.second.get(); });
}
};

} // namespace usm
Expand Down

0 comments on commit a63a56a

Please sign in to comment.