diff --git a/source/adapters/opencl/adapter.cpp b/source/adapters/opencl/adapter.cpp index 65c5676bf9..637c5e5a7e 100644 --- a/source/adapters/opencl/adapter.cpp +++ b/source/adapters/opencl/adapter.cpp @@ -15,7 +15,18 @@ struct ur_adapter_handle_t_ { std::mutex Mutex; }; -ur_adapter_handle_t_ adapter{}; +static ur_adapter_handle_t_ *adapter = nullptr; + +static void globalAdapterShutdown() { + if (cl_ext::ExtFuncPtrCache) { + delete cl_ext::ExtFuncPtrCache; + cl_ext::ExtFuncPtrCache = nullptr; + } + if (adapter) { + delete adapter; + adapter = nullptr; + } +} UR_APIEXPORT ur_result_t UR_APICALL urInit(ur_device_init_flags_t, ur_loader_config_handle_t) { @@ -30,12 +41,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) { if (NumEntries > 0 && phAdapters) { - std::lock_guard Lock{adapter.Mutex}; - if (adapter.RefCount++ == 0) { - cl_ext::ExtFuncPtrCache = std::make_unique(); + // Sometimes urAdaterGet may be called after the library already been torn + // down, we also need to create a temporary handle for it. + if (!adapter) { + adapter = new ur_adapter_handle_t_(); + atexit(globalAdapterShutdown); } - *phAdapters = &adapter; + std::lock_guard Lock{adapter->Mutex}; + if (adapter->RefCount++ == 0) { + cl_ext::ExtFuncPtrCache = new cl_ext::ExtFuncPtrCacheT(); + } + + *phAdapters = adapter; } if (pNumAdapters) { @@ -46,14 +64,20 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - ++adapter.RefCount; + ++adapter->RefCount; return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - std::lock_guard Lock{adapter.Mutex}; - if (--adapter.RefCount == 0) { - cl_ext::ExtFuncPtrCache.reset(); + // Check first if the adapter is valid pointer + if (adapter) { + std::lock_guard Lock{adapter->Mutex}; + if (--adapter->RefCount == 0) { + if (cl_ext::ExtFuncPtrCache) { + delete cl_ext::ExtFuncPtrCache; + cl_ext::ExtFuncPtrCache = nullptr; + } + } } return UR_RESULT_SUCCESS; } @@ -77,7 +101,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_ADAPTER_BACKEND_CUDA); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(adapter.RefCount.load()); + return ReturnValue(adapter->RefCount.load()); default: return UR_RESULT_ERROR_INVALID_ENUMERATION; } diff --git a/source/adapters/opencl/common.hpp b/source/adapters/opencl/common.hpp index 0cb19694a6..068bd29124 100644 --- a/source/adapters/opencl/common.hpp +++ b/source/adapters/opencl/common.hpp @@ -343,7 +343,7 @@ struct ExtFuncPtrCacheT { // piTeardown to avoid issues with static destruction order (a user application // might have static objects that indirectly access this cache in their // destructor). -inline std::unique_ptr ExtFuncPtrCache; +inline ExtFuncPtrCacheT *ExtFuncPtrCache; // USM helper function to get an extension function pointer template