Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeviceSanitizer] Support detecting memory leaks of USM #1808

Merged
merged 18 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions source/loader/layers/sanitizer/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,34 @@ SanitizerInterceptor::findAllocInfoByAddress(uptr Address) {
return It;
}

std::vector<AllocationIterator>
SanitizerInterceptor::findAllocInfoByContext(ur_context_handle_t Context) {
std::shared_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
std::vector<AllocationIterator> AllocInfos;
for (auto It = m_AllocationMap.begin(); It != m_AllocationMap.end(); It++) {
const auto &[_, AI] = *It;
if (AI->Context == Context) {
AllocInfos.emplace_back(It);
}
}
return AllocInfos;
}

ContextInfo::~ContextInfo() {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRelease(Handle);
assert(Result == UR_RESULT_SUCCESS);

std::vector<AllocationIterator> AllocInfos =
getContext()->interceptor->findAllocInfoByContext(Handle);
for (const auto &It : AllocInfos) {
const auto &[_, AI] = *It;
if (!AI->IsReleased) {
ReportMemoryLeak(AI);
}
}
}

ur_result_t USMLaunchInfo::initialize() {
UR_CALL(getContext()->urDdiTable.Context.pfnRetain(Context));
UR_CALL(getContext()->urDdiTable.Device.pfnRetain(Device));
Expand Down
16 changes: 10 additions & 6 deletions source/loader/layers/sanitizer/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct DeviceInfo {
// Device features
bool IsSupportSharedSystemUSM = false;

// lock this mutex if following fields are accessed
ur_mutex Mutex;
std::queue<std::shared_ptr<AllocInfo>> Quarantine;
size_t QuarantineSize = 0;
Expand All @@ -59,6 +60,7 @@ struct DeviceInfo {
struct QueueInfo {
ur_queue_handle_t Handle;

// lock this mutex if following fields are accessed
ur_shared_mutex Mutex;
ur_event_handle_t LastEvent;

Expand All @@ -78,8 +80,10 @@ struct QueueInfo {

struct KernelInfo {
ur_kernel_handle_t Handle;
ur_shared_mutex Mutex;
std::atomic<int32_t> RefCount = 1;

// lock this mutex if following fields are accessed
ur_shared_mutex Mutex;
std::unordered_map<uint32_t, std::shared_ptr<MemBuffer>> BufferArgs;
std::unordered_map<uint32_t, std::pair<const void *, StackTrace>>
PointerArgs;
Expand All @@ -102,6 +106,7 @@ struct KernelInfo {

struct ContextInfo {
ur_context_handle_t Handle;
std::atomic<int32_t> RefCount = 1;

std::vector<ur_device_handle_t> DeviceList;
std::unordered_map<ur_device_handle_t, AllocInfoList> AllocInfosMap;
Expand All @@ -112,11 +117,7 @@ struct ContextInfo {
assert(Result == UR_RESULT_SUCCESS);
}

~ContextInfo() {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRelease(Handle);
assert(Result == UR_RESULT_SUCCESS);
}
~ContextInfo();

void insertAllocInfo(const std::vector<ur_device_handle_t> &Devices,
std::shared_ptr<AllocInfo> &AI) {
Expand Down Expand Up @@ -211,6 +212,9 @@ class SanitizerInterceptor {

std::optional<AllocationIterator> findAllocInfoByAddress(uptr Address);

std::vector<AllocationIterator>
findAllocInfoByContext(ur_context_handle_t Context);

std::shared_ptr<ContextInfo> getContextInfo(ur_context_handle_t Context) {
std::shared_lock<ur_shared_mutex> Guard(m_ContextMapMutex);
assert(m_ContextMap.find(Context) != m_ContextMap.end());
Expand Down
10 changes: 10 additions & 0 deletions source/loader/layers/sanitizer/asan_report.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ void ReportDoubleFree(uptr Addr, const StackTrace &Stack,
AI->AllocStack.print();
}

void ReportMemoryLeak(const std::shared_ptr<AllocInfo> &AI) {
getContext()->logger.always(
"\n====ERROR: DeviceSanitizer: detected memory leaks of {}",
ToString(AI->Type));
getContext()->logger.always(
"Direct leak of {} byte(s) at {} allocated from:",
AI->UserEnd - AI->UserBegin, (void *)AI->UserBegin);
AI->AllocStack.print();
}

void ReportFatalError(const DeviceSanitizerReport &Report) {
getContext()->logger.always("\n====ERROR: DeviceSanitizer: {}",
ToString(Report.ErrorType));
Expand Down
5 changes: 4 additions & 1 deletion source/loader/layers/sanitizer/asan_report.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ void ReportBadContext(uptr Addr, const StackTrace &stack,
void ReportDoubleFree(uptr Addr, const StackTrace &Stack,
const std::shared_ptr<AllocInfo> &AllocInfo);

// This type of error is usually unexpected mistake and doesn't have enough debug information
void ReportMemoryLeak(const std::shared_ptr<AllocInfo> &AI);

// This type of error is usually unexpected mistake and doesn't have enough
// debug information
void ReportFatalError(const DeviceSanitizerReport &Report);

void ReportGenericError(const DeviceSanitizerReport &Report,
Expand Down
48 changes: 38 additions & 10 deletions source/loader/layers/sanitizer/ur_sanddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,29 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
return result;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urContextRetain
__urdlllocal ur_result_t UR_APICALL urContextRetain(
ur_context_handle_t
hContext ///< [in] handle of the context to get a reference of.
) {
auto pfnRetain = getContext()->urDdiTable.Context.pfnRetain;

if (nullptr == pfnRetain) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

getContext()->logger.debug("==== urContextRetain");

UR_CALL(pfnRetain(hContext));

auto ContextInfo = getContext()->interceptor->getContextInfo(hContext);
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
ContextInfo->RefCount++;

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urContextRelease
__urdlllocal ur_result_t UR_APICALL urContextRelease(
Expand All @@ -424,10 +447,15 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(

getContext()->logger.debug("==== urContextRelease");

UR_CALL(getContext()->interceptor->eraseContext(hContext));
ur_result_t result = pfnRelease(hContext);
UR_CALL(pfnRelease(hContext));

return result;
auto ContextInfo = getContext()->interceptor->getContextInfo(hContext);
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
if (--ContextInfo->RefCount == 0) {
UR_CALL(getContext()->interceptor->eraseContext(hContext));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1207,9 +1235,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

UR_CALL(pfnRetain(hKernel));

if (auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel)) {
KernelInfo->RefCount++;
}
auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel);
UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
KernelInfo->RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1228,10 +1256,9 @@ __urdlllocal ur_result_t urKernelRelease(
getContext()->logger.debug("==== urKernelRelease");
UR_CALL(pfnRelease(hKernel));

if (auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel)) {
if (--KernelInfo->RefCount != 0) {
return UR_RESULT_SUCCESS;
}
auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel);
UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
if (--KernelInfo->RefCount == 0) {
UR_CALL(getContext()->interceptor->eraseKernel(hKernel));
}

Expand Down Expand Up @@ -1426,6 +1453,7 @@ __urdlllocal ur_result_t UR_APICALL urGetContextProcAddrTable(
ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnCreate = ur_sanitizer_layer::urContextCreate;
pDdiTable->pfnRetain = ur_sanitizer_layer::urContextRetain;
pDdiTable->pfnRelease = ur_sanitizer_layer::urContextRelease;

pDdiTable->pfnCreateWithNativeHandle =
Expand Down
Loading