From 7bcc56410443e8b8443564fcbea8c8ee774b98f7 Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Thu, 22 Aug 2024 12:15:11 -0700 Subject: [PATCH 1/5] Fix zesInit And zeInit ddi table and usable driver updates - Enable zesInit to check for usable drivers thru zelLoaderDriverCheck. - If a driver check was already done, then this call is skipped to avoid breaking existing allocated handles. Signed-off-by: Neil R. Spruit --- source/lib/ze_lib.cpp | 6 +- source/lib/ze_lib.h | 1 + source/loader/ze_loader.cpp | 128 +++++++++++++++++++---------- source/loader/ze_loader_api.cpp | 4 +- source/loader/ze_loader_api.h | 2 +- source/loader/ze_loader_internal.h | 4 +- 6 files changed, 95 insertions(+), 50 deletions(-) diff --git a/source/lib/ze_lib.cpp b/source/lib/ze_lib.cpp index c79688a..997123d 100644 --- a/source/lib/ze_lib.cpp +++ b/source/lib/ze_lib.cpp @@ -104,12 +104,13 @@ namespace ze_lib // End DDI Table Inits // Check which drivers and layers can be init on this system. - if( ZE_RESULT_SUCCESS == result && !sysmanOnly) + // If the driver check has already been called by zesInit or zeinit, then this is skipped. + if( ZE_RESULT_SUCCESS == result && !driverCheckCompleted) { // Check which drivers support the ze_driver_flag_t specified // No need to check if only initializing sysman bool requireDdiReinit = false; - result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &requireDdiReinit); + result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly); // If a driver was removed from the driver list, then the ddi tables need to be reinit to allow for passthru directly to the driver. // If ZET_ENABLE_PROGRAM_INSTRUMENTATION is enabled, then reInit is not possible due to the functions being intercepted with the previous ddi tables. auto programInstrumentationEnabled = getenv_tobool( "ZET_ENABLE_PROGRAM_INSTRUMENTATION" ); @@ -130,6 +131,7 @@ namespace ze_lib result = zesDdiTableInit(); } } + driverCheckCompleted = true; } if( ZE_RESULT_SUCCESS == result ) diff --git a/source/lib/ze_lib.h b/source/lib/ze_lib.h index c5fd487..10c5a99 100644 --- a/source/lib/ze_lib.h +++ b/source/lib/ze_lib.h @@ -36,6 +36,7 @@ namespace ze_lib std::once_flag initOnce; std::once_flag initOnceSysMan; + bool driverCheckCompleted = false; ze_result_t Init(ze_init_flags_t flags, bool sysmanOnly); diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index ee81dfe..381e1c0 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -124,7 +124,7 @@ namespace loader } } - ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit) { + ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly) { if (debugTraceEnabled) { std::string message = "check_drivers(" + std::string("flags=") + loader::to_string(flags) + ")"; debug_trace_message(message, ""); @@ -137,7 +137,7 @@ namespace loader for(auto it = drivers.begin(); it != drivers.end(); ) { std::string freeLibraryErrorValue; - ze_result_t result = init_driver(*it, flags, globalInitStored); + ze_result_t result = init_driver(*it, flags, globalInitStored, sysmanGlobalInitStored, sysmanOnly); if(result != ZE_RESULT_SUCCESS) { if (it->handle) { auto free_result = FREE_DRIVER_LIBRARY(it->handle); @@ -174,56 +174,98 @@ namespace loader return ZE_RESULT_SUCCESS; } - ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) { + ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly) { - auto getTable = reinterpret_cast( - GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable")); - if(!getTable) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + if (sysmanOnly) { + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR(driver.handle, "zesGetGlobalProcAddrTable")); + if(!getTable) { + if (debugTraceEnabled) { + std::string errorMessage = "init driver " + driver.name + " failed, zesGetGlobalProcAddrTable function pointer null. Returning "; + debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + } + return ZE_RESULT_ERROR_UNINITIALIZED; } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - - ze_global_dditable_t global; - auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global); - if(getTableResult != ZE_RESULT_SUCCESS) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() failed with "; - debug_trace_message(errorMessage, loader::to_string(getTableResult)); + + zes_global_dditable_t global; + auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global); + if(getTableResult != ZE_RESULT_SUCCESS) { + if (debugTraceEnabled) { + std::string errorMessage = "init driver " + driver.name + " failed, zesGetGlobalProcAddrTable() failed with "; + debug_trace_message(errorMessage, loader::to_string(getTableResult)); + } + return ZE_RESULT_ERROR_UNINITIALIZED; } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - - if(nullptr == global.pfnInit) { - if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + + if(nullptr == global.pfnInit) { + if (debugTraceEnabled) { + std::string errorMessage = "init driver " + driver.name + " failed, zesInit function pointer null. Returning "; + debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + } + return ZE_RESULT_ERROR_UNINITIALIZED; } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - auto pfnInit = global.pfnInit; - if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) { + // Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls + ze_result_t res = sysmanGlobalInitStored->pfnInit(flags); + // Verify that this driver successfully init in the call above. + if (driver.initStatus != ZE_RESULT_SUCCESS) { + res = driver.initStatus; + } if (debugTraceEnabled) { - std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning "; - debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + std::string message = "init driver " + driver.name + " zesInit(" + loader::to_string(flags) + ") returning "; + debug_trace_message(message, loader::to_string(res)); + } + return res; + } else { + auto getTable = reinterpret_cast( + GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable")); + if(!getTable) { + if (debugTraceEnabled) { + std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null. Returning "; + debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + } + return ZE_RESULT_ERROR_UNINITIALIZED; } - return ZE_RESULT_ERROR_UNINITIALIZED; - } - // Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls - ze_result_t res = globalInitStored->pfnInit(flags); - // Verify that this driver successfully init in the call above. - if (driver.initStatus != ZE_RESULT_SUCCESS) { - res = driver.initStatus; - } - if (debugTraceEnabled) { - std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning "; - debug_trace_message(message, loader::to_string(res)); + ze_global_dditable_t global; + auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global); + if(getTableResult != ZE_RESULT_SUCCESS) { + if (debugTraceEnabled) { + std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() failed with "; + debug_trace_message(errorMessage, loader::to_string(getTableResult)); + } + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + if(nullptr == global.pfnInit) { + if (debugTraceEnabled) { + std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning "; + debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + } + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + auto pfnInit = global.pfnInit; + if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) { + if (debugTraceEnabled) { + std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning "; + debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED)); + } + return ZE_RESULT_ERROR_UNINITIALIZED; + } + + // Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls + ze_result_t res = globalInitStored->pfnInit(flags); + // Verify that this driver successfully init in the call above. + if (driver.initStatus != ZE_RESULT_SUCCESS) { + res = driver.initStatus; + } + if (debugTraceEnabled) { + std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning "; + debug_trace_message(message, loader::to_string(res)); + } + return res; } - return res; } /////////////////////////////////////////////////////////////////////////////// diff --git a/source/loader/ze_loader_api.cpp b/source/loader/ze_loader_api.cpp index 46f3151..8897035 100644 --- a/source/loader/ze_loader_api.cpp +++ b/source/loader/ze_loader_api.cpp @@ -33,9 +33,9 @@ zeLoaderInit() /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED ZE_DLLEXPORT ze_result_t ZE_APICALL -zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit) +zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStore, bool *requireDdiReinit, bool sysmanOnly) { - return loader::context->check_drivers(flags, globalInitStored, requireDdiReinit); + return loader::context->check_drivers(flags, globalInitStored, sysmanGlobalInitStore, requireDdiReinit, sysmanOnly); } /////////////////////////////////////////////////////////////////////////////// diff --git a/source/loader/ze_loader_api.h b/source/loader/ze_loader_api.h index 9236556..63ad85d 100644 --- a/source/loader/ze_loader_api.h +++ b/source/loader/ze_loader_api.h @@ -33,7 +33,7 @@ zeLoaderInit(); /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED ZE_DLLEXPORT ze_result_t ZE_APICALL -zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit); +zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStore, bool *requireDdiReinit, bool sysmanOnly); /////////////////////////////////////////////////////////////////////////////// diff --git a/source/loader/ze_loader_internal.h b/source/loader/ze_loader_internal.h index d345dd9..92400ef 100644 --- a/source/loader/ze_loader_internal.h +++ b/source/loader/ze_loader_internal.h @@ -110,10 +110,10 @@ namespace loader std::vector compVersions; const char *LOADER_COMP_NAME = "loader"; - ze_result_t check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit); + ze_result_t check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly); void debug_trace_message(std::string errorMessage, std::string errorValue); ze_result_t init(); - ze_result_t init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored); + ze_result_t init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly); void add_loader_version(); ~context_t(); bool intercept_enabled = false; From 3faf300802cef7de8e9dfc28d2069dde78f2c2ea Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Thu, 22 Aug 2024 12:24:40 -0700 Subject: [PATCH 2/5] Fix image/sampler handle map for thread safety - Added lock for usage of image/sampler map for thread safe operation. Signed-off-by: Neil R. Spruit --- scripts/templates/ldrddi.cpp.mako | 36 ++++++++++++++++++++------ source/loader/ze_ldrddi.cpp | 41 +++++++++++++++++++++++------- source/loader/ze_loader_internal.h | 2 ++ 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index d16e1f5..b91e4d0 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -135,11 +135,21 @@ namespace loader %else: %if re.match(r"\w+ImageDestroy$", th.make_func_name(n, tags, obj)): // remove the handle from the kernel arugment map - context->image_handle_map.erase(reinterpret_cast(hImage)); + { + std::lock_guard lock(context->image_handle_map_lock); + if( context->image_handle_map.find(reinterpret_cast(hImage)) != context->image_handle_map.end() ) { + context->image_handle_map.erase(reinterpret_cast(hImage)); + } + } %endif %if re.match(r"\w+SamplerDestroy$", th.make_func_name(n, tags, obj)): // remove the handle from the kernel arugment map - context->sampler_handle_map.erase(reinterpret_cast(hSampler)); + { + std::lock_guard lock(context->sampler_handle_map_lock); + if( context->sampler_handle_map.find(reinterpret_cast(hSampler)) != context->sampler_handle_map.end() ) { + context->sampler_handle_map.erase(reinterpret_cast(hSampler)); + } + } %endif // convert loader handle to driver handle ${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle; @@ -154,10 +164,14 @@ namespace loader // check if the arg value is a translated handle ze_image_object_t **imageHandle = static_cast(internalArgValue); ze_sampler_object_t **samplerHandle = static_cast(internalArgValue); - if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) { - internalArgValue = &context->image_handle_map[*imageHandle]; - } else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) { - internalArgValue = &context->sampler_handle_map[*samplerHandle]; + { + std::lock_guard image_lock(context->image_handle_map_lock); + std::lock_guard sampler_lock(context->sampler_handle_map_lock); + if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) { + internalArgValue = &context->image_handle_map[*imageHandle]; + } else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) { + internalArgValue = &context->sampler_handle_map[*samplerHandle]; + } } } %endif @@ -210,11 +224,17 @@ namespace loader context->${item['factory']}.getInstance( *${item['name']}, dditable ) ); %if re.match(r"\w+ImageCreate$", th.make_func_name(n, tags, obj)) or re.match(r"\w+ImageViewCreateExp$", th.make_func_name(n, tags, obj)): // convert loader handle to driver handle and store in map - context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + { + std::lock_guard lock(context->image_handle_map_lock); + context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + } %endif %if re.match(r"\w+SamplerCreate$", th.make_func_name(n, tags, obj)): // convert loader handle to driver handle and store in map - context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + { + std::lock_guard lock(context->sampler_handle_map_lock); + context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + } %endif %endif %endif diff --git a/source/loader/ze_ldrddi.cpp b/source/loader/ze_ldrddi.cpp index 056ccc7..464f5af 100644 --- a/source/loader/ze_ldrddi.cpp +++ b/source/loader/ze_ldrddi.cpp @@ -2935,7 +2935,10 @@ namespace loader *phImage = reinterpret_cast( context->ze_image_factory.getInstance( *phImage, dditable ) ); // convert loader handle to driver handle and store in map - context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + { + std::lock_guard lock(context->image_handle_map_lock); + context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + } } catch( std::bad_alloc& ) { @@ -2961,7 +2964,12 @@ namespace loader return ZE_RESULT_ERROR_UNINITIALIZED; // remove the handle from the kernel arugment map - context->image_handle_map.erase(reinterpret_cast(hImage)); + { + std::lock_guard lock(context->image_handle_map_lock); + if( context->image_handle_map.find(reinterpret_cast(hImage)) != context->image_handle_map.end() ) { + context->image_handle_map.erase(reinterpret_cast(hImage)); + } + } // convert loader handle to driver handle hImage = reinterpret_cast( hImage )->handle; @@ -3897,10 +3905,14 @@ namespace loader // check if the arg value is a translated handle ze_image_object_t **imageHandle = static_cast(internalArgValue); ze_sampler_object_t **samplerHandle = static_cast(internalArgValue); - if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) { - internalArgValue = &context->image_handle_map[*imageHandle]; - } else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) { - internalArgValue = &context->sampler_handle_map[*samplerHandle]; + { + std::lock_guard image_lock(context->image_handle_map_lock); + std::lock_guard sampler_lock(context->sampler_handle_map_lock); + if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) { + internalArgValue = &context->image_handle_map[*imageHandle]; + } else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) { + internalArgValue = &context->sampler_handle_map[*samplerHandle]; + } } } // forward to device-driver @@ -4412,7 +4424,10 @@ namespace loader *phSampler = reinterpret_cast( context->ze_sampler_factory.getInstance( *phSampler, dditable ) ); // convert loader handle to driver handle and store in map - context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + { + std::lock_guard lock(context->sampler_handle_map_lock); + context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + } } catch( std::bad_alloc& ) { @@ -4438,7 +4453,12 @@ namespace loader return ZE_RESULT_ERROR_UNINITIALIZED; // remove the handle from the kernel arugment map - context->sampler_handle_map.erase(reinterpret_cast(hSampler)); + { + std::lock_guard lock(context->sampler_handle_map_lock); + if( context->sampler_handle_map.find(reinterpret_cast(hSampler)) != context->sampler_handle_map.end() ) { + context->sampler_handle_map.erase(reinterpret_cast(hSampler)); + } + } // convert loader handle to driver handle hSampler = reinterpret_cast( hSampler )->handle; @@ -4968,7 +4988,10 @@ namespace loader *phImageView = reinterpret_cast( context->ze_image_factory.getInstance( *phImageView, dditable ) ); // convert loader handle to driver handle and store in map - context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + { + std::lock_guard lock(context->image_handle_map_lock); + context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr}); + } } catch( std::bad_alloc& ) { diff --git a/source/loader/ze_loader_internal.h b/source/loader/ze_loader_internal.h index 92400ef..e704d4c 100644 --- a/source/loader/ze_loader_internal.h +++ b/source/loader/ze_loader_internal.h @@ -96,6 +96,8 @@ namespace loader zet_debug_session_factory_t zet_debug_session_factory; zet_metric_programmable_exp_factory_t zet_metric_programmable_exp_factory; /////////////////////////////////////////////////////////////////////////////// + std::mutex image_handle_map_lock; + std::mutex sampler_handle_map_lock; std::unordered_map image_handle_map; std::unordered_map sampler_handle_map; ze_api_version_t version = ZE_API_VERSION_CURRENT; From 35bd00950e72d40f12b4002d5f4092f42b6c60c5 Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Thu, 22 Aug 2024 12:26:50 -0700 Subject: [PATCH 3/5] Fix zetCommandListAppendMetricQueryEnd for phWaitEvent handle translation - Fixed zetCommandListAppendMetricQueryEnd to properly translate handles. Signed-off-by: Neil R. Spruit --- scripts/templates/ldrddi.cpp.mako | 11 +++++++++++ source/loader/zet_ldrddi.cpp | 8 +++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index b91e4d0..579317c 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -175,6 +175,16 @@ namespace loader } } %endif + %if re.match(r"\w+CommandListAppendMetricQueryEnd$", th.make_func_name(n, tags, obj)): + // convert loader handles to driver handles + auto phWaitEventsLocal = new ze_event_handle_t [numWaitEvents]; + for( size_t i = 0; ( nullptr != phWaitEvents ) && ( i < numWaitEvents ); ++i ) + phWaitEventsLocal[ i ] = reinterpret_cast( phWaitEvents[ i ] )->handle; + + // forward to device-driver + result = pfnAppendMetricQueryEnd( hCommandList, hMetricQuery, hSignalEvent, numWaitEvents, phWaitEventsLocal ); + delete []phWaitEventsLocal; + %else: // forward to device-driver %if add_local: result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name", "local"]))} ); @@ -188,6 +198,7 @@ namespace loader result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} ); %endif %endif + %endif <% del arrays_to_delete del add_local%> diff --git a/source/loader/zet_ldrddi.cpp b/source/loader/zet_ldrddi.cpp index 84ef5c8..7600426 100644 --- a/source/loader/zet_ldrddi.cpp +++ b/source/loader/zet_ldrddi.cpp @@ -1026,8 +1026,14 @@ namespace loader // convert loader handle to driver handle hSignalEvent = ( hSignalEvent ) ? reinterpret_cast( hSignalEvent )->handle : nullptr; + // convert loader handles to driver handles + auto phWaitEventsLocal = new ze_event_handle_t [numWaitEvents]; + for( size_t i = 0; ( nullptr != phWaitEvents ) && ( i < numWaitEvents ); ++i ) + phWaitEventsLocal[ i ] = reinterpret_cast( phWaitEvents[ i ] )->handle; + // forward to device-driver - result = pfnAppendMetricQueryEnd( hCommandList, hMetricQuery, hSignalEvent, numWaitEvents, phWaitEvents ); + result = pfnAppendMetricQueryEnd( hCommandList, hMetricQuery, hSignalEvent, numWaitEvents, phWaitEventsLocal ); + delete []phWaitEventsLocal; return result; } From e6eec45e21b266264e625eb4cba359c6fc658253 Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Thu, 22 Aug 2024 15:08:55 -0700 Subject: [PATCH 4/5] Fix map erase, typo, and variable name inconsistency Signed-off-by: Neil R. Spruit --- scripts/templates/ldrddi.cpp.mako | 8 ++------ source/loader/ze_ldrddi.cpp | 8 ++------ source/loader/ze_loader.cpp | 2 +- source/loader/ze_loader_api.cpp | 4 ++-- source/loader/ze_loader_api.h | 2 +- 5 files changed, 8 insertions(+), 16 deletions(-) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 579317c..ab5127e 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -137,18 +137,14 @@ namespace loader // remove the handle from the kernel arugment map { std::lock_guard lock(context->image_handle_map_lock); - if( context->image_handle_map.find(reinterpret_cast(hImage)) != context->image_handle_map.end() ) { - context->image_handle_map.erase(reinterpret_cast(hImage)); - } + context->image_handle_map.erase(reinterpret_cast(hImage)); } %endif %if re.match(r"\w+SamplerDestroy$", th.make_func_name(n, tags, obj)): // remove the handle from the kernel arugment map { std::lock_guard lock(context->sampler_handle_map_lock); - if( context->sampler_handle_map.find(reinterpret_cast(hSampler)) != context->sampler_handle_map.end() ) { - context->sampler_handle_map.erase(reinterpret_cast(hSampler)); - } + context->sampler_handle_map.erase(reinterpret_cast(hSampler)); } %endif // convert loader handle to driver handle diff --git a/source/loader/ze_ldrddi.cpp b/source/loader/ze_ldrddi.cpp index 464f5af..c875cf2 100644 --- a/source/loader/ze_ldrddi.cpp +++ b/source/loader/ze_ldrddi.cpp @@ -2966,9 +2966,7 @@ namespace loader // remove the handle from the kernel arugment map { std::lock_guard lock(context->image_handle_map_lock); - if( context->image_handle_map.find(reinterpret_cast(hImage)) != context->image_handle_map.end() ) { - context->image_handle_map.erase(reinterpret_cast(hImage)); - } + context->image_handle_map.erase(reinterpret_cast(hImage)); } // convert loader handle to driver handle hImage = reinterpret_cast( hImage )->handle; @@ -4455,9 +4453,7 @@ namespace loader // remove the handle from the kernel arugment map { std::lock_guard lock(context->sampler_handle_map_lock); - if( context->sampler_handle_map.find(reinterpret_cast(hSampler)) != context->sampler_handle_map.end() ) { - context->sampler_handle_map.erase(reinterpret_cast(hSampler)); - } + context->sampler_handle_map.erase(reinterpret_cast(hSampler)); } // convert loader handle to driver handle hSampler = reinterpret_cast( hSampler )->handle; diff --git a/source/loader/ze_loader.cpp b/source/loader/ze_loader.cpp index 381e1c0..0a8d2a4 100644 --- a/source/loader/ze_loader.cpp +++ b/source/loader/ze_loader.cpp @@ -205,7 +205,7 @@ namespace loader return ZE_RESULT_ERROR_UNINITIALIZED; } - // Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls + // Use the previously init ddi table pointer to zesInit to allow for intercept of the zesInit calls ze_result_t res = sysmanGlobalInitStored->pfnInit(flags); // Verify that this driver successfully init in the call above. if (driver.initStatus != ZE_RESULT_SUCCESS) { diff --git a/source/loader/ze_loader_api.cpp b/source/loader/ze_loader_api.cpp index 8897035..d86d92c 100644 --- a/source/loader/ze_loader_api.cpp +++ b/source/loader/ze_loader_api.cpp @@ -33,9 +33,9 @@ zeLoaderInit() /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED ZE_DLLEXPORT ze_result_t ZE_APICALL -zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStore, bool *requireDdiReinit, bool sysmanOnly) +zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly) { - return loader::context->check_drivers(flags, globalInitStored, sysmanGlobalInitStore, requireDdiReinit, sysmanOnly); + return loader::context->check_drivers(flags, globalInitStored, sysmanGlobalInitStored, requireDdiReinit, sysmanOnly); } /////////////////////////////////////////////////////////////////////////////// diff --git a/source/loader/ze_loader_api.h b/source/loader/ze_loader_api.h index 63ad85d..590f143 100644 --- a/source/loader/ze_loader_api.h +++ b/source/loader/ze_loader_api.h @@ -33,7 +33,7 @@ zeLoaderInit(); /// - ::ZE_RESULT_SUCCESS /// - ::ZE_RESULT_ERROR_UNINITIALIZED ZE_DLLEXPORT ze_result_t ZE_APICALL -zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStore, bool *requireDdiReinit, bool sysmanOnly); +zelLoaderDriverCheck(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly); /////////////////////////////////////////////////////////////////////////////// From e76e74b733b6446b0d3015034b19b40af2827ff9 Mon Sep 17 00:00:00 2001 From: "Neil R. Spruit" Date: Thu, 22 Aug 2024 15:12:06 -0700 Subject: [PATCH 5/5] Add comment on workaround for zeCommandListAppendMetricQueryEnd Signed-off-by: Neil R. Spruit --- scripts/templates/ldrddi.cpp.mako | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index ab5127e..89d593d 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -171,6 +171,8 @@ namespace loader } } %endif + ## Workaround due to incorrect defintion of phWaitEvents in the ze headers which missed the range values. + ## To be removed once the headers have been updated in a new spec release. %if re.match(r"\w+CommandListAppendMetricQueryEnd$", th.make_func_name(n, tags, obj)): // convert loader handles to driver handles auto phWaitEventsLocal = new ze_event_handle_t [numWaitEvents];