From 341114d00a0fe79ae137a16ad3a76f2cb2685a7a Mon Sep 17 00:00:00 2001 From: Alexander Batashev Date: Sat, 11 Nov 2023 08:29:18 +0000 Subject: [PATCH] [UR][Loader] Fix handling of native handles Native handles are created by adapters and thus are inheritently backend-specific. Loader can not assume anything about these handles, as even nullptr may be a valid value for such a handle. This patch changes two things about native handles: 1) Native handles are no longer wrapped in UR objects 2) Dispatch table is extracted from any other argument of the API function The above is true for all interop APIs except for urPlatformCreateWithNativeHandle, which needs a spec change. --- scripts/templates/ldrddi.cpp.mako | 9 ++- source/loader/ur_ldrddi.cpp | 129 +++--------------------------- 2 files changed, 16 insertions(+), 122 deletions(-) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 0c9a3ed8b0..4cd50e36ac 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -127,14 +127,17 @@ namespace ur_loader %else: <%param_replacements={}%> %for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)): - %if 0 == i: + %if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': // extract platform's function pointer table auto dditable = reinterpret_cast<${item['obj']}*>( ${item['pointer']}${item['name']} )->dditable; auto ${th.make_pfn_name(n, tags, obj)} = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}; if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) return ${X}_RESULT_ERROR_UNINITIALIZED; + <%break%> %endif + %endfor + %for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)): %if 'range' in item: <% add_local = True @@ -143,6 +146,7 @@ namespace ur_loader for( size_t i = ${item['range'][0]}; i < ${item['range'][1]}; ++i ) ${item['name']}Local[ i ] = reinterpret_cast<${item['obj']}*>( ${item['name']}[ i ] )->handle; %else: + %if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': // convert loader handle to platform handle %if item['optional']: ${item['name']} = ( ${item['name']} ) ? reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle : nullptr; @@ -150,6 +154,7 @@ namespace ur_loader ${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle; %endif %endif + %endif %endfor // forward to device-platform @@ -170,7 +175,7 @@ namespace ur_loader %if item['release']: // release loader handle ${item['factory']}.release( ${item['name']} ); - %else: + %elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': try { %if 'range' in item: diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 9327f349c5..c780d51335 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -349,14 +349,6 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativePlatform = reinterpret_cast( - ur_native_factory.getInstance(*phNativePlatform, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -670,14 +662,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeDevice = reinterpret_cast( - ur_native_factory.getInstance(*phNativeDevice, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -696,17 +680,13 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( // extract platform's function pointer table auto dditable = - reinterpret_cast(hNativeDevice)->dditable; + reinterpret_cast(hPlatform)->dditable; auto pfnCreateWithNativeHandle = dditable->ur.Device.pfnCreateWithNativeHandle; if (nullptr == pfnCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeDevice = - reinterpret_cast(hNativeDevice)->handle; - // convert loader handle to platform handle hPlatform = reinterpret_cast(hPlatform)->handle; @@ -913,14 +893,6 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeContext = reinterpret_cast( - ur_native_factory.getInstance(*phNativeContext, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -941,17 +913,13 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle( // extract platform's function pointer table auto dditable = - reinterpret_cast(hNativeContext)->dditable; + reinterpret_cast(*phDevices)->dditable; auto pfnCreateWithNativeHandle = dditable->ur.Context.pfnCreateWithNativeHandle; if (nullptr == pfnCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeContext = - reinterpret_cast(hNativeContext)->handle; - // convert loader handles to platform handles auto phDevicesLocal = std::vector(numDevices); for (size_t i = 0; i < numDevices; ++i) { @@ -1204,14 +1172,6 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeMem = reinterpret_cast( - ur_native_factory.getInstance(*phNativeMem, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -1229,17 +1189,13 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( ur_result_t result = UR_RESULT_SUCCESS; // extract platform's function pointer table - auto dditable = - reinterpret_cast(hNativeMem)->dditable; + auto dditable = reinterpret_cast(hContext)->dditable; auto pfnBufferCreateWithNativeHandle = dditable->ur.Mem.pfnBufferCreateWithNativeHandle; if (nullptr == pfnBufferCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeMem = reinterpret_cast(hNativeMem)->handle; - // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; @@ -1279,17 +1235,13 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle( ur_result_t result = UR_RESULT_SUCCESS; // extract platform's function pointer table - auto dditable = - reinterpret_cast(hNativeMem)->dditable; + auto dditable = reinterpret_cast(hContext)->dditable; auto pfnImageCreateWithNativeHandle = dditable->ur.Mem.pfnImageCreateWithNativeHandle; if (nullptr == pfnImageCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeMem = reinterpret_cast(hNativeMem)->handle; - // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; @@ -1525,14 +1477,6 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeSampler = reinterpret_cast( - ur_native_factory.getInstance(*phNativeSampler, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -1550,18 +1494,13 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle( ur_result_t result = UR_RESULT_SUCCESS; // extract platform's function pointer table - auto dditable = - reinterpret_cast(hNativeSampler)->dditable; + auto dditable = reinterpret_cast(hContext)->dditable; auto pfnCreateWithNativeHandle = dditable->ur.Sampler.pfnCreateWithNativeHandle; if (nullptr == pfnCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeSampler = - reinterpret_cast(hNativeSampler)->handle; - // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; @@ -2601,14 +2540,6 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeProgram = reinterpret_cast( - ur_native_factory.getInstance(*phNativeProgram, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -2626,18 +2557,13 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle( ur_result_t result = UR_RESULT_SUCCESS; // extract platform's function pointer table - auto dditable = - reinterpret_cast(hNativeProgram)->dditable; + auto dditable = reinterpret_cast(hContext)->dditable; auto pfnCreateWithNativeHandle = dditable->ur.Program.pfnCreateWithNativeHandle; if (nullptr == pfnCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeProgram = - reinterpret_cast(hNativeProgram)->handle; - // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; @@ -3085,14 +3011,6 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeKernel = reinterpret_cast( - ur_native_factory.getInstance(*phNativeKernel, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -3112,18 +3030,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle( ur_result_t result = UR_RESULT_SUCCESS; // extract platform's function pointer table - auto dditable = - reinterpret_cast(hNativeKernel)->dditable; + auto dditable = reinterpret_cast(hContext)->dditable; auto pfnCreateWithNativeHandle = dditable->ur.Kernel.pfnCreateWithNativeHandle; if (nullptr == pfnCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeKernel = - reinterpret_cast(hNativeKernel)->handle; - // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; @@ -3297,14 +3210,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeQueue = reinterpret_cast( - ur_native_factory.getInstance(*phNativeQueue, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -3323,17 +3228,13 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle( ur_result_t result = UR_RESULT_SUCCESS; // extract platform's function pointer table - auto dditable = - reinterpret_cast(hNativeQueue)->dditable; + auto dditable = reinterpret_cast(hContext)->dditable; auto pfnCreateWithNativeHandle = dditable->ur.Queue.pfnCreateWithNativeHandle; if (nullptr == pfnCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeQueue = reinterpret_cast(hNativeQueue)->handle; - // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; @@ -3570,14 +3471,6 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle( return result; } - try { - // convert platform handle to loader handle - *phNativeEvent = reinterpret_cast( - ur_native_factory.getInstance(*phNativeEvent, dditable)); - } catch (std::bad_alloc &) { - result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; - } - return result; } @@ -3595,17 +3488,13 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle( ur_result_t result = UR_RESULT_SUCCESS; // extract platform's function pointer table - auto dditable = - reinterpret_cast(hNativeEvent)->dditable; + auto dditable = reinterpret_cast(hContext)->dditable; auto pfnCreateWithNativeHandle = dditable->ur.Event.pfnCreateWithNativeHandle; if (nullptr == pfnCreateWithNativeHandle) { return UR_RESULT_ERROR_UNINITIALIZED; } - // convert loader handle to platform handle - hNativeEvent = reinterpret_cast(hNativeEvent)->handle; - // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle;