Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with intercept layer and zesInit/zeInit given fallback to passthrough #183

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,21 @@ namespace loader
%else:
%if re.match(r"\w+ImageDestroy$", th.make_func_name(n, tags, obj)):
// remove the handle from the kernel arugment map
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
if( context->image_handle_map.find(reinterpret_cast<ze_image_object_t*>(hImage)) != context->image_handle_map.end() ) {
nrspruit marked this conversation as resolved.
Show resolved Hide resolved
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
}
}
%endif
%if re.match(r"\w+SamplerDestroy$", th.make_func_name(n, tags, obj)):
// remove the handle from the kernel arugment map
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
if( context->sampler_handle_map.find(reinterpret_cast<ze_sampler_object_t*>(hSampler)) != context->sampler_handle_map.end() ) {
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
}
}
%endif
// convert loader handle to driver handle
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
Expand All @@ -154,13 +164,27 @@ namespace loader
// check if the arg value is a translated handle
ze_image_object_t **imageHandle = static_cast<ze_image_object_t **>(internalArgValue);
ze_sampler_object_t **samplerHandle = static_cast<ze_sampler_object_t **>(internalArgValue);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
{
std::lock_guard<std::mutex> image_lock(context->image_handle_map_lock);
std::lock_guard<std::mutex> sampler_lock(context->sampler_handle_map_lock);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
}
}
}
%endif
%if re.match(r"\w+CommandListAppendMetricQueryEnd$", th.make_func_name(n, tags, obj)):
nrspruit marked this conversation as resolved.
Show resolved Hide resolved
// convert loader handles to driver handles
auto phWaitEventsLocal = new ze_event_handle_t [numWaitEvents];
for( size_t i = 0; ( nullptr != phWaitEvents ) && ( i < numWaitEvents ); ++i )
phWaitEventsLocal[ i ] = reinterpret_cast<ze_event_object_t*>( phWaitEvents[ i ] )->handle;

// forward to device-driver
result = pfnAppendMetricQueryEnd( hCommandList, hMetricQuery, hSignalEvent, numWaitEvents, phWaitEventsLocal );
delete []phWaitEventsLocal;
%else:
// forward to device-driver
%if add_local:
result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name", "local"]))} );
Expand All @@ -174,6 +198,7 @@ namespace loader
result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
%endif
%endif
%endif
<%
del arrays_to_delete
del add_local%>
Expand Down Expand Up @@ -210,11 +235,17 @@ namespace loader
context->${item['factory']}.getInstance( *${item['name']}, dditable ) );
%if re.match(r"\w+ImageCreate$", th.make_func_name(n, tags, obj)) or re.match(r"\w+ImageViewCreateExp$", th.make_func_name(n, tags, obj)):
// convert loader handle to driver handle and store in map
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
%endif
%if re.match(r"\w+SamplerCreate$", th.make_func_name(n, tags, obj)):
// convert loader handle to driver handle and store in map
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
%endif
%endif
%endif
Expand Down
6 changes: 4 additions & 2 deletions source/lib/ze_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ namespace ze_lib
// End DDI Table Inits

// Check which drivers and layers can be init on this system.
if( ZE_RESULT_SUCCESS == result && !sysmanOnly)
// If the driver check has already been called by zesInit or zeinit, then this is skipped.
if( ZE_RESULT_SUCCESS == result && !driverCheckCompleted)
{
// Check which drivers support the ze_driver_flag_t specified
// No need to check if only initializing sysman
bool requireDdiReinit = false;
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &requireDdiReinit);
result = zelLoaderDriverCheck(flags, &ze_lib::context->initialzeDdiTable.Global, &ze_lib::context->initialzesDdiTable.Global, &requireDdiReinit, sysmanOnly);
// If a driver was removed from the driver list, then the ddi tables need to be reinit to allow for passthru directly to the driver.
// If ZET_ENABLE_PROGRAM_INSTRUMENTATION is enabled, then reInit is not possible due to the functions being intercepted with the previous ddi tables.
auto programInstrumentationEnabled = getenv_tobool( "ZET_ENABLE_PROGRAM_INSTRUMENTATION" );
Expand All @@ -130,6 +131,7 @@ namespace ze_lib
result = zesDdiTableInit();
}
}
driverCheckCompleted = true;
}

if( ZE_RESULT_SUCCESS == result )
Expand Down
1 change: 1 addition & 0 deletions source/lib/ze_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace ze_lib

std::once_flag initOnce;
std::once_flag initOnceSysMan;
bool driverCheckCompleted = false;

ze_result_t Init(ze_init_flags_t flags, bool sysmanOnly);

Expand Down
41 changes: 32 additions & 9 deletions source/loader/ze_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2935,7 +2935,10 @@ namespace loader
*phImage = reinterpret_cast<ze_image_handle_t>(
context->ze_image_factory.getInstance( *phImage, dditable ) );
// convert loader handle to driver handle and store in map
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
}
catch( std::bad_alloc& )
{
Expand All @@ -2961,7 +2964,12 @@ namespace loader
return ZE_RESULT_ERROR_UNINITIALIZED;

// remove the handle from the kernel arugment map
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
if( context->image_handle_map.find(reinterpret_cast<ze_image_object_t*>(hImage)) != context->image_handle_map.end() ) {
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
}
}
// convert loader handle to driver handle
hImage = reinterpret_cast<ze_image_object_t*>( hImage )->handle;

Expand Down Expand Up @@ -3897,10 +3905,14 @@ namespace loader
// check if the arg value is a translated handle
ze_image_object_t **imageHandle = static_cast<ze_image_object_t **>(internalArgValue);
ze_sampler_object_t **samplerHandle = static_cast<ze_sampler_object_t **>(internalArgValue);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
{
std::lock_guard<std::mutex> image_lock(context->image_handle_map_lock);
std::lock_guard<std::mutex> sampler_lock(context->sampler_handle_map_lock);
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
}
}
}
// forward to device-driver
Expand Down Expand Up @@ -4412,7 +4424,10 @@ namespace loader
*phSampler = reinterpret_cast<ze_sampler_handle_t>(
context->ze_sampler_factory.getInstance( *phSampler, dditable ) );
// convert loader handle to driver handle and store in map
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
}
catch( std::bad_alloc& )
{
Expand All @@ -4438,7 +4453,12 @@ namespace loader
return ZE_RESULT_ERROR_UNINITIALIZED;

// remove the handle from the kernel arugment map
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
{
std::lock_guard<std::mutex> lock(context->sampler_handle_map_lock);
if( context->sampler_handle_map.find(reinterpret_cast<ze_sampler_object_t*>(hSampler)) != context->sampler_handle_map.end() ) {
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
}
}
// convert loader handle to driver handle
hSampler = reinterpret_cast<ze_sampler_object_t*>( hSampler )->handle;

Expand Down Expand Up @@ -4968,7 +4988,10 @@ namespace loader
*phImageView = reinterpret_cast<ze_image_handle_t>(
context->ze_image_factory.getInstance( *phImageView, dditable ) );
// convert loader handle to driver handle and store in map
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
{
std::lock_guard<std::mutex> lock(context->image_handle_map_lock);
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
}
}
catch( std::bad_alloc& )
{
Expand Down
Loading
Loading