Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with intercept layer and zesInit/zeInit given fallback to passthrough #183

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,17 @@ namespace loader
%else:
%if re.match(r"\w+ImageDestroy$", th.make_func_name(n, tags, obj)):
// remove the handle from the kernel arugment map
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
}
%endif
%if re.match(r"\w+SamplerDestroy$", th.make_func_name(n, tags, obj)):
// remove the handle from the kernel arugment map
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
}
%endif
// convert loader handle to driver handle
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
Expand All @@ -154,13 +160,29 @@ namespace loader
// check if the arg value is a translated handle
ze_image_object_t **imageHandle = static_cast<ze_image_object_t **>(internalArgValue);
ze_sampler_object_t **samplerHandle = static_cast<ze_sampler_object_t **>(internalArgValue);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
{
std::lock_guard<std::mutex> image_lock(context->image_handle_map_lock);
std::lock_guard<std::mutex> sampler_lock(context->sampler_handle_map_lock);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
}
}
}
%endif
## Workaround due to incorrect defintion of phWaitEvents in the ze headers which missed the range values.
## To be removed once the headers have been updated in a new spec release.
%if re.match(r"\w+CommandListAppendMetricQueryEnd$", th.make_func_name(n, tags, obj)):
nrspruit marked this conversation as resolved.
Show resolved Hide resolved
// convert loader handles to driver handles
auto phWaitEventsLocal = new ze_event_handle_t [numWaitEvents];
for( size_t i = 0; ( nullptr != phWaitEvents ) && ( i < numWaitEvents ); ++i )
phWaitEventsLocal[ i ] = reinterpret_cast<ze_event_object_t*>( phWaitEvents[ i ] )->handle;

// forward to device-driver
result = pfnAppendMetricQueryEnd( hCommandList, hMetricQuery, hSignalEvent, numWaitEvents, phWaitEventsLocal );
delete []phWaitEventsLocal;
%else:
// forward to device-driver
%if add_local:
result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name", "local"]))} );
Expand All @@ -174,6 +196,7 @@ namespace loader
result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
%endif
%endif
%endif
<%
del arrays_to_delete
del add_local%>
Expand Down Expand Up @@ -210,11 +233,17 @@ namespace loader
context->${item['factory']}.getInstance( *${item['name']}, dditable ) );
%if re.match(r"\w+ImageCreate$", th.make_func_name(n, tags, obj)) or re.match(r"\w+ImageViewCreateExp$", th.make_func_name(n, tags, obj)):
// convert loader handle to driver handle and store in map
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
%endif
%if re.match(r"\w+SamplerCreate$", th.make_func_name(n, tags, obj)):
// convert loader handle to driver handle and store in map
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
%endif
%endif
%endif
Expand Down
6 changes: 4 additions & 2 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ namespace ze_lib
// End DDI Table Inits

// Check which drivers and layers can be init on this system.
if( ZE_RESULT_SUCCESS == result && !sysmanOnly)
// If the driver check has already been called by zesInit or zeinit, then this is skipped.
if( ZE_RESULT_SUCCESS == result && !driverCheckCompleted)
{
// Check which drivers support the ze_driver_flag_t specified
// No need to check if only initializing sysman
bool requireDdiReinit = false;
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &requireDdiReinit);
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
// If a driver was removed from the driver list, then the ddi tables need to be reinit to allow for passthru directly to the driver.
// If ZET_ENABLE_PROGRAM_INSTRUMENTATION is enabled, then reInit is not possible due to the functions being intercepted with the previous ddi tables.
auto programInstrumentationEnabled = getenv_tobool( "ZET_ENABLE_PROGRAM_INSTRUMENTATION" );
Expand All @@ -130,6 +131,7 @@ namespace ze_lib
result = zesDdiTableInit();
}
}
driverCheckCompleted = true;
}

if( ZE_RESULT_SUCCESS == result )
Expand Down
1 change: 1 addition & 0 deletions source/lib/ze_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace ze_lib

std::once_flag initOnce;
std::once_flag initOnceSysMan;
bool driverCheckCompleted = false;

ze_result_t Init(ze_init_flags_t flags, bool sysmanOnly);

Expand Down
37 changes: 28 additions & 9 deletions source/loader/ze_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2935,7 +2935,10 @@ namespace loader
*phImage = reinterpret_cast<ze_image_handle_t>(
context->ze_image_factory.getInstance( *phImage, dditable ) );
// convert loader handle to driver handle and store in map
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
}
catch( std::bad_alloc& )
{
Expand All @@ -2961,7 +2964,10 @@ namespace loader
return ZE_RESULT_ERROR_UNINITIALIZED;

// remove the handle from the kernel arugment map
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
}
// convert loader handle to driver handle
hImage = reinterpret_cast<ze_image_object_t*>( hImage )->handle;

