Skip to content

Commit

Permalink
[L0] move platform cache into the adapter structure
Browse files Browse the repository at this point in the history
The platform cache was globally available, protected by
a loosely-associated spin lock. However, its destruction was
happening during adapter teardown (i.e., when the adapter refcount
reached 0). This was causing issues whenever the adapter was
initialized and destroyed multiple time inside of a single process,
which, for example, happens during tests.

This patch fixes the above problem by moving the platform cache
from the global state into the adapter structure. This allowed
for a simpler implementation that no longer requires an explicit
lock and instead uses lazy loading (std::call_once).

With this patch, all platform tests are now passing for L0.
Closes #824
  • Loading branch information
pbalcer committed Jan 15, 2024
1 parent 79c28d0 commit f465e56
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 85 deletions.
8 changes: 0 additions & 8 deletions .github/workflows/cmake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,7 @@ jobs:
working-directory: ${{github.workspace}}/build
run: ctest -C ${{matrix.build_type}} --output-on-failure -L "adapter-specific" --timeout 180

# Temporarily disabling platform test for L0, because of hang
# See issue: #824
- name: Test L0 adapter
if: matrix.adapter.name == 'L0'
working-directory: ${{github.workspace}}/build
run: ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" -E "platform-adapter_level_zero" --timeout 180

- name: Test adapters
if: matrix.adapter.name != 'L0'
working-directory: ${{github.workspace}}/build
run: env UR_CTS_ADAPTER_PLATFORM="${{matrix.adapter.platform}}" ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" --timeout 180

Expand Down
49 changes: 39 additions & 10 deletions source/adapters/level_zero/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,44 @@

ur_adapter_handle_t_ Adapter{};

ur_result_t adapterStateTeardown() {
// reclaim ur_platform_handle_t objects here since we don't have
// urPlatformRelease.
for (ur_platform_handle_t Platform : *URPlatformsCache) {
delete Platform;
ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
uint32_t ZeDriverCount = 0;
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
if (ZeDriverCount == 0) {
return UR_RESULT_SUCCESS;
}
delete URPlatformsCache;
delete URPlatformsCacheMutex;

std::vector<ze_driver_handle_t> ZeDrivers;
ZeDrivers.resize(ZeDriverCount);

ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data()));
for (uint32_t I = 0; I < ZeDriverCount; ++I) {
auto platform = std::make_unique<ur_platform_handle_t_>(ZeDrivers[I]);
UR_CALL(platform->initialize());

// Save a copy in the cache for future uses.
platforms.push_back(std::move(platform));
}
return UR_RESULT_SUCCESS;
} catch (...) {
return exceptionToResult(std::current_exception());
}

ur_result_t adapterStateInit() {
/* TODO: move L0 initialization from urPlatformGet */
Adapter.PlatformCache.Compute = [](Result<PlatformVec> &result) {
PlatformVec platforms;
ur_result_t err = initPlatforms(platforms);
if (err == UR_RESULT_SUCCESS) {
result = Result<PlatformVec>(std::move(platforms));
} else {
result = Result<PlatformVec>(err);
}
};
return UR_RESULT_SUCCESS;
}

ur_result_t adapterStateTeardown() {
bool LeakFound = false;

// Print the balance of various create/destroy native calls.
Expand Down Expand Up @@ -126,9 +155,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
) {
if (NumEntries > 0 && Adapters) {
std::lock_guard<std::mutex> Lock{Adapter.Mutex};
// TODO: Some initialization that happens in urPlatformsGet could be moved
// here for when RefCount reaches 1
Adapter.RefCount++;
if (Adapter.RefCount++ == 0) {
adapterStateInit();
}
*Adapters = &Adapter;
}

Expand Down
5 changes: 5 additions & 0 deletions source/adapters/level_zero/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@

#include <atomic>
#include <mutex>
#include <ur/ur.hpp>

using PlatformVec = std::vector<std::unique_ptr<ur_platform_handle_t_>>;

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;

ZeCache<Result<PlatformVec>> PlatformCache;
};

extern ur_adapter_handle_t_ Adapter;
14 changes: 7 additions & 7 deletions source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "device.hpp"
#include "adapter.hpp"
#include "ur_level_zero.hpp"
#include <algorithm>
#include <climits>
Expand Down Expand Up @@ -1325,18 +1326,17 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
// Level Zero devices when we initialized the platforms/devices cache, so the
// "NativeHandle" must already be in the cache. If it is not, this must not be
// a valid Level Zero device.
//
// TODO: maybe we should populate cache of platforms if it wasn't already.
// For now assert that is was populated.
UR_ASSERT(URPlatformCachePopulated, UR_RESULT_ERROR_INVALID_VALUE);
const std::lock_guard<SpinLock> Lock{*URPlatformsCacheMutex};

