From 91a37cfb9b45edc8b6cb13f0c12c5c93b0a1bf6a Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Thu, 7 Mar 2024 12:49:56 -0800 Subject: [PATCH] Delete Adapter in atexit after refcnt is 0 due to multi DLLMain - In Windows, SYCL and UMF both define DLLMain such that a DLLMain for only the adapter's is not possible. To fix this, the L0 adapter inits the global adapter at variable init and registers an atexit teardown after refcnt == 0. Signed-off-by: Neil R. Spruit --- source/adapters/level_zero/CMakeLists.txt | 9 +-- source/adapters/level_zero/adapter.cpp | 56 +++++++++++++------ source/adapters/level_zero/adapter.hpp | 2 +- .../level_zero/adapter_lib_init_linux.cpp | 9 +-- .../level_zero/adapter_lib_init_windows.cpp | 27 --------- source/adapters/level_zero/device.cpp | 4 +- source/adapters/level_zero/platform.cpp | 6 +- source/adapters/level_zero/queue.cpp | 2 +- 8 files changed, 52 insertions(+), 63 deletions(-) delete mode 100644 source/adapters/level_zero/adapter_lib_init_windows.cpp diff --git a/source/adapters/level_zero/CMakeLists.txt b/source/adapters/level_zero/CMakeLists.txt index 4cebeefc4f..d26d0aeb26 100644 --- a/source/adapters/level_zero/CMakeLists.txt +++ b/source/adapters/level_zero/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2022-2024 Intel Corporation # Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. # See LICENSE.TXT # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -122,12 +122,7 @@ add_ur_adapter(${TARGET_NAME} ${CMAKE_CURRENT_SOURCE_DIR}/../../ur/ur.cpp ) -if(WIN32) - target_sources(ur_adapter_level_zero - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_windows.cpp - ) -else() +if(NOT WIN32) target_sources(ur_adapter_level_zero PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/adapter_lib_init_linux.cpp diff --git a/source/adapters/level_zero/adapter.cpp b/source/adapters/level_zero/adapter.cpp index 504df5b20d..9d5b1038a2 100644 --- a/source/adapters/level_zero/adapter.cpp +++ b/source/adapters/level_zero/adapter.cpp @@ -11,7 +11,13 @@ #include "adapter.hpp" #include "ur_level_zero.hpp" -ur_adapter_handle_t_ *Adapter; +// Due to multiple DLLMain definitions with SYCL, Global Adapter is init at +// variable creation. +#if defined(_WIN32) +ur_adapter_handle_t_ *GlobalAdapter = new ur_adapter_handle_t_(); +#else +ur_adapter_handle_t_ *GlobalAdapter; +#endif ur_result_t initPlatforms(PlatformVec &platforms) noexcept try { uint32_t ZeDriverCount = 0; @@ -53,7 +59,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() { } // initialize level zero only once. - if (Adapter->ZeResult == std::nullopt) { + if (GlobalAdapter->ZeResult == std::nullopt) { // Setting these environment variables before running zeInit will enable // the validation layer in the Level Zero loader. if (UrL0Debug & UR_L0_DEBUG_VALIDATION) { @@ -72,20 +78,21 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() { // We must only initialize the driver once, even if urPlatformGet() is // called multiple times. Declaring the return value as "static" ensures // it's only called once. - Adapter->ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY)); + GlobalAdapter->ZeResult = + ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY)); } - assert(Adapter->ZeResult != + assert(GlobalAdapter->ZeResult != std::nullopt); // verify that level-zero is initialized PlatformVec platforms; // Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms. - if (*Adapter->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) { + if (*GlobalAdapter->ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) { result = std::move(platforms); return; } - if (*Adapter->ZeResult != ZE_RESULT_SUCCESS) { + if (*GlobalAdapter->ZeResult != ZE_RESULT_SUCCESS) { urPrint("zeInit: Level Zero initialization failure\n"); - result = ze2urResult(*Adapter->ZeResult); + result = ze2urResult(*GlobalAdapter->ZeResult); return; } @@ -98,6 +105,14 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() { }; } +#if defined(_WIN32) +void globalAdapterWindowsCleanup() { + if (GlobalAdapter) { + delete GlobalAdapter; + } +} +#endif + ur_result_t adapterStateTeardown() { bool LeakFound = false; @@ -183,6 +198,11 @@ ur_result_t adapterStateTeardown() { } if (LeakFound) return UR_RESULT_ERROR_INVALID_MEM_OBJECT; + // Due to multiple DLLMain definitions with SYCL, register to cleanup the + // Global Adapter after refcnt is 0 +#if defined(_WIN32) + std::atexit(globalAdapterWindowsCleanup); +#endif return UR_RESULT_SUCCESS; } @@ -202,12 +222,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( ///< adapters available. ) { if (NumEntries > 0 && Adapters) { - if (Adapter) { - std::lock_guard Lock{Adapter->Mutex}; - if (Adapter->RefCount++ == 0) { + if (GlobalAdapter) { + std::lock_guard Lock{GlobalAdapter->Mutex}; + if (GlobalAdapter->RefCount++ == 0) { adapterStateInit(); } - *Adapters = Adapter; + *Adapters = GlobalAdapter; } else { return UR_RESULT_ERROR_UNINITIALIZED; } @@ -222,9 +242,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { // Check first if the Adapter pointer is valid - if (Adapter) { - std::lock_guard Lock{Adapter->Mutex}; - if (--Adapter->RefCount == 0) { + if (GlobalAdapter) { + std::lock_guard Lock{GlobalAdapter->Mutex}; + if (--GlobalAdapter->RefCount == 0) { return adapterStateTeardown(); } } @@ -233,9 +253,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - if (Adapter) { - std::lock_guard Lock{Adapter->Mutex}; - Adapter->RefCount++; + if (GlobalAdapter) { + std::lock_guard Lock{GlobalAdapter->Mutex}; + GlobalAdapter->RefCount++; } else { return UR_RESULT_ERROR_UNINITIALIZED; } @@ -267,7 +287,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_ADAPTER_BACKEND_LEVEL_ZERO); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(Adapter->RefCount.load()); + return ReturnValue(GlobalAdapter->RefCount.load()); default: return UR_RESULT_ERROR_INVALID_ENUMERATION; } diff --git a/source/adapters/level_zero/adapter.hpp b/source/adapters/level_zero/adapter.hpp index e3e05b118b..1fdf3a9294 100644 --- a/source/adapters/level_zero/adapter.hpp +++ b/source/adapters/level_zero/adapter.hpp @@ -25,4 +25,4 @@ struct ur_adapter_handle_t_ { ZeCache> PlatformCache; }; -extern ur_adapter_handle_t_ *Adapter; +extern ur_adapter_handle_t_ *GlobalAdapter; diff --git a/source/adapters/level_zero/adapter_lib_init_linux.cpp b/source/adapters/level_zero/adapter_lib_init_linux.cpp index 13044d251f..bb6f1d4e6d 100644 --- a/source/adapters/level_zero/adapter_lib_init_linux.cpp +++ b/source/adapters/level_zero/adapter_lib_init_linux.cpp @@ -12,13 +12,14 @@ #include "ur_level_zero.hpp" void __attribute__((constructor)) createAdapterHandle() { - if (!Adapter) { - Adapter = new ur_adapter_handle_t_(); + if (!GlobalAdapter) { + GlobalAdapter = new ur_adapter_handle_t_(); } } void __attribute__((destructor)) deleteAdapterHandle() { - if (Adapter) { - delete Adapter; + if (GlobalAdapter) { + delete GlobalAdapter; + GlobalAdapter = nullptr; } } diff --git a/source/adapters/level_zero/adapter_lib_init_windows.cpp b/source/adapters/level_zero/adapter_lib_init_windows.cpp deleted file mode 100644 index 7f5146461f..0000000000 --- a/source/adapters/level_zero/adapter_lib_init_windows.cpp +++ /dev/null @@ -1,27 +0,0 @@ -//===--------- adapter_lib_init_windows.cpp - Level Zero Adapter ----------===// -// -// Copyright (C) 2023 Intel Corporation -// -// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM -// Exceptions. See LICENSE.TXT -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "adapter.hpp" -#include "ur_level_zero.hpp" -#include - -extern "C" BOOL APIENTRY DllMain(HINSTANCE hinstDLL, DWORD fdwReason, - LPVOID lpvReserved) { - if (fdwReason == DLL_PROCESS_DETACH) { - if (Adapter) { - delete Adapter; - } - } else if (fdwReason == DLL_PROCESS_ATTACH) { - if (!Adapter) { - Adapter = new ur_adapter_handle_t_(); - } - } - return TRUE; -} diff --git a/source/adapters/level_zero/device.cpp b/source/adapters/level_zero/device.cpp index 624c5a7e7f..437b4d6603 100644 --- a/source/adapters/level_zero/device.cpp +++ b/source/adapters/level_zero/device.cpp @@ -1442,7 +1442,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // a valid Level Zero device. ur_device_handle_t Dev = nullptr; - if (const auto *platforms = Adapter->PlatformCache->get_value()) { + if (const auto *platforms = GlobalAdapter->PlatformCache->get_value()) { for (const auto &p : *platforms) { Dev = p->getDeviceFromNativeHandle(ZeDevice); if (Dev) { @@ -1453,7 +1453,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( } } } else { - return Adapter->PlatformCache->get_error(); + return GlobalAdapter->PlatformCache->get_error(); } if (Dev == nullptr) diff --git a/source/adapters/level_zero/platform.cpp b/source/adapters/level_zero/platform.cpp index d8e0583e73..ab577247bd 100644 --- a/source/adapters/level_zero/platform.cpp +++ b/source/adapters/level_zero/platform.cpp @@ -29,7 +29,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet( ) { // Platform handles are cached for reuse. This is to ensure consistent // handle pointers across invocations and to improve retrieval performance. - if (const auto *cached_platforms = Adapter->PlatformCache->get_value(); + if (const auto *cached_platforms = GlobalAdapter->PlatformCache->get_value(); cached_platforms) { uint32_t nplatforms = (uint32_t)cached_platforms->size(); if (NumPlatforms) { @@ -41,7 +41,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet( } } } else { - return Adapter->PlatformCache->get_error(); + return GlobalAdapter->PlatformCache->get_error(); } return UR_RESULT_SUCCESS; @@ -133,7 +133,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformCreateWithNativeHandle( auto ZeDriver = ur_cast(NativePlatform); uint32_t NumPlatforms = 0; - ur_adapter_handle_t AdapterHandle = Adapter; + ur_adapter_handle_t AdapterHandle = GlobalAdapter; UR_CALL(urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms)); if (NumPlatforms) { diff --git a/source/adapters/level_zero/queue.cpp b/source/adapters/level_zero/queue.cpp index a50921f207..17ead460d0 100644 --- a/source/adapters/level_zero/queue.cpp +++ b/source/adapters/level_zero/queue.cpp @@ -569,7 +569,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle( // Maybe this is not completely correct. uint32_t NumEntries = 1; ur_platform_handle_t Platform{}; - ur_adapter_handle_t AdapterHandle = Adapter; + ur_adapter_handle_t AdapterHandle = GlobalAdapter; UR_CALL(urPlatformGet(&AdapterHandle, 1, NumEntries, &Platform, nullptr)); ur_device_handle_t UrDevice = Device;