diff --git a/source/loader/layers/validation/ur_leak_check.hpp b/source/loader/layers/validation/ur_leak_check.hpp index b7cb487f88..8202ac447d 100644 --- a/source/loader/layers/validation/ur_leak_check.hpp +++ b/source/loader/layers/validation/ur_leak_check.hpp @@ -9,6 +9,7 @@ #include "ur_validation_layer.hpp" #include +#include #include #include @@ -20,7 +21,12 @@ struct RefCountContext { private: struct RefRuntimeInfo { int64_t refCount; + std::type_index type; std::vector backtrace; + + RefRuntimeInfo(int64_t refCount, std::type_index type, + std::vector backtrace) + : refCount(refCount), type(type), backtrace(backtrace) {} }; enum RefCountUpdateType { @@ -34,26 +40,32 @@ struct RefCountContext { std::unordered_map counts; int64_t adapterCount = 0; - void updateRefCount(void *ptr, enum RefCountUpdateType type, + template + void updateRefCount(T handle, enum RefCountUpdateType type, bool isAdapterHandle = false) { std::unique_lock ulock(mutex); + void *ptr = static_cast(handle); auto it = counts.find(ptr); switch (type) { case REFCOUNT_CREATE_OR_INCREASE: if (it == counts.end()) { - counts[ptr] = {1, getCurrentBacktrace()}; + std::tie(it, std::ignore) = counts.emplace( + ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)), + getCurrentBacktrace()}); if (isAdapterHandle) { adapterCount++; } } else { - counts[ptr].refCount++; + it->second.refCount++; } break; case REFCOUNT_CREATE: if (it == counts.end()) { - counts[ptr] = {1, getCurrentBacktrace()}; + std::tie(it, std::ignore) = counts.emplace( + ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)), + getCurrentBacktrace()}); } else { context.logger.error("Handle {} already exists", ptr); return; @@ -65,29 +77,31 @@ struct RefCountContext { "Attempting to retain nonexistent handle {}", ptr); return; } else { - counts[ptr].refCount++; + it->second.refCount++; } break; case REFCOUNT_DECREASE: if (it == counts.end()) { - counts[ptr] = {-1, getCurrentBacktrace()}; + std::tie(it, std::ignore) = counts.emplace( + ptr, RefRuntimeInfo{-1, std::type_index(typeid(handle)), + getCurrentBacktrace()}); } else { - counts[ptr].refCount--; + it->second.refCount--; } - if (counts[ptr].refCount < 0) { + if (it->second.refCount < 0) { context.logger.error( "Attempting to release nonexistent handle {}", ptr); - } else if (counts[ptr].refCount == 0 && isAdapterHandle) { + } else if (it->second.refCount == 0 && isAdapterHandle) { adapterCount--; } break; } context.logger.debug("Reference count for handle {} changed to {}", ptr, - counts[ptr].refCount); + it->second.refCount); - if (counts[ptr].refCount == 0) { + if (it->second.refCount == 0) { counts.erase(ptr); } @@ -99,23 +113,35 @@ struct RefCountContext { } public: - void createRefCount(void *ptr) { updateRefCount(ptr, REFCOUNT_CREATE); } + template void createRefCount(T handle) { + updateRefCount(handle, REFCOUNT_CREATE); + } - void incrementRefCount(void *ptr, bool isAdapterHandle = false) { - updateRefCount(ptr, REFCOUNT_INCREASE, isAdapterHandle); + template + void incrementRefCount(T handle, bool isAdapterHandle = false) { + updateRefCount(handle, REFCOUNT_INCREASE, isAdapterHandle); } - void decrementRefCount(void *ptr, bool isAdapterHandle = false) { - updateRefCount(ptr, REFCOUNT_DECREASE, isAdapterHandle); + template + void decrementRefCount(T handle, bool isAdapterHandle = false) { + updateRefCount(handle, REFCOUNT_DECREASE, isAdapterHandle); } - void createOrIncrementRefCount(void *ptr, bool isAdapterHandle = false) { - updateRefCount(ptr, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle); + template + void createOrIncrementRefCount(T handle, bool isAdapterHandle = false) { + updateRefCount(handle, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle); } void clear() { counts.clear(); } - bool isReferenceValid(void *ptr) { return counts.count(ptr) > 0; } + template bool isReferenceValid(T handle) { + auto it = counts.find(static_cast(handle)); + if (it == counts.end() || it->second.refCount < 1) { + return false; + } + + return (it->second.type == std::type_index(typeid(handle))); + } void logInvalidReferences() { for (auto &[ptr, refRuntimeInfo] : counts) { diff --git a/test/layers/validation/lifetime.cpp b/test/layers/validation/lifetime.cpp index 4c642fedb0..bd9a40ebe1 100644 --- a/test/layers/validation/lifetime.cpp +++ b/test/layers/validation/lifetime.cpp @@ -9,13 +9,19 @@ TEST_F(urTest, testUrAdapterHandleLifetimeExpectFail) { size_t size = 0; ur_adapter_handle_t adapter = (ur_adapter_handle_t)0xC0FFEE; ur_adapter_info_t info_type = UR_ADAPTER_INFO_BACKEND; - ASSERT_EQ(urAdapterGetInfo(adapter, info_type, 0, nullptr, &size), - UR_RESULT_ERROR_INVALID_ARGUMENT); + urAdapterGetInfo(adapter, info_type, 0, nullptr, &size); } TEST_F(valAdapterTest, testUrAdapterHandleLifetimeExpectSuccess) { size_t size = 0; ur_adapter_info_t info_type = UR_ADAPTER_INFO_BACKEND; - ASSERT_EQ(urAdapterGetInfo(adapter, info_type, 0, nullptr, &size), - UR_RESULT_SUCCESS); + urAdapterGetInfo(adapter, info_type, 0, nullptr, &size); } + +TEST_F(valAdapterTest, testUrAdapterHandleTypeMismatchExpectFail) { + size_t size = 0; + // Use valid adapter handle with incorrect cast. + ur_device_handle_t device = (ur_device_handle_t)adapter; + ur_device_info_t info_type = UR_DEVICE_INFO_BACKEND_RUNTIME_VERSION; + urDeviceGetInfo(device, info_type, 0, nullptr, &size); +} \ No newline at end of file diff --git a/test/layers/validation/lifetime.out.match b/test/layers/validation/lifetime.out.match index c454500641..f73bd71ba4 100644 --- a/test/layers/validation/lifetime.out.match +++ b/test/layers/validation/lifetime.out.match @@ -3,5 +3,10 @@ [ERROR]: There are no valid references to handle {{[0-9xa-fA-F]+}} {{IGNORE}} [ RUN ] valAdapterTest.testUrAdapterHandleLifetimeExpectSuccess +[DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1 {{^(?!.*There are no valid references to handle).*$}} {{IGNORE}} +[ RUN ] valAdapterTest.testUrAdapterHandleTypeMismatchExpectFail +[DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1 +[ERROR]: There are no valid references to handle {{[0-9xa-fA-F]+}} +{{IGNORE}}