ur_device_handle_t Dev = nullptr;
for (ur_platform_handle_t ThePlatform : *URPlatformsCache) {
auto &platforms = Adapter.PlatformCache;
if (platforms->is_err()) {
return platforms->get_error();
}
for (const auto &ThePlatform : platforms->get_value()) {
Dev = ThePlatform->getDeviceFromNativeHandle(ZeDevice);
if (Dev) {
// Check that the input Platform, if was given, matches the found one.
UR_ASSERT(!Platform || Platform == ThePlatform,
UR_ASSERT(!Platform || Platform == ThePlatform.get(),
UR_RESULT_ERROR_INVALID_PLATFORM);
break;
}
Expand Down
57 changes: 13 additions & 44 deletions source/adapters/level_zero/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,55 +73,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGet(
return ze2urResult(ZeResult);
}

// Cache ur_platform_handle_t for reuse in the future
// It solves two problems;
// 1. sycl::platform equality issue; we always return the same
// ur_platform_handle_t
// 2. performance; we can save time by immediately return from cache.
//

const std::lock_guard<SpinLock> Lock{*URPlatformsCacheMutex};
if (!URPlatformCachePopulated) {
try {
// Level Zero does not have concept of Platforms, but Level Zero driver is
// the closest match.
uint32_t ZeDriverCount = 0;
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
if (ZeDriverCount == 0) {
URPlatformCachePopulated = true;
} else {
std::vector<ze_driver_handle_t> ZeDrivers;
ZeDrivers.resize(ZeDriverCount);

ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, ZeDrivers.data()));
for (uint32_t I = 0; I < ZeDriverCount; ++I) {
auto Platform = new ur_platform_handle_t_(ZeDrivers[I]);
// Save a copy in the cache for future uses.
URPlatformsCache->push_back(Platform);

UR_CALL(Platform->initialize());
}
URPlatformCachePopulated = true;
}
} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
return UR_RESULT_ERROR_UNKNOWN;
}
// Platform handles are cached for reuse. This is to ensure consistent
// handle pointers across invocations and to improve retrieval performance.
auto &cache = Adapter.PlatformCache;
if (cache->is_err()) {
return cache->get_error();
}

// Populate returned platforms from the cache.
if (Platforms) {
UR_ASSERT(NumEntries <= URPlatformsCache->size(),
UR_RESULT_ERROR_INVALID_PLATFORM);
std::copy_n(URPlatformsCache->begin(), NumEntries, Platforms);
auto &cached_platforms = cache->get_value();
if (NumEntries > cached_platforms.size()) {
return UR_RESULT_ERROR_INVALID_PLATFORM;
}

if (NumPlatforms) {
if (*NumPlatforms == 0)
*NumPlatforms = URPlatformsCache->size();
else
*NumPlatforms = (std::min)(URPlatformsCache->size(), (size_t)NumEntries);
*NumPlatforms = cached_platforms.size();
}

for (uint32_t i = 0; i < NumEntries; ++i) {
Platforms[i] = cached_platforms[i].get();
}

return UR_RESULT_SUCCESS;
Expand Down
6 changes: 0 additions & 6 deletions source/ur/ur.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,3 @@ bool PrintTrace = [] {
}
return false;
}();

// Apparatus for maintaining immutable cache of platforms.
std::vector<ur_platform_handle_t> *URPlatformsCache =
new std::vector<ur_platform_handle_t>;
SpinLock *URPlatformsCacheMutex = new SpinLock;
bool URPlatformCachePopulated = false;
37 changes: 27 additions & 10 deletions source/ur/ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <shared_mutex>
#include <string>
#include <thread>
#include <variant>
#include <vector>

#include <ur_api.h>
Expand Down Expand Up @@ -191,16 +192,6 @@ struct _ur_platform {};
// Controls tracing UR calls from within the UR itself.
extern bool PrintTrace;

// Apparatus for maintaining immutable cache of platforms.
//
// Note we only create a simple pointer variables such that C++ RT won't
// deallocate them automatically at the end of the main program.
// The heap memory allocated for these global variables reclaimed only at
// explicit tear-down.
extern std::vector<ur_platform_handle_t> *URPlatformsCache;
extern SpinLock *URPlatformsCacheMutex;
extern bool URPlatformCachePopulated;

// The getInfo*/ReturnHelper facilities provide shortcut way of
// writing return bytes for the various getInfo APIs.
namespace ur {
Expand Down Expand Up @@ -310,3 +301,29 @@ class UrReturnHelper {
void *param_value;
size_t *param_value_size_ret;
};

template <typename T> class Result {
public:
Result(ur_result_t err) : value_or_err(err) {}
Result(T value) : value_or_err(std::move(value)) {}
Result() : value_or_err(UR_RESULT_ERROR_UNINITIALIZED) {}

bool is_err() { return std::holds_alternative<ur_result_t>(value_or_err); }
explicit operator bool() const { return !is_err(); }
const T &get_value() noexcept {
try {
return std::get<T>(value_or_err);
} catch (...) {
/* unreachable */
assert(0 &&
"get_value() called on empty Result without checking for error.");
}
}
const ur_result_t get_error() {
auto *err = std::get_if<ur_result_t>(&value_or_err);
return err ? *err : UR_RESULT_SUCCESS;
}

private:
std::variant<ur_result_t, T> value_or_err;
};

0 comments on commit f465e56

Please sign in to comment.