Expand Down Expand Up @@ -3897,10 +3903,14 @@ namespace loader
// check if the arg value is a translated handle
ze_image_object_t **imageHandle = static_cast<ze_image_object_t **>(internalArgValue);
ze_sampler_object_t **samplerHandle = static_cast<ze_sampler_object_t **>(internalArgValue);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
{
std::lock_guard<std::mutex> image_lock(context->image_handle_map_lock);
std::lock_guard<std::mutex> sampler_lock(context->sampler_handle_map_lock);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
}
}
}
// forward to device-driver
Expand Down Expand Up @@ -4412,7 +4422,10 @@ namespace loader
*phSampler = reinterpret_cast<ze_sampler_handle_t>(
context->ze_sampler_factory.getInstance( *phSampler, dditable ) );
// convert loader handle to driver handle and store in map
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
}
catch( std::bad_alloc& )
{
Expand All @@ -4438,7 +4451,10 @@ namespace loader
return ZE_RESULT_ERROR_UNINITIALIZED;

// remove the handle from the kernel arugment map
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
}
// convert loader handle to driver handle
hSampler = reinterpret_cast<ze_sampler_object_t*>( hSampler )->handle;

Expand Down Expand Up @@ -4968,7 +4984,10 @@ namespace loader
*phImageView = reinterpret_cast<ze_image_handle_t>(
context->ze_image_factory.getInstance( *phImageView, dditable ) );
// convert loader handle to driver handle and store in map
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
}
catch( std::bad_alloc& )
{
Expand Down
128 changes: 85 additions & 43 deletions source/loader/ze_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ namespace loader
}
}

ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, bool *requireDdiReinit) {
ze_result_t context_t::check_drivers(ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool *requireDdiReinit, bool sysmanOnly) {
if (debugTraceEnabled) {
std::string message = "check_drivers(" + std::string("flags=") + loader::to_string(flags) + ")";
debug_trace_message(message, "");
Expand All @@ -137,7 +137,7 @@ namespace loader
for(auto it = drivers.begin(); it != drivers.end(); )
{
std::string freeLibraryErrorValue;
ze_result_t result = init_driver(*it, flags, globalInitStored);
ze_result_t result = init_driver(*it, flags, globalInitStored, sysmanGlobalInitStored, sysmanOnly);
if(result != ZE_RESULT_SUCCESS) {
if (it->handle) {
auto free_result = FREE_DRIVER_LIBRARY(it->handle);
Expand Down Expand Up @@ -174,56 +174,98 @@ namespace loader
return ZE_RESULT_SUCCESS;
}

ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored) {
ze_result_t context_t::init_driver(driver_t driver, ze_init_flags_t flags, ze_global_dditable_t *globalInitStored, zes_global_dditable_t *sysmanGlobalInitStored, bool sysmanOnly) {

auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable"));
if(!getTable) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
if (sysmanOnly) {
auto getTable = reinterpret_cast<zes_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(driver.handle, "zesGetGlobalProcAddrTable"));
if(!getTable) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zesGetGlobalProcAddrTable function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

ze_global_dditable_t global;
auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global);
if(getTableResult != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() failed with ";
debug_trace_message(errorMessage, loader::to_string(getTableResult));

zes_global_dditable_t global;
auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global);
if(getTableResult != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zesGetGlobalProcAddrTable() failed with ";
debug_trace_message(errorMessage, loader::to_string(getTableResult));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

if(nullptr == global.pfnInit) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));

if(nullptr == global.pfnInit) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zesInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

auto pfnInit = global.pfnInit;
if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) {
// Use the previously init ddi table pointer to zesInit to allow for intercept of the zesInit calls
ze_result_t res = sysmanGlobalInitStored->pfnInit(flags);
// Verify that this driver successfully init in the call above.
if (driver.initStatus != ZE_RESULT_SUCCESS) {
res = driver.initStatus;
}
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
std::string message = "init driver " + driver.name + " zesInit(" + loader::to_string(flags) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return res;
} else {
auto getTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
GET_FUNCTION_PTR(driver.handle, "zeGetGlobalProcAddrTable"));
if(!getTable) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

// Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls
ze_result_t res = globalInitStored->pfnInit(flags);
// Verify that this driver successfully init in the call above.
if (driver.initStatus != ZE_RESULT_SUCCESS) {
res = driver.initStatus;
}
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning ";
debug_trace_message(message, loader::to_string(res));
ze_global_dditable_t global;
auto getTableResult = getTable(ZE_API_VERSION_CURRENT, &global);
if(getTableResult != ZE_RESULT_SUCCESS) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeGetGlobalProcAddrTable() failed with ";
debug_trace_message(errorMessage, loader::to_string(getTableResult));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

if(nullptr == global.pfnInit) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

auto pfnInit = global.pfnInit;
if(nullptr == pfnInit || globalInitStored->pfnInit == nullptr) {
if (debugTraceEnabled) {
std::string errorMessage = "init driver " + driver.name + " failed, zeInit function pointer null. Returning ";
debug_trace_message(errorMessage, loader::to_string(ZE_RESULT_ERROR_UNINITIALIZED));
}
return ZE_RESULT_ERROR_UNINITIALIZED;
}

// Use the previously init ddi table pointer to zeInit to allow for intercept of the zeInit calls
ze_result_t res = globalInitStored->pfnInit(flags);
// Verify that this driver successfully init in the call above.
if (driver.initStatus != ZE_RESULT_SUCCESS) {
res = driver.initStatus;
}
if (debugTraceEnabled) {
std::string message = "init driver " + driver.name + " zeInit(" + loader::to_string(flags) + ") returning ";
debug_trace_message(message, loader::to_string(res));
}
return res;
}
return res;
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
Loading
Loading