Skip to content

Commit

Permalink
Fix fallback to passthrough calls to single driver given drivers removed
Browse files Browse the repository at this point in the history
- Fix the fallback case where multiple drivers are found, but only one
  inits such that the ddi tables are reinit to point back to the driver.

Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
  • Loading branch information
nrspruit committed Aug 20, 2024
1 parent 05b5259 commit cfe26f0
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 6 deletions.
20 changes: 19 additions & 1 deletion source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,25 @@ namespace ze_lib
{
// Check which drivers support the ze_driver_flag_t specified
// No need to check if only initializing sysman
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global);
bool requireDdiReinit = false;
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &requireDdiReinit);
if (requireDdiReinit) {
// reInit the ZE DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zeDdiTableInit();
}
// reInit the ZET DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zetDdiTableInit();
}
// reInit the ZES DDI Tables
if( ZE_RESULT_SUCCESS == result )
{
result = zesDdiTableInit();
}
}
}

if( ZE_RESULT_SUCCESS == result )
Expand Down
3 changes: 2 additions & 1 deletion source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ namespace loader
}
}

ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) {
ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit) {
if (debugTraceEnabled) {
std::string message = "check_drivers(" + std::string("flags=") + loader::to_string(flags) + ")";
debug_trace_message(message, "");
Expand Down Expand Up @@ -156,6 +156,7 @@ namespace loader
debug_trace_message(errorMessage, loader::to_string(result));
}
it = drivers.erase(it);
*requireDdiReinit = true;
if(return_first_driver_result)
return result;
}
Expand Down
4 changes: 2 additions & 2 deletions source/loader/ze_loader_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ zeLoaderInit()
/// - ::ZE_RESULT_SUCCESS
/// - ::ZE_RESULT_ERROR_UNINITIALIZED
ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored)
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit)
{
return loader::context->check_drivers(flags, globalInitStored);
return loader::context->check_drivers(flags, globalInitStored, requireDdiReinit);
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion source/loader/ze_loader_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ zeLoaderInit();
/// - ::ZE_RESULT_SUCCESS
/// - ::ZE_RESULT_ERROR_UNINITIALIZED
ZE_DLLEXPORT ze_result_t ZE_APICALL
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);
zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit);


///////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion source/loader/ze_loader_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ namespace loader
std::vector<zel_component_version_t> compVersions;
const char *LOADER_COMP_NAME = "loader";

ze_result_t check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);
ze_result_t check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit);
void debug_trace_message(std::string errorMessage, std::string errorValue);
ze_result_t init();
ze_result_t init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored);
Expand Down

0 comments on commit cfe26f0

Please sign in to comment.