Skip to content

Commit

Permalink
Fix first zeinit to allow for layer checks (#177)
Browse files Browse the repository at this point in the history
- Use ddi table init of zeinit for the first call to zeInit
  to enable intercept in layers.

Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
  • Loading branch information
nrspruit authored Aug 7, 2024
1 parent 3047d0f commit 3d1e4a7
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ namespace ze_lib
{
// Check which drivers support the ze_driver_flag_t specified
// No need to check if only initializing sysman
result = zelLoaderDriverCheck(flags);
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global);
}

if( ZE_RESULT_SUCCESS == result )
Expand Down
11 changes: 6 additions & 5 deletions source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ namespace loader
}
}

ze_result_t context_t::check_drivers(ze_init_flags_t flags) {
ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) {
if (debugTraceEnabled) {
std::string message = "check_drivers(" + std::string("flags=") + loader::to_string(flags) + ")";
debug_trace_message(message, "");
Expand All @@ -137,7 +137,7 @@ namespace loader
for(auto it = drivers.begin(); it != drivers.end(); )
{
std::string freeLibraryErrorValue;
ze_result_t result = init_driver(*it, flags);
ze_result_t result = init_driver(*it, flags, globalInitStored);
if(result != ZE_RESULT_SUCCESS) {
if (it->handle) {
auto free_result = FREE_DRIVER_LIBRARY(it->handle);
Expand Down Expand Up @@ -170,7 +170,7 @@ namespace loader
return ZE_RESULT_SUCCESS;
}

ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags) {
ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) {

auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable"));
Expand Down Expand Up @@ -201,15 +201,16 @@ namespace loader
}

auto pfnInit = global.pfnInit;
if(nullptr == pfnInit) {
if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

ze_result_t res = pfnInit(flags);
// Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls
ze_result_t res = globalInitStored->pfnInit(flags);
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning ";
debug_trace_message(message, loader::to_string(res));
Expand Down
4 changes: 2 additions & 2 deletions source/loader/ze_loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ zeLoaderInit()
/// - ::ZE_RESULT_SUCCESS
/// - ::ZE_RESULT_ERROR_UNINITIALIZED
ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderDriverCheck(ze_init_flags_t flags)
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored)
{
return loader::context->check_drivers(flags);
return loader::context->check_drivers(flags, globalInitStored);
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion source/loader/ze_loader_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ zeLoaderInit();
/// - ::ZE_RESULT_SUCCESS
/// - ::ZE_RESULT_ERROR_UNINITIALIZED
ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderDriverCheck(ze_init_flags_t flags);
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);


///////////////////////////////////////////////////////////////////////////////
Expand Down
4 changes: 2 additions & 2 deletions source/loader/ze_loader_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ namespace loader
std::vector<zel_component_version_t> compVersions;
const char *LOADER_COMP_NAME = "loader";

ze_result_t check_drivers(ze_init_flags_t flags);
ze_result_t check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);
void debug_trace_message(std::string errorMessage, std::string errorValue);
ze_result_t init();
ze_result_t init_driver(driver_t driver, ze_init_flags_t flags);
ze_result_t init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);
void add_loader_version();
~context_t();
bool intercept_enabled = false;
Expand Down

0 comments on commit 3d1e4a7

Please sign in to comment.