From 9217c29d15a71a526a25154a3baf43313ed3b6c6 Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Wed, 7 Aug 2024 09:36:46 -0700 Subject: [PATCH] Fix first zeinit to allow for layer checks - Use ddi table init of zeinit for the first call to zeInit to enable intercept in layers. Signed-off-by: Neil R. Spruit --- source/lib/ze_lib.cpp | 2 +- source/loader/ze_loader.cpp | 11 ++++++----- source/loader/ze_loader_api.cpp | 4 ++-- source/loader/ze_loader_api.h | 2 +- source/loader/ze_loader_internal.h | 4 ++-- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/source/lib/ze_lib.cpp b/source/lib/ze_lib.cpp index f818045..ae402fa 100644 --- a/source/lib/ze_lib.cpp +++ b/source/lib/ze_lib.cpp @@ -108,7 +108,7 @@ namespace ze_lib { // Check which drivers support the ze_driver_flag_t specified // No need to check if only initializing sysman - result = zelLoaderDriverCheck(flags); + result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global); } if( ZE_RESULT_SUCCESS == result ) diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index dd4a61d..db17222 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -124,7 +124,7 @@ namespace loader } } - ze_result_t context_t::check_drivers(ze_init_flags_t flags) { + ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) { if (debugTraceEnabled) { std::string message = "check_drivers(" + std::string("flags=") + loader::to_string(flags) + ")"; debug_trace_message(message, ""); @@ -137,7 +137,7 @@ namespace loader for(auto it = drivers.begin(); it != drivers.end(); ) { std::string freeLibraryErrorValue; - ze_result_t result = init_driver(*it, flags); + ze_result_t result = init_driver(*it, flags, globalInitStored); if(result != ZE_RESULT_SUCCESS) { if (it->handle) { auto free_result = FREE_DRIVER_LIBRARY(it->handle); @@ -170,7 +170,7 @@ namespace loader return ZE_RESULT_SUCCESS; } - ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags) { + ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) { auto getTable = reinterpret_cast( GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable")); @@ -201,7 +201,7 @@ namespace loader } auto pfnInit = global.pfnInit; - if(nullptr == pfnInit) { + if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) { if (debugTraceEnabled) { std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning "; debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); @@ -209,7 +209,8 @@ namespace loader return ZE_RESULT_ERROR_UNINITIALIZED; } - ze_result_t res = pfnInit(flags); + // Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls + ze_result_t res = globalInitStored->pfnInit(flags); if (debugTraceEnabled) { std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning "; debug_trace_message(message, loader::to_string(res)); diff --git a/source/loader/ze_loader_api.cpp b/source/loader/ze_loader_api.cpp index aaaafe7..228f5e2 100644 --- a/source/loader/ze_loader_api.cpp +++ b/source/loader/ze_loader_api.cpp @@ -33,9 +33,9 @@ zeLoaderInit() /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED ZE_DLLEXPORT ze_result_t ZE_APICALL -zelLoaderDriverCheck(ze_init_flags_t flags) +zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) { - return loader::context->check_drivers(flags); + return loader::context->check_drivers(flags, globalInitStored); } /////////////////////////////////////////////////////////////////////////////// diff --git a/source/loader/ze_loader_api.h b/source/loader/ze_loader_api.h index c73eb85..8d5bbbb 100644 --- a/source/loader/ze_loader_api.h +++ b/source/loader/ze_loader_api.h @@ -33,7 +33,7 @@ zeLoaderInit(); /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED ZE_DLLEXPORT ze_result_t ZE_APICALL -zelLoaderDriverCheck(ze_init_flags_t flags); +zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored); /////////////////////////////////////////////////////////////////////////////// diff --git a/source/loader/ze_loader_internal.h b/source/loader/ze_loader_internal.h index 75f9d10..713d578 100644 --- a/source/loader/ze_loader_internal.h +++ b/source/loader/ze_loader_internal.h @@ -53,10 +53,10 @@ namespace loader std::vector compVersions; const char *LOADER_COMP_NAME = "loader"; - ze_result_t check_drivers(ze_init_flags_t flags); + ze_result_t check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored); void debug_trace_message(std::string errorMessage, std::string errorValue); ze_result_t init(); - ze_result_t init_driver(driver_t driver, ze_init_flags_t flags); + ze_result_t init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored); void add_loader_version(); ~context_t(); bool intercept_enabled = false;