Skip to content

Commit

Permalink
Merge pull request #1808 from zhaomaosu/detect-memory-leak
Browse files Browse the repository at this point in the history
[DeviceSanitizer] Support detecting memory leaks of USM
  • Loading branch information
pbalcer authored Sep 19, 2024
2 parents e8182b8 + 5653b30 commit 2af159d
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 17 deletions.
28 changes: 28 additions & 0 deletions source/loader/layers/sanitizer/asan_interceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,34 @@ SanitizerInterceptor::findAllocInfoByAddress(uptr Address) {
return It;
}

std::vector<AllocationIterator>
SanitizerInterceptor::findAllocInfoByContext(ur_context_handle_t Context) {
std::shared_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
std::vector<AllocationIterator> AllocInfos;
for (auto It = m_AllocationMap.begin(); It != m_AllocationMap.end(); It++) {
const auto &[_, AI] = *It;
if (AI->Context == Context) {
AllocInfos.emplace_back(It);
}
}
return AllocInfos;
}

ContextInfo::~ContextInfo() {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRelease(Handle);
assert(Result == UR_RESULT_SUCCESS);

std::vector<AllocationIterator> AllocInfos =
getContext()->interceptor->findAllocInfoByContext(Handle);
for (const auto &It : AllocInfos) {
const auto &[_, AI] = *It;
if (!AI->IsReleased) {
ReportMemoryLeak(AI);
}
}
}

