Skip to content

Commit

Permalink
Fix global teardown of loader handles and check driver status in init…
Browse files Browse the repository at this point in the history
…_driver

- Moved all loader handle maps to within the loader context to ensure the
  maps are init and destroyed with library init/destroy vs at start of
atexit.
- Check result of individual driver for zeInit during init_driver().

Signed-off-by: Neil R. Spruit <neil.r.spruit@intel.com>
  • Loading branch information
nrspruit committed Aug 20, 2024
1 parent c1f6e28 commit 209af0b
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 189 deletions.
42 changes: 13 additions & 29 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,6 @@ from templates import helper as th

namespace loader
{
///////////////////////////////////////////////////////////////////////////////
%for obj in th.extract_objs(specs, r"handle"):
%if 'class' in obj:
<%
_handle_t = th.subt(n, tags, obj['name'])
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
%>${th.append_ws(_factory_t, 35)} ${_factory};
%endif
%endfor
%if re.match(r"ze_ldrddi", name):
///////////////////////////////////////////////////////////////////////////////
std::unordered_map<ze_image_object_t *, ze_image_handle_t> image_handle_map;
std::unordered_map<ze_sampler_object_t *, ze_sampler_handle_t> sampler_handle_map;
%endif

%for obj in th.extract_objs(specs, r"function"):
///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
Expand Down Expand Up @@ -103,7 +87,7 @@ namespace loader
for( uint32_t i = 0; i < library_driver_handle_count; ++i ) {
uint32_t driver_index = total_driver_handle_count + i;
${obj['params'][1]['name']}[ driver_index ] = reinterpret_cast<${n}_driver_handle_t>(
${n}_driver_factory.getInstance( ${obj['params'][1]['name']}[ driver_index ], &drv.dditable ) );
context->${n}_driver_factory.getInstance( ${obj['params'][1]['name']}[ driver_index ], &drv.dditable ) );
}
}
catch( std::bad_alloc& )
Expand Down Expand Up @@ -151,11 +135,11 @@ namespace loader
%else:
%if re.match(r"\w+ImageDestroy$", th.make_func_name(n, tags, obj)):
// remove the handle from the kernel arugment map
image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
context->image_handle_map.erase(reinterpret_cast<ze_image_object_t*>(hImage));
%endif
%if re.match(r"\w+SamplerDestroy$", th.make_func_name(n, tags, obj)):
// remove the handle from the kernel arugment map
sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
context->sampler_handle_map.erase(reinterpret_cast<ze_sampler_object_t*>(hSampler));
%endif
// convert loader handle to driver handle
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
Expand All @@ -170,10 +154,10 @@ namespace loader
// check if the arg value is a translated handle
ze_image_object_t **imageHandle = static_cast<ze_image_object_t **>(internalArgValue);
ze_sampler_object_t **samplerHandle = static_cast<ze_sampler_object_t **>(internalArgValue);
if( image_handle_map.find(*imageHandle) != image_handle_map.end() ) {
internalArgValue = &image_handle_map[*imageHandle];
} else if( sampler_handle_map.find(*samplerHandle) != sampler_handle_map.end() ) {
internalArgValue = &sampler_handle_map[*samplerHandle];
if( context->image_handle_map.find(*imageHandle) != context->image_handle_map.end() ) {
internalArgValue = &context->image_handle_map[*imageHandle];
} else if( context->sampler_handle_map.find(*samplerHandle) != context->sampler_handle_map.end() ) {
internalArgValue = &context->sampler_handle_map[*samplerHandle];
}
}
%endif
Expand Down Expand Up @@ -203,34 +187,34 @@ namespace loader
%endif
%if item['release']:
// release loader handle
${item['factory']}.release( ${item['name']} );
context->${item['factory']}.release( ${item['name']} );
%else:
try
{
%if 'range' in item:
// convert driver handles to loader handles
for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i )
${item['name']}[ i ] = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
context->${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
%else:
// convert driver handle to loader handle
%if item['optional']:
if( nullptr != ${item['name']} )
*${item['name']} = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( *${item['name']}, dditable ) );
context->${item['factory']}.getInstance( *${item['name']}, dditable ) );
%else:
%if re.match(r"\w+ImageCreate$", th.make_func_name(n, tags, obj)) or re.match(r"\w+SamplerCreate$", th.make_func_name(n, tags, obj)) or re.match(r"\w+ImageViewCreateExp$", th.make_func_name(n, tags, obj)):
${item['type']} internalHandlePtr = *${item['name']};
%endif
*${item['name']} = reinterpret_cast<${item['type']}>(
${item['factory']}.getInstance( *${item['name']}, dditable ) );
context->${item['factory']}.getInstance( *${item['name']}, dditable ) );
%if re.match(r"\w+ImageCreate$", th.make_func_name(n, tags, obj)) or re.match(r"\w+ImageViewCreateExp$", th.make_func_name(n, tags, obj)):
// convert loader handle to driver handle and store in map
image_handle_map.insert({ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
context->image_handle_map.insert({context->ze_image_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
%endif
%if re.match(r"\w+SamplerCreate$", th.make_func_name(n, tags, obj)):
// convert loader handle to driver handle and store in map
sampler_handle_map.insert({ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
context->sampler_handle_map.insert({context->ze_sampler_factory.getInstance( internalHandlePtr, dditable ), internalHandlePtr});
%endif
%endif
%endif
Expand Down
Loading

0 comments on commit 209af0b

Please sign in to comment.