From 34d722a99a656f50e748b8d7580cba85b454b351 Mon Sep 17 00:00:00 2001 From: Krzysztof Swiecicki Date: Thu, 12 Oct 2023 15:30:32 +0200 Subject: [PATCH] [SYCL][UR][CUDA] Access UMF pool handles through usm::pool_manager --- source/adapters/cuda/usm.cpp | 161 +++++++++++++++++++++++------- source/adapters/cuda/usm.hpp | 5 +- source/common/ur_pool_manager.hpp | 10 +- 3 files changed, 134 insertions(+), 42 deletions(-) diff --git a/source/adapters/cuda/usm.cpp b/source/adapters/cuda/usm.cpp index 8929fb7fa1..c70ad07b93 100644 --- a/source/adapters/cuda/usm.cpp +++ b/source/adapters/cuda/usm.cpp @@ -36,10 +36,19 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc, return USMHostAllocImpl(ppMem, hContext, nullptr, size, alignment); } - auto UMFPool = hPool->HostMemPool.get(); - *ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment); + usm::pool_descriptor Desc = {hPool, hContext, nullptr, UR_USM_TYPE_HOST, + false}; + auto hPoolInternalOpt = hPool->PoolManager.getPool(Desc); + if (!hPoolInternalOpt.has_value()) { + // Internal error, every usm pool should have Host, Device, + // Shared UMF pools. + return UR_RESULT_ERROR_UNKNOWN; + } + + auto hPoolInternal = hPoolInternalOpt.value(); + *ppMem = umfPoolAlignedMalloc(hPoolInternal, size, alignment); if (*ppMem == nullptr) { - auto umfErr = umfPoolGetLastAllocationError(UMFPool); + auto umfErr = umfPoolGetLastAllocationError(hPoolInternal); return umf::umf2urResult(umfErr); } return UR_RESULT_SUCCESS; @@ -61,10 +70,19 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice, alignment); } - auto UMFPool = hPool->DeviceMemPool.get(); - *ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment); + usm::pool_descriptor Desc = {hPool, hContext, hDevice, UR_USM_TYPE_DEVICE, + false}; + auto hPoolInternalOpt = hPool->PoolManager.getPool(Desc); + if (!hPoolInternalOpt.has_value()) { + // Internal error, every usm pool should have Host, Device, + // Shared UMF pools. + return UR_RESULT_ERROR_UNKNOWN; + } + + auto hPoolInternal = hPoolInternalOpt.value(); + *ppMem = umfPoolAlignedMalloc(hPoolInternal, size, alignment); if (*ppMem == nullptr) { - auto umfErr = umfPoolGetLastAllocationError(UMFPool); + auto umfErr = umfPoolGetLastAllocationError(hPoolInternal); return umf::umf2urResult(umfErr); } return UR_RESULT_SUCCESS; @@ -86,10 +104,19 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice, alignment); } - auto UMFPool = hPool->SharedMemPool.get(); - *ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment); + usm::pool_descriptor Desc = {hPool, hContext, hDevice, UR_USM_TYPE_SHARED, + false}; + auto hPoolInternalOpt = hPool->PoolManager.getPool(Desc); + if (!hPoolInternalOpt.has_value()) { + // Internal error, every usm pool should have Host, Device, + // Shared UMF pools. + return UR_RESULT_ERROR_UNKNOWN; + } + + auto hPoolInternal = hPoolInternalOpt.value(); + *ppMem = umfPoolAlignedMalloc(hPoolInternal, size, alignment); if (*ppMem == nullptr) { - auto umfErr = umfPoolGetLastAllocationError(UMFPool); + auto umfErr = umfPoolGetLastAllocationError(hPoolInternal); return umf::umf2urResult(umfErr); } return UR_RESULT_SUCCESS; @@ -373,6 +400,48 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size, return USMHostAllocImpl(ResultPtr, Context, nullptr, Size, Alignment); } +// Template helper function for creating USM pools for given pool descriptor. +template +std::pair +createUMFPoolForDesc(usm::pool_descriptor &Desc, Args &&...args) { + umf_result_t UmfRet = UMF_RESULT_SUCCESS; + umf::provider_unique_handle_t MemProvider = nullptr; + + switch (Desc.type) { + case UR_USM_TYPE_HOST: { + std::tie(UmfRet, MemProvider) = + umf::memoryProviderMakeUnique(Desc.hContext, + Desc.hDevice); + break; + } + case UR_USM_TYPE_DEVICE: { + std::tie(UmfRet, MemProvider) = + umf::memoryProviderMakeUnique(Desc.hContext, + Desc.hDevice); + break; + } + case UR_USM_TYPE_SHARED: { + std::tie(UmfRet, MemProvider) = + umf::memoryProviderMakeUnique(Desc.hContext, + Desc.hDevice); + break; + } + default: + UmfRet = UMF_RESULT_ERROR_INVALID_ARGUMENT; + } + + if (UmfRet) + return std::pair{ + umf::umf2urResult(UmfRet), nullptr}; + + umf::pool_unique_handle_t Pool = nullptr; + std::tie(UmfRet, Pool) = + umf::poolMakeUnique({std::move(MemProvider)}, args...); + + return std::pair{ + umf::umf2urResult(UmfRet), std::move(Pool)}; +}; + ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc) : Context(Context) { @@ -396,40 +465,58 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, pNext = BaseDesc->pNext; } - auto MemProvider = - umf::memoryProviderMakeUnique(Context, nullptr) - .second; - - HostMemPool = - umf::poolMakeUnique( - {std::move(MemProvider)}, - this->DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Host]) - .second; + ur_result_t Ret; + std::tie(Ret, PoolManager) = + usm::pool_manager::create(); + if (Ret) { + throw UsmAllocationException(Ret); + } auto Device = Context->DeviceID; - MemProvider = - umf::memoryProviderMakeUnique(Context, Device) - .second; - DeviceMemPool = - umf::poolMakeUnique( - {std::move(MemProvider)}, - this->DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Device]) - .second; - - MemProvider = - umf::memoryProviderMakeUnique(Context, Device) - .second; - SharedMemPool = - umf::poolMakeUnique( - {std::move(MemProvider)}, - this->DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Shared]) - .second; + auto UrUSMPool = reinterpret_cast(this); + + // TODO: Replace this with appropriate usm::pool_descriptor 'create' static + // function. + usm::pool_descriptor Descs[] = { + {UrUSMPool, Context, nullptr, UR_USM_TYPE_HOST, false}, + {UrUSMPool, Context, Device, UR_USM_TYPE_DEVICE, false}, + {UrUSMPool, Context, Device, UR_USM_TYPE_SHARED, false}}; + + // Helper lambda function matching USM type to DisjointPoolMemType + auto descTypeToDisjointPoolType = + [](usm::pool_descriptor &Desc) -> usm::DisjointPoolMemType { + switch (Desc.type) { + case UR_USM_TYPE_HOST: + return usm::DisjointPoolMemType::Host; + case UR_USM_TYPE_DEVICE: + return usm::DisjointPoolMemType::Device; + case UR_USM_TYPE_SHARED: + return (Desc.deviceReadOnly) ? usm::DisjointPoolMemType::SharedReadOnly + : usm::DisjointPoolMemType::Shared; + default: + // Should not be reachable. + ur::unreachable(); + } + }; + + for (auto &Desc : Descs) { + umf::pool_unique_handle_t Pool = nullptr; + auto PoolType = descTypeToDisjointPoolType(Desc); + + std::tie(Ret, Pool) = createUMFPoolForDesc( + Desc, this->DisjointPoolConfigs.Configs[PoolType]); + if (Ret) { + throw UsmAllocationException(Ret); + } + + PoolManager.addPool(Desc, Pool); + } + Context->addPool(this); } bool ur_usm_pool_handle_t_::hasUMFPool(umf_memory_pool_t *umf_pool) { - return DeviceMemPool.get() == umf_pool || SharedMemPool.get() == umf_pool || - HostMemPool.get() == umf_pool; + return PoolManager.hasPool(umf_pool); } UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate( diff --git a/source/adapters/cuda/usm.hpp b/source/adapters/cuda/usm.hpp index 2ec3df150f..541e3617d3 100644 --- a/source/adapters/cuda/usm.hpp +++ b/source/adapters/cuda/usm.hpp @@ -12,6 +12,7 @@ #include #include +#include usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); @@ -23,9 +24,7 @@ struct ur_usm_pool_handle_t_ { usm::DisjointPoolAllConfigs DisjointPoolConfigs = usm::DisjointPoolAllConfigs(); - umf::pool_unique_handle_t DeviceMemPool; - umf::pool_unique_handle_t SharedMemPool; - umf::pool_unique_handle_t HostMemPool; + usm::pool_manager PoolManager; ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc); diff --git a/source/common/ur_pool_manager.hpp b/source/common/ur_pool_manager.hpp index 2215bd0575..4accd55631 100644 --- a/source/common/ur_pool_manager.hpp +++ b/source/common/ur_pool_manager.hpp @@ -249,13 +249,19 @@ template struct pool_manager { std::optional getPool(const D &desc) noexcept { auto it = descToPoolMap.find(desc); if (it == descToPoolMap.end()) { - logger::error("Pool descriptor doesn't match any existing pool: {}", - desc); + logger::error( + "Pool descriptor: {}, doesn't match any existing pool", desc); return std::nullopt; } return it->second.get(); } + + bool hasPool(umf_memory_pool_handle_t hPool) noexcept { + return std::any_of( + descToPoolMap.begin(), descToPoolMap.end(), + [&hPool](const auto &pair) { return hPool == pair.second.get(); }); + } }; } // namespace usm