Skip to content

Commit

Permalink
Merge pull request #1067 from alexbatashev/fix_native_handles
Browse files Browse the repository at this point in the history
[UR][Loader] Fix handling of native handles
  • Loading branch information
kbenzie authored Dec 11, 2023
2 parents 0e281bc + 8b1bfc9 commit b25bb64
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 122 deletions.
9 changes: 7 additions & 2 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,17 @@ namespace ur_loader
%else:
<%param_replacements={}%>
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
%if 0 == i:
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
// extract platform's function pointer table
auto dditable = reinterpret_cast<${item['obj']}*>( ${item['pointer']}${item['name']} )->dditable;
auto ${th.make_pfn_name(n, tags, obj)} = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
return ${X}_RESULT_ERROR_UNINITIALIZED;
<%break%>
%endif
%endfor
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
%if 'range' in item:
<%
add_local = True
Expand All @@ -146,13 +149,15 @@ namespace ur_loader
for( size_t i = ${item['range'][0]}; i < ${item['range'][1]}; ++i )
${item['name']}Local[ i ] = reinterpret_cast<${item['obj']}*>( ${item['name']}[ i ] )->handle;
%else:
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
// convert loader handle to platform handle
%if item['optional']:
${item['name']} = ( ${item['name']} ) ? reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle : nullptr;
%else:
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
%endif
%endif
%endif
%endfor
// forward to device-platform
Expand All @@ -173,7 +178,7 @@ namespace ur_loader
%if item['release']:
// release loader handle
${item['factory']}.release( ${item['name']} );
%else:
%elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
try
{
%if 'range' in item:
Expand Down
129 changes: 9 additions & 120 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,14 +352,6 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativePlatform = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativePlatform, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand Down Expand Up @@ -673,14 +665,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeDevice = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeDevice, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -699,17 +683,13 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->dditable;
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Device.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeDevice =
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->handle;

// convert loader handle to platform handle
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;

Expand Down Expand Up @@ -916,14 +896,6 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeContext, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -944,17 +916,13 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeContext)->dditable;
reinterpret_cast<ur_device_object_t *>(*phDevices)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Context.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeContext =
reinterpret_cast<ur_native_object_t *>(hNativeContext)->handle;

// convert loader handles to platform handles
auto phDevicesLocal = std::vector<ur_device_handle_t>(numDevices);
for (size_t i = 0; i < numDevices; ++i) {
Expand Down Expand Up @@ -1207,14 +1175,6 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeMem, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -1232,17 +1192,13 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnBufferCreateWithNativeHandle =
dditable->ur.Mem.pfnBufferCreateWithNativeHandle;
if (nullptr == pfnBufferCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -1282,17 +1238,13 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnImageCreateWithNativeHandle =
dditable->ur.Mem.pfnImageCreateWithNativeHandle;
if (nullptr == pfnImageCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -1528,14 +1480,6 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeSampler = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeSampler, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -1553,18 +1497,13 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Sampler.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeSampler =
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -2604,14 +2543,6 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeProgram = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeProgram, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -2629,18 +2560,13 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Program.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeProgram =
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -3088,14 +3014,6 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeKernel = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeKernel, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -3115,18 +3033,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Kernel.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeKernel =
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -3300,14 +3213,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeQueue = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeQueue, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -3326,17 +3231,13 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeQueue)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Queue.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeQueue = reinterpret_cast<ur_native_object_t *>(hNativeQueue)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down Expand Up @@ -3573,14 +3474,6 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle(
return result;
}

try {
// convert platform handle to loader handle
*phNativeEvent = reinterpret_cast<ur_native_handle_t>(
ur_native_factory.getInstance(*phNativeEvent, dditable));
} catch (std::bad_alloc &) {
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
}

return result;
}

Expand All @@ -3598,17 +3491,13 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_native_object_t *>(hNativeEvent)->dditable;
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Event.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hNativeEvent = reinterpret_cast<ur_native_object_t *>(hNativeEvent)->handle;

// convert loader handle to platform handle
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

Expand Down

0 comments on commit b25bb64

Please sign in to comment.