From 81c8b1b554d1b6a525def2920d6ceef386fb9346 Mon Sep 17 00:00:00 2001 From: pbalcer Date: Thu, 23 Nov 2023 13:38:39 +0100 Subject: [PATCH] [loader] perform handle conversion after info queries This fixes a bug where the loader, if interception is active, returns direct handles in GetInfo functions instead of converting them to indirect handles. Closes #1107 Reported-by: @AllanZyne --- include/ur_api.h | 2 +- include/ur_print.hpp | 6 +- scripts/core/queue.yml | 2 +- scripts/templates/helper.py | 85 ++++- scripts/templates/ldrddi.cpp.mako | 42 ++- scripts/templates/nullddi.cpp.mako | 19 +- source/adapters/null/ur_null.cpp | 88 ++--- source/adapters/null/ur_null.hpp | 3 + source/adapters/null/ur_nullddi.cpp | 213 ++++++++++++ source/loader/ur_ldrddi.cpp | 412 ++++++++++++++++++++++++ test/loader/CMakeLists.txt | 1 + test/loader/handles/CMakeLists.txt | 27 ++ test/loader/handles/fixtures.hpp | 45 +++ test/loader/handles/urLoaderHandles.cpp | 41 +++ 14 files changed, 931 insertions(+), 55 deletions(-) create mode 100644 test/loader/handles/CMakeLists.txt create mode 100644 test/loader/handles/fixtures.hpp create mode 100644 test/loader/handles/urLoaderHandles.cpp diff --git a/include/ur_api.h b/include/ur_api.h index 5c9c7af5da..b25855be01 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -5103,7 +5103,7 @@ urKernelCreateWithNativeHandle( /////////////////////////////////////////////////////////////////////////////// /// @brief Query queue info typedef enum ur_queue_info_t { - UR_QUEUE_INFO_CONTEXT = 0, ///< [::ur_queue_handle_t] context associated with this queue. + UR_QUEUE_INFO_CONTEXT = 0, ///< [::ur_context_handle_t] context associated with this queue. UR_QUEUE_INFO_DEVICE = 1, ///< [::ur_device_handle_t] device associated with this queue. UR_QUEUE_INFO_DEVICE_DEFAULT = 2, ///< [::ur_queue_handle_t] the current default queue of the underlying ///< device. diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 70e5b9886d..63cf0e3aea 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -8018,9 +8018,9 @@ inline ur_result_t printTagged(std::ostream &os, const void *ptr, ur_queue_info_ switch (value) { case UR_QUEUE_INFO_CONTEXT: { - const ur_queue_handle_t *tptr = (const ur_queue_handle_t *)ptr; - if (sizeof(ur_queue_handle_t) > size) { - os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_queue_handle_t) << ")"; + const ur_context_handle_t *tptr = (const ur_context_handle_t *)ptr; + if (sizeof(ur_context_handle_t) > size) { + os << "invalid size (is: " << size << ", expected: >=" << sizeof(ur_context_handle_t) << ")"; return UR_RESULT_ERROR_INVALID_SIZE; } os << (const void *)(tptr) << " ("; diff --git a/scripts/core/queue.yml b/scripts/core/queue.yml index 15934c0e2f..816da179ba 100644 --- a/scripts/core/queue.yml +++ b/scripts/core/queue.yml @@ -19,7 +19,7 @@ name: $x_queue_info_t typed_etors: True etors: - name: CONTEXT - desc: "[$x_queue_handle_t] context associated with this queue." + desc: "[$x_context_handle_t] context associated with this queue." - name: DEVICE desc: "[$x_device_handle_t] device associated with this queue." - name: DEVICE_DEFAULT diff --git a/scripts/templates/helper.py b/scripts/templates/helper.py index d7d29dc0a8..08ef0952f1 100644 --- a/scripts/templates/helper.py +++ b/scripts/templates/helper.py @@ -39,6 +39,13 @@ def is_handle(obj): except: return False + @staticmethod + def is_enum(obj): + try: + return True if re.match(r"enum", obj['type']) else False + except: + return False + @staticmethod def is_experimental(obj): try: @@ -449,6 +456,13 @@ def is_release(cls, item): except: return False + @classmethod + def is_typename(cls, item): + try: + return True if re.match(cls.RE_TYPENAME, item['desc']) else False + except: + return False + @classmethod def typename(cls, item): match = re.match(cls.RE_TYPENAME, item['desc']) @@ -1241,24 +1255,43 @@ def get_loader_prologue(namespace, tags, obj, meta): return prologue +""" +Public: + returns an enum object with the given name +""" +def get_enum_by_name(specs, namespace, tags, name, only_typed): + for s in specs: + for obj in s['objects']: + if obj_traits.is_enum(obj) and make_enum_name(namespace, tags, obj) == name: + typed = obj.get('typed_etors', False) is True + if only_typed: + if typed: + return obj + else: + return None + else: + return obj + return None + """ Public: returns a list of dict for converting loader output parameters """ -def get_loader_epilogue(namespace, tags, obj, meta): +def get_loader_epilogue(specs, namespace, tags, obj, meta): epilogue = [] for i, item in enumerate(obj['params']): if param_traits.is_mbz(item): continue - if param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item): - if type_traits.is_class_handle(item['type'], meta): - name = subt(namespace, tags, item['name']) - tname = _remove_const_ptr(subt(namespace, tags, item['type'])) - obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname) - fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname) + name = subt(namespace, tags, item['name']) + tname = _remove_const_ptr(subt(namespace, tags, item['type'])) + + obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", tname) + fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", tname) + if param_traits.is_release(item) or param_traits.is_output(item) or param_traits.is_inoutput(item): + if type_traits.is_class_handle(item['type'], meta): if param_traits.is_range(item): range_start = param_traits.range_start(item) range_end = param_traits.range_end(item) @@ -1279,6 +1312,44 @@ def get_loader_epilogue(namespace, tags, obj, meta): 'release': param_traits.is_release(item), 'optional': param_traits.is_optional(item) }) + elif param_traits.is_typename(item): + typename = param_traits.typename(item) + underlying_type = None + for inner in obj['params']: + iname = _get_param_name(namespace, tags, inner) + if iname == typename: + underlying_type = _get_type_name(namespace, tags, obj, inner) + if underlying_type is None: + continue + + prop_size = param_traits.typename_size(item) + enum = get_enum_by_name(specs, namespace, tags, underlying_type, True) + handle_etors = [] + for etor in enum['etors']: + associated_type = etor_get_associated_type(namespace, tags, etor) + if 'handle' in associated_type: + is_array = False + if value_traits.is_array(associated_type): + associated_type = value_traits.get_array_name(associated_type) + is_array = True + + etor_name = make_etor_name(namespace, tags, enum['name'], etor['name']) + obj_name = re.sub(r"(\w+)_handle_t", r"\1_object_t", associated_type) + fty_name = re.sub(r"(\w+)_handle_t", r"\1_factory", associated_type) + handle_etors.append({'name': etor_name, + 'type': associated_type, + 'obj': obj_name, + 'factory': fty_name, + 'is_array': is_array}) + + if handle_etors: + epilogue.append({ + 'name': name, + 'obj': obj_name, + 'release': False, + 'typename': typename, + 'size': prop_size, + 'etors': handle_etors}) return epilogue diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index eaca102ea9..52880247a1 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -160,6 +160,23 @@ namespace ur_loader %endif %endfor + + <% + epilogue = th.get_loader_epilogue(specs, n, tags, obj, meta) + has_typename = False + for item in epilogue: + if 'typename' in item: + has_typename = True + break + %> + + %if has_typename: + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) + pPropSizeRet = &sizeret; + %endif + // forward to device-platform %if add_local: result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name", "local"], replacements=param_replacements))} ); @@ -168,8 +185,9 @@ namespace ur_loader %endif <% del param_replacements - del add_local%> - %for i, item in enumerate(th.get_loader_epilogue(n, tags, obj, meta)): + del add_local + %> + %for i, item in enumerate(epilogue): %if 0 == i: if( ${X}_RESULT_SUCCESS != result ) return result; @@ -181,7 +199,25 @@ namespace ur_loader %elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': try { - %if 'range' in item: + %if 'typename' in item: + if (${item['name']} != nullptr) { + switch (${item['typename']}) { + %for etor in item['etors']: + case ${etor['name']}: { + ${etor['type']} *handles = reinterpret_cast<${etor['type']} *>(${item['name']}); + size_t nelements = *pPropSizeRet / sizeof(${etor['type']}); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast<${etor['type']}>( + ${etor['factory']}.getInstance( handles[i], dditable ) ); + } + } + } break; + %endfor + default: {} break; + } + } + %elif 'range' in item: // convert platform handles to loader handles for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i ) ${item['name']}[ i ] = reinterpret_cast<${item['type']}>( diff --git a/scripts/templates/nullddi.cpp.mako b/scripts/templates/nullddi.cpp.mako index daee79d626..f503d4073c 100644 --- a/scripts/templates/nullddi.cpp.mako +++ b/scripts/templates/nullddi.cpp.mako @@ -48,8 +48,23 @@ namespace driver else { // generic implementation - %for item in th.get_loader_epilogue(n, tags, obj, meta): - %if 'range' in item: + %for item in th.get_loader_epilogue(specs, n, tags, obj, meta): + %if 'typename' in item: + if (${item['name']} != nullptr) { + switch (${item['typename']}) { + %for etor in item['etors']: + case ${etor['name']}: { + ${etor['type']} *handles = reinterpret_cast<${etor['type']} *>(${item['name']}); + size_t nelements = ${item['size']} / sizeof(${etor['type']}); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = reinterpret_cast<${etor['type']}>( d_context.get() ); + } + } break; + %endfor + default: {} break; + } + } + %elif 'range' in item: for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i ) ${item['name']}[ i ] = reinterpret_cast<${item['type']}>( d_context.get() ); %elif not item['release']: diff --git a/source/adapters/null/ur_null.cpp b/source/adapters/null/ur_null.cpp index 094f28c8fb..84ad1ba352 100644 --- a/source/adapters/null/ur_null.cpp +++ b/source/adapters/null/ur_null.cpp @@ -17,6 +17,7 @@ context_t d_context; ////////////////////////////////////////////////////////////////////////// context_t::context_t() { + platform = get(); ////////////////////////////////////////////////////////////////////////// urDdiTable.Global.pfnAdapterGet = [](uint32_t NumAdapters, ur_adapter_handle_t *phAdapters, @@ -28,7 +29,7 @@ context_t::context_t() { *pNumAdapters = 1; } if (nullptr != phAdapters) { - *reinterpret_cast(phAdapters) = d_context.get(); + *reinterpret_cast(phAdapters) = d_context.platform; } return UR_RESULT_SUCCESS; @@ -48,7 +49,7 @@ context_t::context_t() { *pNumPlatforms = 1; } if (nullptr != phPlatforms) { - *reinterpret_cast(phPlatforms) = d_context.get(); + *reinterpret_cast(phPlatforms) = d_context.platform; } return UR_RESULT_SUCCESS; }; @@ -120,48 +121,59 @@ context_t::context_t() { }; ////////////////////////////////////////////////////////////////////////// - urDdiTable.Device.pfnGetInfo = - [](ur_device_handle_t, ur_device_info_t infoType, size_t propSize, - void *pDeviceInfo, size_t *pPropSizeRet) { - switch (infoType) { - case UR_DEVICE_INFO_TYPE: - if (pDeviceInfo && propSize != sizeof(ur_device_type_t)) { - return UR_RESULT_ERROR_INVALID_SIZE; - } + urDdiTable.Device.pfnGetInfo = [](ur_device_handle_t, + ur_device_info_t infoType, + size_t propSize, void *pDeviceInfo, + size_t *pPropSizeRet) { + switch (infoType) { + case UR_DEVICE_INFO_TYPE: + if (pDeviceInfo && propSize != sizeof(ur_device_type_t)) { + return UR_RESULT_ERROR_INVALID_SIZE; + } - if (pDeviceInfo != nullptr) { - *reinterpret_cast(pDeviceInfo) = - UR_DEVICE_TYPE_GPU; - } - if (pPropSizeRet != nullptr) { - *pPropSizeRet = sizeof(ur_device_type_t); - } - break; + if (pDeviceInfo != nullptr) { + *reinterpret_cast(pDeviceInfo) = + UR_DEVICE_TYPE_GPU; + } + if (pPropSizeRet != nullptr) { + *pPropSizeRet = sizeof(ur_device_type_t); + } + break; - case UR_DEVICE_INFO_NAME: { - char deviceName[] = "Null Device"; - if (pDeviceInfo && propSize < sizeof(deviceName)) { - return UR_RESULT_ERROR_INVALID_SIZE; - } - if (pDeviceInfo != nullptr) { + case UR_DEVICE_INFO_NAME: { + char deviceName[] = "Null Device"; + if (pDeviceInfo && propSize < sizeof(deviceName)) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + if (pDeviceInfo != nullptr) { #if defined(_WIN32) - strncpy_s(reinterpret_cast(pDeviceInfo), propSize, - deviceName, sizeof(deviceName)); + strncpy_s(reinterpret_cast(pDeviceInfo), propSize, + deviceName, sizeof(deviceName)); #else - strncpy(reinterpret_cast(pDeviceInfo), deviceName, - propSize); + strncpy(reinterpret_cast(pDeviceInfo), deviceName, + propSize); #endif - } - if (pPropSizeRet != nullptr) { - *pPropSizeRet = sizeof(deviceName); - } - } break; - - default: - return UR_RESULT_ERROR_INVALID_ARGUMENT; } - return UR_RESULT_SUCCESS; - }; + if (pPropSizeRet != nullptr) { + *pPropSizeRet = sizeof(deviceName); + } + } break; + case UR_DEVICE_INFO_PLATFORM: { + if (pDeviceInfo && propSize < sizeof(pDeviceInfo)) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + if (pDeviceInfo != nullptr) { + *reinterpret_cast(pDeviceInfo) = d_context.platform; + } + if (pPropSizeRet != nullptr) { + *pPropSizeRet = sizeof(intptr_t); + } + } break; + default: + return UR_RESULT_ERROR_INVALID_ARGUMENT; + } + return UR_RESULT_SUCCESS; + }; ////////////////////////////////////////////////////////////////////////// urDdiTable.USM.pfnHostAlloc = [](ur_context_handle_t, const ur_usm_desc_t *, diff --git a/source/adapters/null/ur_null.hpp b/source/adapters/null/ur_null.hpp index 9029a25b97..b9b997f5bf 100644 --- a/source/adapters/null/ur_null.hpp +++ b/source/adapters/null/ur_null.hpp @@ -9,6 +9,7 @@ * @file ur_null.hpp * */ +#include "ur_api.h" #ifndef UR_ADAPTER_NULL_H #define UR_ADAPTER_NULL_H 1 @@ -27,6 +28,8 @@ class __urdlllocal context_t { context_t(); ~context_t() = default; + void *platform; + void *get() { static uint64_t count = 0x80800000; return reinterpret_cast(++count); diff --git a/source/adapters/null/ur_nullddi.cpp b/source/adapters/null/ur_nullddi.cpp index f016830d11..d6887ee12f 100644 --- a/source/adapters/null/ur_nullddi.cpp +++ b/source/adapters/null/ur_nullddi.cpp @@ -375,6 +375,30 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( pfnGetInfo(hDevice, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_DEVICE_INFO_PLATFORM: { + ur_platform_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_platform_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + case UR_DEVICE_INFO_PARENT_DEVICE: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -662,6 +686,21 @@ __urdlllocal ur_result_t UR_APICALL urContextGetInfo( pfnGetInfo(hContext, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_CONTEXT_INFO_DEVICES: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -983,6 +1022,21 @@ __urdlllocal ur_result_t UR_APICALL urMemGetInfo( pfnGetInfo(hMemory, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_MEM_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -1110,6 +1164,21 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetInfo( pfnGetInfo(hSampler, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_SAMPLER_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -1298,6 +1367,30 @@ __urdlllocal ur_result_t UR_APICALL urUSMGetMemAllocInfo( pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_USM_ALLOC_INFO_DEVICE: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + case UR_USM_ALLOC_INFO_POOL: { + ur_usm_pool_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_usm_pool_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -1391,6 +1484,21 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolGetInfo( pfnPoolGetInfo(hPool, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_USM_POOL_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -1897,6 +2005,30 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetInfo( pfnGetInfo(hProgram, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_PROGRAM_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + case UR_PROGRAM_INFO_DEVICES: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -2118,6 +2250,30 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetInfo( pfnGetInfo(hKernel, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_KERNEL_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + case UR_KERNEL_INFO_PROGRAM: { + ur_program_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_program_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -2429,6 +2585,39 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetInfo( pfnGetInfo(hQueue, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_QUEUE_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + case UR_QUEUE_INFO_DEVICE: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + case UR_QUEUE_INFO_DEVICE_DEFAULT: { + ur_queue_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_queue_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; @@ -2617,6 +2806,30 @@ __urdlllocal ur_result_t UR_APICALL urEventGetInfo( pfnGetInfo(hEvent, propName, propSize, pPropValue, pPropSizeRet); } else { // generic implementation + if (pPropValue != nullptr) { + switch (propName) { + case UR_EVENT_INFO_COMMAND_QUEUE: { + ur_queue_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_queue_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + case UR_EVENT_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = propSize / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + handles[i] = + reinterpret_cast(d_context.get()); + } + } break; + default: { + } break; + } + } } return result; diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 6d3dda30f0..201315272f 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -506,9 +506,54 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetInfo( // convert loader handle to platform handle hDevice = reinterpret_cast(hDevice)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hDevice, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_DEVICE_INFO_PLATFORM: { + ur_platform_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_platform_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_platform_factory.getInstance(handles[i], + dditable)); + } + } + } break; + case UR_DEVICE_INFO_PARENT_DEVICE: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_device_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -864,9 +909,42 @@ __urdlllocal ur_result_t UR_APICALL urContextGetInfo( // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hContext, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_CONTEXT_INFO_DEVICES: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_device_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -1296,9 +1374,42 @@ __urdlllocal ur_result_t UR_APICALL urMemGetInfo( // convert loader handle to platform handle hMemory = reinterpret_cast(hMemory)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hMemory, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_MEM_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_context_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -1448,9 +1559,42 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetInfo( // convert loader handle to platform handle hSampler = reinterpret_cast(hSampler)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hSampler, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_SAMPLER_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_context_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -1687,10 +1831,55 @@ __urdlllocal ur_result_t UR_APICALL urUSMGetMemAllocInfo( // convert loader handle to platform handle hContext = reinterpret_cast(hContext)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetMemAllocInfo(hContext, pMem, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_USM_ALLOC_INFO_DEVICE: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_device_factory.getInstance(handles[i], + dditable)); + } + } + } break; + case UR_USM_ALLOC_INFO_POOL: { + ur_usm_pool_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_usm_pool_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_usm_pool_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -1803,10 +1992,43 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolGetInfo( // convert loader handle to platform handle hPool = reinterpret_cast(hPool)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnPoolGetInfo(hPool, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_USM_POOL_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_context_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -2443,9 +2665,54 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetInfo( // convert loader handle to platform handle hProgram = reinterpret_cast(hProgram)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hProgram, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_PROGRAM_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_context_factory.getInstance(handles[i], + dditable)); + } + } + } break; + case UR_PROGRAM_INFO_DEVICES: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_device_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -2713,9 +2980,54 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetInfo( // convert loader handle to platform handle hKernel = reinterpret_cast(hKernel)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hKernel, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_KERNEL_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_context_factory.getInstance(handles[i], + dditable)); + } + } + } break; + case UR_KERNEL_INFO_PROGRAM: { + ur_program_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_program_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_program_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -3090,9 +3402,65 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetInfo( // convert loader handle to platform handle hQueue = reinterpret_cast(hQueue)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hQueue, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_QUEUE_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_context_factory.getInstance(handles[i], + dditable)); + } + } + } break; + case UR_QUEUE_INFO_DEVICE: { + ur_device_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_device_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_device_factory.getInstance(handles[i], + dditable)); + } + } + } break; + case UR_QUEUE_INFO_DEVICE_DEFAULT: { + ur_queue_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_queue_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_queue_factory.getInstance(handles[i], dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } @@ -3332,9 +3700,53 @@ __urdlllocal ur_result_t UR_APICALL urEventGetInfo( // convert loader handle to platform handle hEvent = reinterpret_cast(hEvent)->handle; + // this value is needed for converting adapter handles to loader handles + size_t sizeret = 0; + if (pPropSizeRet == NULL) { + pPropSizeRet = &sizeret; + } + // forward to device-platform result = pfnGetInfo(hEvent, propName, propSize, pPropValue, pPropSizeRet); + if (UR_RESULT_SUCCESS != result) { + return result; + } + + try { + if (pPropValue != nullptr) { + switch (propName) { + case UR_EVENT_INFO_COMMAND_QUEUE: { + ur_queue_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_queue_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_queue_factory.getInstance(handles[i], dditable)); + } + } + } break; + case UR_EVENT_INFO_CONTEXT: { + ur_context_handle_t *handles = + reinterpret_cast(pPropValue); + size_t nelements = *pPropSizeRet / sizeof(ur_context_handle_t); + for (size_t i = 0; i < nelements; ++i) { + if (handles[i] != nullptr) { + handles[i] = reinterpret_cast( + ur_context_factory.getInstance(handles[i], + dditable)); + } + } + } break; + default: { + } break; + } + } + } catch (std::bad_alloc &) { + result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return result; } diff --git a/test/loader/CMakeLists.txt b/test/loader/CMakeLists.txt index d36f922098..5472da74bc 100644 --- a/test/loader/CMakeLists.txt +++ b/test/loader/CMakeLists.txt @@ -12,3 +12,4 @@ add_subdirectory(adapter_registry) add_subdirectory(loader_config) add_subdirectory(loader_lifetime) add_subdirectory(platforms) +add_subdirectory(handles) diff --git a/test/loader/handles/CMakeLists.txt b/test/loader/handles/CMakeLists.txt new file mode 100644 index 0000000000..737216fc23 --- /dev/null +++ b/test/loader/handles/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (C) 2023 Intel Corporation +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_executable(test-loader-handles + urLoaderHandles.cpp +) + +target_link_libraries(test-loader-handles + PRIVATE + ${PROJECT_NAME}::common + ${PROJECT_NAME}::headers + ${PROJECT_NAME}::loader + gmock + GTest::gtest_main +) + +add_test(NAME loader-handles + COMMAND test-loader-handles + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +) + +set_tests_properties(loader-handles PROPERTIES + LABELS "loader" + ENVIRONMENT "UR_ENABLE_LOADER_INTERCEPT=1;UR_ADAPTERS_FORCE_LOAD=\"$\"" +) diff --git a/test/loader/handles/fixtures.hpp b/test/loader/handles/fixtures.hpp new file mode 100644 index 0000000000..c903de11ce --- /dev/null +++ b/test/loader/handles/fixtures.hpp @@ -0,0 +1,45 @@ +// Copyright (C) 2023 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef UR_LOADER_CONFIG_TEST_FIXTURES_H +#define UR_LOADER_CONFIG_TEST_FIXTURES_H + +#include "ur_api.h" +#include + +#ifndef ASSERT_SUCCESS +#define ASSERT_SUCCESS(ACTUAL) ASSERT_EQ(UR_RESULT_SUCCESS, ACTUAL) +#endif + +struct LoaderHandleTest : ::testing::Test { + void SetUp() override { + urLoaderInit(0, nullptr); + uint32_t nadapters = 0; + adapter = nullptr; + ASSERT_SUCCESS(urAdapterGet(1, &adapter, &nadapters)); + ASSERT_NE(adapter, nullptr); + uint32_t nplatforms = 0; + platform = nullptr; + ASSERT_SUCCESS(urPlatformGet(&adapter, 1, 1, &platform, &nplatforms)); + ASSERT_NE(platform, nullptr); + uint32_t ndevices; + device = nullptr; + ASSERT_SUCCESS( + urDeviceGet(platform, UR_DEVICE_TYPE_ALL, 1, &device, &ndevices)); + ASSERT_NE(device, nullptr); + } + + void TearDown() override { + urDeviceRelease(device); + urAdapterRelease(adapter); + urLoaderTearDown(); + } + + ur_adapter_handle_t adapter; + ur_platform_handle_t platform; + ur_device_handle_t device; +}; + +#endif diff --git a/test/loader/handles/urLoaderHandles.cpp b/test/loader/handles/urLoaderHandles.cpp new file mode 100644 index 0000000000..6bd17d982a --- /dev/null +++ b/test/loader/handles/urLoaderHandles.cpp @@ -0,0 +1,41 @@ +// Copyright (C) 2023 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "fixtures.hpp" +#include "ur_api.h" +#include +#include +#include + +TEST_F(LoaderHandleTest, Success) { + ur_platform_handle_t query_platform; + size_t retsize; + ASSERT_SUCCESS(urDeviceGetInfo(device, UR_DEVICE_INFO_PLATFORM, + sizeof(intptr_t), &query_platform, + &retsize)); + ASSERT_EQ(query_platform, platform); +} + +TEST_F(LoaderHandleTest, SuccessArray) { + ur_platform_handle_t query_platform[2] = {(ur_platform_handle_t)0xCAFE, + (ur_platform_handle_t)0xBEEF}; + ASSERT_SUCCESS(urDeviceGetInfo(device, UR_DEVICE_INFO_PLATFORM, + sizeof(query_platform), &query_platform, + NULL)); + ASSERT_EQ(query_platform[0], platform); + ASSERT_EQ(query_platform[1], (ur_platform_handle_t)0xBEEF); +} + +TEST_F(LoaderHandleTest, SuccessArraySizeRet) { + ur_platform_handle_t query_platform[2] = {(ur_platform_handle_t)0xCAFE, + (ur_platform_handle_t)0xBEEF}; + size_t sizeret; + ASSERT_SUCCESS(urDeviceGetInfo(device, UR_DEVICE_INFO_PLATFORM, + sizeof(query_platform), &query_platform, + &sizeret)); + ASSERT_EQ(sizeret, sizeof(intptr_t)); + ASSERT_EQ(query_platform[0], platform); + ASSERT_EQ(query_platform[1], (ur_platform_handle_t)0xBEEF); +}