Skip to content

Commit

Permalink
[UR] Expand lifetime validation with
Browse files Browse the repository at this point in the history
... handle type checks.
  • Loading branch information
kswiecicki committed Dec 8, 2023
1 parent 80a298e commit 510bd27
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 23 deletions.
64 changes: 45 additions & 19 deletions source/loader/layers/validation/ur_leak_check.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ur_validation_layer.hpp"

#include <mutex>
#include <typeindex>
#include <unordered_map>
#include <utility>

Expand All @@ -20,7 +21,12 @@ struct RefCountContext {
private:
struct RefRuntimeInfo {
int64_t refCount;
std::type_index type;
std::vector<BacktraceLine> backtrace;

RefRuntimeInfo(int64_t refCount, std::type_index type,
std::vector<BacktraceLine> backtrace)
: refCount(refCount), type(type), backtrace(backtrace) {}
};

enum RefCountUpdateType {
Expand All @@ -34,26 +40,32 @@ struct RefCountContext {
std::unordered_map<void *, struct RefRuntimeInfo> counts;
int64_t adapterCount = 0;

void updateRefCount(void *ptr, enum RefCountUpdateType type,
template <typename T>
void updateRefCount(T handle, enum RefCountUpdateType type,
bool isAdapterHandle = false) {
std::unique_lock<std::mutex> ulock(mutex);

void *ptr = static_cast<void *>(handle);
auto it = counts.find(ptr);

switch (type) {
case REFCOUNT_CREATE_OR_INCREASE:
if (it == counts.end()) {
counts[ptr] = {1, getCurrentBacktrace()};
std::tie(it, std::ignore) = counts.emplace(
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
getCurrentBacktrace()});
if (isAdapterHandle) {
adapterCount++;
}
} else {
counts[ptr].refCount++;
it->second.refCount++;
}
break;
case REFCOUNT_CREATE:
if (it == counts.end()) {
counts[ptr] = {1, getCurrentBacktrace()};
std::tie(it, std::ignore) = counts.emplace(
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
getCurrentBacktrace()});
} else {
context.logger.error("Handle {} already exists", ptr);
return;
Expand All @@ -65,29 +77,31 @@ struct RefCountContext {
"Attempting to retain nonexistent handle {}", ptr);
return;
} else {
counts[ptr].refCount++;
it->second.refCount++;
}
break;
case REFCOUNT_DECREASE:
if (it == counts.end()) {
counts[ptr] = {-1, getCurrentBacktrace()};
std::tie(it, std::ignore) = counts.emplace(
ptr, RefRuntimeInfo{-1, std::type_index(typeid(handle)),
getCurrentBacktrace()});
} else {
counts[ptr].refCount--;
it->second.refCount--;
}

if (counts[ptr].refCount < 0) {
if (it->second.refCount < 0) {
context.logger.error(
"Attempting to release nonexistent handle {}", ptr);
} else if (counts[ptr].refCount == 0 && isAdapterHandle) {
} else if (it->second.refCount == 0 && isAdapterHandle) {
adapterCount--;
}
break;
}

context.logger.debug("Reference count for handle {} changed to {}", ptr,
counts[ptr].refCount);
it->second.refCount);

if (counts[ptr].refCount == 0) {
if (it->second.refCount == 0) {
counts.erase(ptr);
}

Expand All @@ -99,23 +113,35 @@ struct RefCountContext {
}

public:
void createRefCount(void *ptr) { updateRefCount(ptr, REFCOUNT_CREATE); }
template <typename T> void createRefCount(T handle) {
updateRefCount<T>(handle, REFCOUNT_CREATE);
}

void incrementRefCount(void *ptr, bool isAdapterHandle = false) {
updateRefCount(ptr, REFCOUNT_INCREASE, isAdapterHandle);
template <typename T>
void incrementRefCount(T handle, bool isAdapterHandle = false) {
updateRefCount(handle, REFCOUNT_INCREASE, isAdapterHandle);
}

void decrementRefCount(void *ptr, bool isAdapterHandle = false) {
updateRefCount(ptr, REFCOUNT_DECREASE, isAdapterHandle);
template <typename T>
void decrementRefCount(T handle, bool isAdapterHandle = false) {
updateRefCount(handle, REFCOUNT_DECREASE, isAdapterHandle);
}

void createOrIncrementRefCount(void *ptr, bool isAdapterHandle = false) {
updateRefCount(ptr, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
template <typename T>
void createOrIncrementRefCount(T handle, bool isAdapterHandle = false) {
updateRefCount(handle, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
}

void clear() { counts.clear(); }

bool isReferenceValid(void *ptr) { return counts.count(ptr) > 0; }
template <typename T> bool isReferenceValid(T handle) {
auto it = counts.find(static_cast<void *>(handle));
if (it == counts.end() || it->second.refCount < 1) {
return false;
}

return (it->second.type == std::type_index(typeid(handle)));
}

void logInvalidReferences() {
for (auto &[ptr, refRuntimeInfo] : counts) {
Expand Down
14 changes: 10 additions & 4 deletions test/layers/validation/lifetime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ TEST_F(urTest, testUrAdapterHandleLifetimeExpectFail) {
size_t size = 0;
ur_adapter_handle_t adapter = (ur_adapter_handle_t)0xC0FFEE;
ur_adapter_info_t info_type = UR_ADAPTER_INFO_BACKEND;
ASSERT_EQ(urAdapterGetInfo(adapter, info_type, 0, nullptr, &size),
UR_RESULT_ERROR_INVALID_ARGUMENT);
urAdapterGetInfo(adapter, info_type, 0, nullptr, &size);
}

TEST_F(valAdapterTest, testUrAdapterHandleLifetimeExpectSuccess) {
size_t size = 0;
ur_adapter_info_t info_type = UR_ADAPTER_INFO_BACKEND;
ASSERT_EQ(urAdapterGetInfo(adapter, info_type, 0, nullptr, &size),
UR_RESULT_SUCCESS);
urAdapterGetInfo(adapter, info_type, 0, nullptr, &size);
}

TEST_F(valAdapterTest, testUrAdapterHandleTypeMismatchExpectFail) {
size_t size = 0;
// Use valid adapter handle with incorrect cast.
ur_device_handle_t device = (ur_device_handle_t)adapter;
ur_device_info_t info_type = UR_DEVICE_INFO_BACKEND_RUNTIME_VERSION;
urDeviceGetInfo(device, info_type, 0, nullptr, &size);
}
5 changes: 5 additions & 0 deletions test/layers/validation/lifetime.out.match
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
<VALIDATION>[ERROR]: There are no valid references to handle {{[0-9xa-fA-F]+}}
{{IGNORE}}
[ RUN ] valAdapterTest.testUrAdapterHandleLifetimeExpectSuccess
<VALIDATION>[DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1
{{^(?!.*There are no valid references to handle).*$}}
{{IGNORE}}
[ RUN ] valAdapterTest.testUrAdapterHandleTypeMismatchExpectFail
<VALIDATION>[DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1
<VALIDATION>[ERROR]: There are no valid references to handle {{[0-9xa-fA-F]+}}
{{IGNORE}}

0 comments on commit 510bd27

Please sign in to comment.