Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix first zeinit to allow for layer checks #177

Merged
merged 1 commit into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ namespace ze_lib
{
// Check which drivers support the ze_driver_flag_t specified
// No need to check if only initializing sysman
result = zelLoaderDriverCheck(flags);
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global);
}

if( ZE_RESULT_SUCCESS == result )
Expand Down
11 changes: 6 additions & 5 deletions source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ namespace loader
}
}

ze_result_t context_t::check_drivers(ze_init_flags_t flags) {
ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) {
if (debugTraceEnabled) {
std::string message = "check_drivers(" + std::string("flags=") + loader::to_string(flags) + ")";
debug_trace_message(message, "");
Expand All @@ -137,7 +137,7 @@ namespace loader
for(auto it = drivers.begin(); it != drivers.end(); )
{
std::string freeLibraryErrorValue;
ze_result_t result = init_driver(*it, flags);
ze_result_t result = init_driver(*it, flags, globalInitStored);
if(result != ZE_RESULT_SUCCESS) {
if (it->handle) {
auto free_result = FREE_DRIVER_LIBRARY(it->handle);
Expand Down Expand Up @@ -170,7 +170,7 @@ namespace loader
return ZE_RESULT_SUCCESS;
}

ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags) {
ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) {

auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable"));
Expand Down Expand Up @@ -201,15 +201,16 @@ namespace loader
}

auto pfnInit = global.pfnInit;
if(nullptr == pfnInit) {
if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

ze_result_t res = pfnInit(flags);
// Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls
ze_result_t res = globalInitStored->pfnInit(flags);
nrspruit marked this conversation as resolved.
Show resolved Hide resolved
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning ";
debug_trace_message(message, loader::to_string(res));
Expand Down
4 changes: 2 additions & 2 deletions source/loader/ze_loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ zeLoaderInit()
/// - ::ZE_RESULT_SUCCESS
/// - ::ZE_RESULT_ERROR_UNINITIALIZED
ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderDriverCheck(ze_init_flags_t flags)
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored)
{
return loader::context->check_drivers(flags);
return loader::context->check_drivers(flags, globalInitStored);
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion source/loader/ze_loader_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ zeLoaderInit();
/// - ::ZE_RESULT_SUCCESS
/// - ::ZE_RESULT_ERROR_UNINITIALIZED
ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderDriverCheck(ze_init_flags_t flags);
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);


///////////////////////////////////////////////////////////////////////////////
Expand Down
4 changes: 2 additions & 2 deletions source/loader/ze_loader_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ namespace loader
std::vector<zel_component_version_t> compVersions;
const char *LOADER_COMP_NAME = "loader";

ze_result_t check_drivers(ze_init_flags_t flags);
ze_result_t check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);
void debug_trace_message(std::string errorMessage, std::string errorValue);
ze_result_t init();
ze_result_t init_driver(driver_t driver, ze_init_flags_t flags);
ze_result_t init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);
void add_loader_version();
~context_t();
bool intercept_enabled = false;
Expand Down
Loading