diff --git a/source/common/unified_malloc_framework/CMakeLists.txt b/source/common/unified_malloc_framework/CMakeLists.txt index 30991629cd..86d19f3099 100644 --- a/source/common/unified_malloc_framework/CMakeLists.txt +++ b/source/common/unified_malloc_framework/CMakeLists.txt @@ -16,6 +16,7 @@ if(UMF_BUILD_SHARED_LIBRARY) "Do not use the shared library in production software.") add_library(unified_malloc_framework SHARED ${UMF_SOURCES}) + target_compile_definitions(unified_malloc_framework PUBLIC UMF_SHARED_LIBRARY) else() add_library(unified_malloc_framework STATIC ${UMF_SOURCES}) diff --git a/source/common/unified_malloc_framework/src/memory_tracker.cpp b/source/common/unified_malloc_framework/src/memory_tracker.cpp index 74638579cc..adbe2aa5e9 100644 --- a/source/common/unified_malloc_framework/src/memory_tracker.cpp +++ b/source/common/unified_malloc_framework/src/memory_tracker.cpp @@ -18,6 +18,10 @@ #include #include +#ifdef _WIN32 +#include +#endif + // TODO: reimplement in C and optimize... struct umf_memory_tracker_t { enum umf_result_t add(void *pool, const void *ptr, size_t size) { @@ -84,11 +88,30 @@ umfMemoryTrackerRemove(umf_memory_tracker_handle_t hTracker, const void *ptr, extern "C" { -umf_memory_tracker_handle_t umfMemoryTrackerGet(void) { - static umf_memory_tracker_t tracker; - return &tracker; +#if defined(_WIN32) && defined(UMF_SHARED_LIBRARY) +umf_memory_tracker_t *tracker = nullptr; +BOOL APIENTRY DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved) { + if (fdwReason == DLL_PROCESS_DETACH) { + delete tracker; + } else if (fdwReason == DLL_PROCESS_ATTACH) { + tracker = new umf_memory_tracker_t; + } + return TRUE; +} +#elif defined(_WIN32) +umf_memory_tracker_t trackerInstance; +umf_memory_tracker_t *tracker = &trackerInstance; +#else +umf_memory_tracker_t *tracker = nullptr; +void __attribute__((constructor)) createLibTracker() { + tracker = new umf_memory_tracker_t; } +void __attribute__((destructor)) deleteLibTracker() { delete tracker; } +#endif + +umf_memory_tracker_handle_t umfMemoryTrackerGet(void) { return tracker; } + void *umfMemoryTrackerGetPool(umf_memory_tracker_handle_t hTracker, const void *ptr) { return hTracker->find(ptr);