ur_result_t USMLaunchInfo::initialize() {
UR_CALL(getContext()->urDdiTable.Context.pfnRetain(Context));
UR_CALL(getContext()->urDdiTable.Device.pfnRetain(Device));
Expand Down
16 changes: 10 additions & 6 deletions source/loader/layers/sanitizer/asan_interceptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct DeviceInfo {
// Device features
bool IsSupportSharedSystemUSM = false;

// lock this mutex if following fields are accessed
ur_mutex Mutex;
std::queue<std::shared_ptr<AllocInfo>> Quarantine;
size_t QuarantineSize = 0;
Expand All @@ -59,6 +60,7 @@ struct DeviceInfo {
struct QueueInfo {
ur_queue_handle_t Handle;

// lock this mutex if following fields are accessed
ur_shared_mutex Mutex;
ur_event_handle_t LastEvent;

Expand All @@ -78,8 +80,10 @@ struct QueueInfo {

struct KernelInfo {
ur_kernel_handle_t Handle;
ur_shared_mutex Mutex;
std::atomic<int32_t> RefCount = 1;

// lock this mutex if following fields are accessed
ur_shared_mutex Mutex;
std::unordered_map<uint32_t, std::shared_ptr<MemBuffer>> BufferArgs;
std::unordered_map<uint32_t, std::pair<const void *, StackTrace>>
PointerArgs;
Expand All @@ -102,6 +106,7 @@ struct KernelInfo {

struct ContextInfo {
ur_context_handle_t Handle;
std::atomic<int32_t> RefCount = 1;

std::vector<ur_device_handle_t> DeviceList;
std::unordered_map<ur_device_handle_t, AllocInfoList> AllocInfosMap;
Expand All @@ -112,11 +117,7 @@ struct ContextInfo {
assert(Result == UR_RESULT_SUCCESS);
}

~ContextInfo() {
[[maybe_unused]] auto Result =
getContext()->urDdiTable.Context.pfnRelease(Handle);
assert(Result == UR_RESULT_SUCCESS);
}
~ContextInfo();

void insertAllocInfo(const std::vector<ur_device_handle_t> &Devices,
std::shared_ptr<AllocInfo> &AI) {
Expand Down Expand Up @@ -211,6 +212,9 @@ class SanitizerInterceptor {

std::optional<AllocationIterator> findAllocInfoByAddress(uptr Address);

std::vector<AllocationIterator>
findAllocInfoByContext(ur_context_handle_t Context);

std::shared_ptr<ContextInfo> getContextInfo(ur_context_handle_t Context) {
std::shared_lock<ur_shared_mutex> Guard(m_ContextMapMutex);
assert(m_ContextMap.find(Context) != m_ContextMap.end());
Expand Down
10 changes: 10 additions & 0 deletions source/loader/layers/sanitizer/asan_report.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ void ReportDoubleFree(uptr Addr, const StackTrace &Stack,
AI->AllocStack.print();
}

void ReportMemoryLeak(const std::shared_ptr<AllocInfo> &AI) {
getContext()->logger.always(
"\n====ERROR: DeviceSanitizer: detected memory leaks of {}",
ToString(AI->Type));
getContext()->logger.always(
"Direct leak of {} byte(s) at {} allocated from:",
AI->UserEnd - AI->UserBegin, (void *)AI->UserBegin);
AI->AllocStack.print();
}

void ReportFatalError(const DeviceSanitizerReport &Report) {
getContext()->logger.always("\n====ERROR: DeviceSanitizer: {}",
ToString(Report.ErrorType));
Expand Down
5 changes: 4 additions & 1 deletion source/loader/layers/sanitizer/asan_report.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ void ReportBadContext(uptr Addr, const StackTrace &stack,
void ReportDoubleFree(uptr Addr, const StackTrace &Stack,
const std::shared_ptr<AllocInfo> &AllocInfo);

// This type of error is usually unexpected mistake and doesn't have enough debug information
void ReportMemoryLeak(const std::shared_ptr<AllocInfo> &AI);

// This type of error is usually unexpected mistake and doesn't have enough
// debug information
void ReportFatalError(const DeviceSanitizerReport &Report);

void ReportGenericError(const DeviceSanitizerReport &Report,
Expand Down
48 changes: 38 additions & 10 deletions source/loader/layers/sanitizer/ur_sanddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,29 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
return result;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urContextRetain
__urdlllocal ur_result_t UR_APICALL urContextRetain(
ur_context_handle_t
hContext ///< [in] handle of the context to get a reference of.
) {
auto pfnRetain = getContext()->urDdiTable.Context.pfnRetain;

if (nullptr == pfnRetain) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

getContext()->logger.debug("==== urContextRetain");

UR_CALL(pfnRetain(hContext));

auto ContextInfo = getContext()->interceptor->getContextInfo(hContext);
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
ContextInfo->RefCount++;

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urContextRelease
__urdlllocal ur_result_t UR_APICALL urContextRelease(
Expand All @@ -424,10 +447,15 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(

getContext()->logger.debug("==== urContextRelease");

UR_CALL(getContext()->interceptor->eraseContext(hContext));
ur_result_t result = pfnRelease(hContext);
UR_CALL(pfnRelease(hContext));

return result;
auto ContextInfo = getContext()->interceptor->getContextInfo(hContext);
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
if (--ContextInfo->RefCount == 0) {
UR_CALL(getContext()->interceptor->eraseContext(hContext));
}

return UR_RESULT_SUCCESS;
}

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1207,9 +1235,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

UR_CALL(pfnRetain(hKernel));

if (auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel)) {
KernelInfo->RefCount++;
}
auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel);
UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
KernelInfo->RefCount++;

return UR_RESULT_SUCCESS;
}
Expand All @@ -1228,10 +1256,9 @@ __urdlllocal ur_result_t urKernelRelease(
getContext()->logger.debug("==== urKernelRelease");
UR_CALL(pfnRelease(hKernel));

if (auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel)) {
if (--KernelInfo->RefCount != 0) {
return UR_RESULT_SUCCESS;
}
auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel);
UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
if (--KernelInfo->RefCount == 0) {
UR_CALL(getContext()->interceptor->eraseKernel(hKernel));
}

Expand Down Expand Up @@ -1426,6 +1453,7 @@ __urdlllocal ur_result_t UR_APICALL urGetContextProcAddrTable(
ur_result_t result = UR_RESULT_SUCCESS;

pDdiTable->pfnCreate = ur_sanitizer_layer::urContextCreate;
pDdiTable->pfnRetain = ur_sanitizer_layer::urContextRetain;
pDdiTable->pfnRelease = ur_sanitizer_layer::urContextRelease;

pDdiTable->pfnCreateWithNativeHandle =
Expand Down

0 comments on commit 2af159d

Please sign in to comment.