diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index 9cd67ad748..aa108d74b1 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -213,15 +213,7 @@ jobs: working-directory: ${{github.workspace}}/build run: ctest -C ${{matrix.build_type}} --output-on-failure -L "adapter-specific" --timeout 180 - # Temporarily disabling platform test for L0, because of hang - # See issue: #824 - - name: Test L0 adapter - if: matrix.adapter.name == 'L0' - working-directory: ${{github.workspace}}/build - run: ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" -E "platform-adapter_level_zero" --timeout 180 - - name: Test adapters - if: matrix.adapter.name != 'L0' working-directory: ${{github.workspace}}/build run: env UR_CTS_ADAPTER_PLATFORM="${{matrix.adapter.platform}}" ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" --timeout 180 diff --git a/source/adapters/level_zero/adapter.cpp b/source/adapters/level_zero/adapter.cpp index d43ae07cdb..1cf3871108 100644 --- a/source/adapters/level_zero/adapter.cpp +++ b/source/adapters/level_zero/adapter.cpp @@ -13,15 +13,92 @@ ur_adapter_handle_t_ Adapter{}; -ur_result_t adapterStateTeardown() { - // reclaim ur_platform_handle_t objects here since we don't have - // urPlatformRelease. - for (ur_platform_handle_t Platform : *URPlatformsCache) { - delete Platform; +ur_result_t initPlatforms(PlatformVec &platforms) noexcept try { + uint32_t ZeDriverCount = 0; + ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr)); + if (ZeDriverCount == 0) { + return UR_RESULT_SUCCESS; + } + + std::vector ZeDrivers; + ZeDrivers.resize(ZeDriverCount); + + ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data())); + for (uint32_t I = 0; I < ZeDriverCount; ++I) { + auto platform = std::make_unique(ZeDrivers[I]); + UR_CALL(platform->initialize()); + + // Save a copy in the cache for future uses. + platforms.push_back(std::move(platform)); + } + return UR_RESULT_SUCCESS; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +ur_result_t adapterStateInit() { + static std::once_flag ZeCallCountInitialized; + try { + std::call_once(ZeCallCountInitialized, []() { + if (UrL0LeaksDebug) { + ZeCallCount = new std::map; + } + }); + } catch (const std::bad_alloc &) { + return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } catch (...) { + return UR_RESULT_ERROR_UNKNOWN; } - delete URPlatformsCache; - delete URPlatformsCacheMutex; + // initialize level zero only once. + if (Adapter.ZeResult == std::nullopt) { + // Setting these environment variables before running zeInit will enable the + // validation layer in the Level Zero loader. + if (UrL0Debug & UR_L0_DEBUG_VALIDATION) { + setEnvVar("ZE_ENABLE_VALIDATION_LAYER", "1"); + setEnvVar("ZE_ENABLE_PARAMETER_VALIDATION", "1"); + } + + if (getenv("SYCL_ENABLE_PCI") != nullptr) { + urPrint("WARNING: SYCL_ENABLE_PCI is deprecated and no longer needed.\n"); + } + + // TODO: We can still safely recover if something goes wrong during the + // init. Implement handling segfault using sigaction. + + // We must only initialize the driver once, even if urPlatformGet() is + // called multiple times. Declaring the return value as "static" ensures + // it's only called once. + Adapter.ZeResult = ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY)); + } + + Adapter.PlatformCache.Compute = [](Result &result) { + assert(Adapter.ZeResult != + std::nullopt); // verify that level-zero is initialized + PlatformVec platforms; + + // Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms. + if (*Adapter.ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) { + result = Result(std::move(platforms)); + return; + } + if (*Adapter.ZeResult != ZE_RESULT_SUCCESS) { + urPrint("zeInit: Level Zero initialization failure\n"); + result = Result(ze2urResult(*Adapter.ZeResult)); + return; + } + + ur_result_t err = initPlatforms(platforms); + if (err == UR_RESULT_SUCCESS) { + result = Result(std::move(platforms)); + } else { + result = Result(err); + } + }; + return UR_RESULT_SUCCESS; +} + +ur_result_t adapterStateTeardown() { bool LeakFound = false; // Print the balance of various create/destroy native calls. @@ -126,9 +203,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet( ) { if (NumEntries > 0 && Adapters) { std::lock_guard Lock{Adapter.Mutex}; - // TODO: Some initialization that happens in urPlatformsGet could be moved - // here for when RefCount reaches 1 - Adapter.RefCount++; + if (Adapter.RefCount++ == 0) { + adapterStateInit(); + } *Adapters = &Adapter; } diff --git a/source/adapters/level_zero/adapter.hpp b/source/adapters/level_zero/adapter.hpp index 22bb032d75..86dbde162d 100644 --- a/source/adapters/level_zero/adapter.hpp +++ b/source/adapters/level_zero/adapter.hpp @@ -10,10 +10,18 @@ #include #include +#include +#include +#include + +using PlatformVec = std::vector>; struct ur_adapter_handle_t_ { std::atomic RefCount = 0; std::mutex Mutex; + + std::optional ZeResult; + ZeCache> PlatformCache; }; extern ur_adapter_handle_t_ Adapter; diff --git a/source/adapters/level_zero/device.cpp b/source/adapters/level_zero/device.cpp index abdfd2e541..30935bfaa0 100644 --- a/source/adapters/level_zero/device.cpp +++ b/source/adapters/level_zero/device.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "device.hpp" +#include "adapter.hpp" #include "ur_level_zero.hpp" #include #include @@ -1325,18 +1326,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // Level Zero devices when we initialized the platforms/devices cache, so the // "NativeHandle" must already be in the cache. If it is not, this must not be // a valid Level Zero device. - // - // TODO: maybe we should populate cache of platforms if it wasn't already. - // For now assert that is was populated. - UR_ASSERT(URPlatformCachePopulated, UR_RESULT_ERROR_INVALID_VALUE); - const std::lock_guard Lock{*URPlatformsCacheMutex}; ur_device_handle_t Dev = nullptr; - for (ur_platform_handle_t ThePlatform : *URPlatformsCache) { + auto &platforms = Adapter.PlatformCache; + if (platforms->is_err()) { + return platforms->get_error(); + } + for (const auto &ThePlatform : platforms->get_value()) { Dev = ThePlatform->getDeviceFromNativeHandle(ZeDevice); if (Dev) { // Check that the input Platform, if was given, matches the found one. - UR_ASSERT(!Platform || Platform == ThePlatform, + UR_ASSERT(!Platform || Platform == ThePlatform.get(), UR_RESULT_ERROR_INVALID_PLATFORM); break; } diff --git a/source/adapters/level_zero/platform.cpp b/source/adapters/level_zero/platform.cpp index 335a920294..0a4ca66f56 100644 --- a/source/adapters/level_zero/platform.cpp +++ b/source/adapters/level_zero/platform.cpp @@ -27,101 +27,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet( uint32_t *NumPlatforms ///< [out][optional] returns the total number of ///< platforms available. ) { - static std::once_flag ZeCallCountInitialized; - try { - std::call_once(ZeCallCountInitialized, []() { - if (UrL0LeaksDebug) { - ZeCallCount = new std::map; - } - }); - } catch (const std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } - - // Setting these environment variables before running zeInit will enable the - // validation layer in the Level Zero loader. - if (UrL0Debug & UR_L0_DEBUG_VALIDATION) { - setEnvVar("ZE_ENABLE_VALIDATION_LAYER", "1"); - setEnvVar("ZE_ENABLE_PARAMETER_VALIDATION", "1"); - } - - if (getenv("SYCL_ENABLE_PCI") != nullptr) { - urPrint("WARNING: SYCL_ENABLE_PCI is deprecated and no longer needed.\n"); - } - - // TODO: We can still safely recover if something goes wrong during the init. - // Implement handling segfault using sigaction. - - // We must only initialize the driver once, even if urPlatformGet() is called - // multiple times. Declaring the return value as "static" ensures it's only - // called once. - static ze_result_t ZeResult = - ZE_CALL_NOCHECK(zeInit, (ZE_INIT_FLAG_GPU_ONLY)); - - // Absorb the ZE_RESULT_ERROR_UNINITIALIZED and just return 0 Platforms. - if (ZeResult == ZE_RESULT_ERROR_UNINITIALIZED) { - UR_ASSERT(NumEntries == 0, UR_RESULT_ERROR_INVALID_VALUE); - if (NumPlatforms) - *NumPlatforms = 0; - return UR_RESULT_SUCCESS; + // Platform handles are cached for reuse. This is to ensure consistent + // handle pointers across invocations and to improve retrieval performance. + auto &cache = Adapter.PlatformCache; + if (cache->is_err()) { + return cache->get_error(); } - if (ZeResult != ZE_RESULT_SUCCESS) { - urPrint("zeInit: Level Zero initialization failure\n"); - return ze2urResult(ZeResult); + auto &cached_platforms = cache->get_value(); + if (NumEntries > cached_platforms.size()) { + return UR_RESULT_ERROR_INVALID_PLATFORM; } - // Cache ur_platform_handle_t for reuse in the future - // It solves two problems; - // 1. sycl::platform equality issue; we always return the same - // ur_platform_handle_t - // 2. performance; we can save time by immediately return from cache. - // - - const std::lock_guard Lock{*URPlatformsCacheMutex}; - if (!URPlatformCachePopulated) { - try { - // Level Zero does not have concept of Platforms, but Level Zero driver is - // the closest match. - uint32_t ZeDriverCount = 0; - ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr)); - if (ZeDriverCount == 0) { - URPlatformCachePopulated = true; - } else { - std::vector ZeDrivers; - ZeDrivers.resize(ZeDriverCount); - - ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data())); - for (uint32_t I = 0; I < ZeDriverCount; ++I) { - auto Platform = new ur_platform_handle_t_(ZeDrivers[I]); - // Save a copy in the cache for future uses. - URPlatformsCache->push_back(Platform); - - UR_CALL(Platform->initialize()); - } - URPlatformCachePopulated = true; - } - } catch (const std::bad_alloc &) { - return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } catch (...) { - return UR_RESULT_ERROR_UNKNOWN; - } - } - - // Populate returned platforms from the cache. - if (Platforms) { - UR_ASSERT(NumEntries <= URPlatformsCache->size(), - UR_RESULT_ERROR_INVALID_PLATFORM); - std::copy_n(URPlatformsCache->begin(), NumEntries, Platforms); + if (NumPlatforms) { + *NumPlatforms = cached_platforms.size(); } - if (NumPlatforms) { - if (*NumPlatforms == 0) - *NumPlatforms = URPlatformsCache->size(); - else - *NumPlatforms = (std::min)(URPlatformsCache->size(), (size_t)NumEntries); + for (uint32_t i = 0; i < NumEntries; ++i) { + Platforms[i] = cached_platforms[i].get(); } return UR_RESULT_SUCCESS; diff --git a/source/ur/ur.cpp b/source/ur/ur.cpp index 4de87d53c2..dad6312d57 100644 --- a/source/ur/ur.cpp +++ b/source/ur/ur.cpp @@ -22,9 +22,3 @@ bool PrintTrace = [] { } return false; }(); - -// Apparatus for maintaining immutable cache of platforms. -std::vector *URPlatformsCache = - new std::vector; -SpinLock *URPlatformsCacheMutex = new SpinLock; -bool URPlatformCachePopulated = false; diff --git a/source/ur/ur.hpp b/source/ur/ur.hpp index 11d619ea04..8ff0fada95 100644 --- a/source/ur/ur.hpp +++ b/source/ur/ur.hpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -191,16 +192,6 @@ struct _ur_platform {}; // Controls tracing UR calls from within the UR itself. extern bool PrintTrace; -// Apparatus for maintaining immutable cache of platforms. -// -// Note we only create a simple pointer variables such that C++ RT won't -// deallocate them automatically at the end of the main program. -// The heap memory allocated for these global variables reclaimed only at -// explicit tear-down. -extern std::vector *URPlatformsCache; -extern SpinLock *URPlatformsCacheMutex; -extern bool URPlatformCachePopulated; - // The getInfo*/ReturnHelper facilities provide shortcut way of // writing return bytes for the various getInfo APIs. namespace ur { @@ -310,3 +301,29 @@ class UrReturnHelper { void *param_value; size_t *param_value_size_ret; }; + +template class Result { +public: + Result(ur_result_t err) : value_or_err(err) {} + Result(T value) : value_or_err(std::move(value)) {} + Result() : value_or_err(UR_RESULT_ERROR_UNINITIALIZED) {} + + bool is_err() { return std::holds_alternative(value_or_err); } + explicit operator bool() const { return !is_err(); } + const T &get_value() noexcept { + try { + return std::get(value_or_err); + } catch (...) { + /* unreachable */ + assert(0 && + "get_value() called on empty Result without checking for error."); + } + } + const ur_result_t get_error() { + auto *err = std::get_if(&value_or_err); + return err ? *err : UR_RESULT_SUCCESS; + } + +private: + std::variant value_or_err; +};