diff --git a/include/ur.py b/include/ur.py index d0954563fb..c078f9980f 100644 --- a/include/ur.py +++ b/include/ur.py @@ -198,8 +198,8 @@ class ur_function_v(IntEnum): COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190 ## Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191## Enumerator for ::urCommandBufferAppendMemBufferReadRectExp COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192 ## Enumerator for ::urCommandBufferAppendMemBufferFillExp - ENQUEUE_COOPERATIVE_KERNEL_LAUNCH = 193 ## Enumerator for ::urEnqueueCooperativeKernelLaunch - KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT = 194## Enumerator for ::urKernelSuggestMaxCooperativeGroupCount + ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193 ## Enumerator for ::urEnqueueCooperativeKernelLaunchExp + KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194## Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp class ur_function_t(c_int): def __str__(self): @@ -2684,13 +2684,6 @@ class ur_program_dditable_t(Structure): else: _urKernelSetSpecializationConstants_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, c_ulong, POINTER(ur_specialization_constant_info_t) ) -############################################################################### -## @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCount -if __use_win_types: - _urKernelSuggestMaxCooperativeGroupCount_t = WINFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) ) -else: - _urKernelSuggestMaxCooperativeGroupCount_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) ) - ############################################################################### ## @brief Table of Kernel functions pointers @@ -2710,8 +2703,22 @@ class ur_kernel_dditable_t(Structure): ("pfnSetExecInfo", c_void_p), ## _urKernelSetExecInfo_t ("pfnSetArgSampler", c_void_p), ## _urKernelSetArgSampler_t ("pfnSetArgMemObj", c_void_p), ## _urKernelSetArgMemObj_t - ("pfnSetSpecializationConstants", c_void_p), ## _urKernelSetSpecializationConstants_t - ("pfnSuggestMaxCooperativeGroupCount", c_void_p) ## _urKernelSuggestMaxCooperativeGroupCount_t + ("pfnSetSpecializationConstants", c_void_p) ## _urKernelSetSpecializationConstants_t + ] + +############################################################################### +## @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp +if __use_win_types: + _urKernelSuggestMaxCooperativeGroupCountExp_t = WINFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) ) +else: + _urKernelSuggestMaxCooperativeGroupCountExp_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) ) + + +############################################################################### +## @brief Table of KernelExp functions pointers +class ur_kernel_exp_dditable_t(Structure): + _fields_ = [ + ("pfnSuggestMaxCooperativeGroupCountExp", c_void_p) ## _urKernelSuggestMaxCooperativeGroupCountExp_t ] ############################################################################### @@ -3109,13 +3116,6 @@ class ur_global_dditable_t(Structure): else: _urEnqueueWriteHostPipe_t = CFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_program_handle_t, c_char_p, c_bool, c_void_p, c_size_t, c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) ) -############################################################################### -## @brief Function-pointer for urEnqueueCooperativeKernelLaunch -if __use_win_types: - _urEnqueueCooperativeKernelLaunch_t = WINFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) ) -else: - _urEnqueueCooperativeKernelLaunch_t = CFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) ) - ############################################################################### ## @brief Table of Enqueue functions pointers @@ -3145,8 +3145,22 @@ class ur_enqueue_dditable_t(Structure): ("pfnDeviceGlobalVariableWrite", c_void_p), ## _urEnqueueDeviceGlobalVariableWrite_t ("pfnDeviceGlobalVariableRead", c_void_p), ## _urEnqueueDeviceGlobalVariableRead_t ("pfnReadHostPipe", c_void_p), ## _urEnqueueReadHostPipe_t - ("pfnWriteHostPipe", c_void_p), ## _urEnqueueWriteHostPipe_t - ("pfnCooperativeKernelLaunch", c_void_p) ## _urEnqueueCooperativeKernelLaunch_t + ("pfnWriteHostPipe", c_void_p) ## _urEnqueueWriteHostPipe_t + ] + +############################################################################### +## @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp +if __use_win_types: + _urEnqueueCooperativeKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) ) +else: + _urEnqueueCooperativeKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) ) + + +############################################################################### +## @brief Table of EnqueueExp functions pointers +class ur_enqueue_exp_dditable_t(Structure): + _fields_ = [ + ("pfnCooperativeKernelLaunchExp", c_void_p) ## _urEnqueueCooperativeKernelLaunchExp_t ] ############################################################################### @@ -3781,11 +3795,13 @@ class ur_dditable_t(Structure): ("Event", ur_event_dditable_t), ("Program", ur_program_dditable_t), ("Kernel", ur_kernel_dditable_t), + ("KernelExp", ur_kernel_exp_dditable_t), ("Sampler", ur_sampler_dditable_t), ("Mem", ur_mem_dditable_t), ("PhysicalMem", ur_physical_mem_dditable_t), ("Global", ur_global_dditable_t), ("Enqueue", ur_enqueue_dditable_t), + ("EnqueueExp", ur_enqueue_exp_dditable_t), ("Queue", ur_queue_dditable_t), ("BindlessImagesExp", ur_bindless_images_exp_dditable_t), ("USM", ur_usm_dditable_t), @@ -3905,7 +3921,16 @@ def __init__(self, version : ur_api_version_t): self.urKernelSetArgSampler = _urKernelSetArgSampler_t(self.__dditable.Kernel.pfnSetArgSampler) self.urKernelSetArgMemObj = _urKernelSetArgMemObj_t(self.__dditable.Kernel.pfnSetArgMemObj) self.urKernelSetSpecializationConstants = _urKernelSetSpecializationConstants_t(self.__dditable.Kernel.pfnSetSpecializationConstants) - self.urKernelSuggestMaxCooperativeGroupCount = _urKernelSuggestMaxCooperativeGroupCount_t(self.__dditable.Kernel.pfnSuggestMaxCooperativeGroupCount) + + # call driver to get function pointers + KernelExp = ur_kernel_exp_dditable_t() + r = ur_result_v(self.__dll.urGetKernelExpProcAddrTable(version, byref(KernelExp))) + if r != ur_result_v.SUCCESS: + raise Exception(r) + self.__dditable.KernelExp = KernelExp + + # attach function interface to function address + self.urKernelSuggestMaxCooperativeGroupCountExp = _urKernelSuggestMaxCooperativeGroupCountExp_t(self.__dditable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp) # call driver to get function pointers Sampler = ur_sampler_dditable_t() @@ -4000,7 +4025,16 @@ def __init__(self, version : ur_api_version_t): self.urEnqueueDeviceGlobalVariableRead = _urEnqueueDeviceGlobalVariableRead_t(self.__dditable.Enqueue.pfnDeviceGlobalVariableRead) self.urEnqueueReadHostPipe = _urEnqueueReadHostPipe_t(self.__dditable.Enqueue.pfnReadHostPipe) self.urEnqueueWriteHostPipe = _urEnqueueWriteHostPipe_t(self.__dditable.Enqueue.pfnWriteHostPipe) - self.urEnqueueCooperativeKernelLaunch = _urEnqueueCooperativeKernelLaunch_t(self.__dditable.Enqueue.pfnCooperativeKernelLaunch) + + # call driver to get function pointers + EnqueueExp = ur_enqueue_exp_dditable_t() + r = ur_result_v(self.__dll.urGetEnqueueExpProcAddrTable(version, byref(EnqueueExp))) + if r != ur_result_v.SUCCESS: + raise Exception(r) + self.__dditable.EnqueueExp = EnqueueExp + + # attach function interface to function address + self.urEnqueueCooperativeKernelLaunchExp = _urEnqueueCooperativeKernelLaunchExp_t(self.__dditable.EnqueueExp.pfnCooperativeKernelLaunchExp) # call driver to get function pointers Queue = ur_queue_dditable_t() diff --git a/include/ur_api.h b/include/ur_api.h index 750f827052..ba971b1820 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -207,8 +207,8 @@ typedef enum ur_function_t { UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190, ///< Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191, ///< Enumerator for ::urCommandBufferAppendMemBufferReadRectExp UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192, ///< Enumerator for ::urCommandBufferAppendMemBufferFillExp - UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH = 193, ///< Enumerator for ::urEnqueueCooperativeKernelLaunch - UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT = 194, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCount + UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193, ///< Enumerator for ::urEnqueueCooperativeKernelLaunchExp + UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp /// @cond UR_FUNCTION_FORCE_UINT32 = 0x7fffffff /// @endcond @@ -8159,7 +8159,7 @@ urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES UR_APIEXPORT ur_result_t UR_APICALL -urEnqueueCooperativeKernelLaunch( +urEnqueueCooperativeKernelLaunchExp( ur_queue_handle_t hQueue, ///< [in] handle of the queue object ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and @@ -8197,7 +8197,7 @@ urEnqueueCooperativeKernelLaunch( /// + `NULL == pGroupCountRet` /// - ::UR_RESULT_ERROR_INVALID_KERNEL UR_APIEXPORT ur_result_t UR_APICALL -urKernelSuggestMaxCooperativeGroupCount( +urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups ); @@ -8971,13 +8971,13 @@ typedef struct ur_kernel_set_specialization_constants_params_t { } ur_kernel_set_specialization_constants_params_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Function parameters for urKernelSuggestMaxCooperativeGroupCount +/// @brief Function parameters for urKernelSuggestMaxCooperativeGroupCountExp /// @details Each entry is a pointer to the parameter passed to the function; /// allowing the callback the ability to modify the parameter's value -typedef struct ur_kernel_suggest_max_cooperative_group_count_params_t { +typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t { ur_kernel_handle_t *phKernel; uint32_t **ppGroupCountRet; -} ur_kernel_suggest_max_cooperative_group_count_params_t; +} ur_kernel_suggest_max_cooperative_group_count_exp_params_t; /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urSamplerCreate @@ -9627,10 +9627,10 @@ typedef struct ur_enqueue_write_host_pipe_params_t { } ur_enqueue_write_host_pipe_params_t; /////////////////////////////////////////////////////////////////////////////// -/// @brief Function parameters for urEnqueueCooperativeKernelLaunch +/// @brief Function parameters for urEnqueueCooperativeKernelLaunchExp /// @details Each entry is a pointer to the parameter passed to the function; /// allowing the callback the ability to modify the parameter's value -typedef struct ur_enqueue_cooperative_kernel_launch_params_t { +typedef struct ur_enqueue_cooperative_kernel_launch_exp_params_t { ur_queue_handle_t *phQueue; ur_kernel_handle_t *phKernel; uint32_t *pworkDim; @@ -9640,7 +9640,7 @@ typedef struct ur_enqueue_cooperative_kernel_launch_params_t { uint32_t *pnumEventsInWaitList; const ur_event_handle_t **pphEventWaitList; ur_event_handle_t **pphEvent; -} ur_enqueue_cooperative_kernel_launch_params_t; +} ur_enqueue_cooperative_kernel_launch_exp_params_t; /////////////////////////////////////////////////////////////////////////////// /// @brief Function parameters for urQueueGetInfo diff --git a/include/ur_ddi.h b/include/ur_ddi.h index 321ccb9e27..ae5edf6371 100644 --- a/include/ur_ddi.h +++ b/include/ur_ddi.h @@ -526,12 +526,6 @@ typedef ur_result_t(UR_APICALL *ur_pfnKernelSetSpecializationConstants_t)( uint32_t, const ur_specialization_constant_info_t *); -/////////////////////////////////////////////////////////////////////////////// -/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCount -typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCount_t)( - ur_kernel_handle_t, - uint32_t *); - /////////////////////////////////////////////////////////////////////////////// /// @brief Table of Kernel functions pointers typedef struct ur_kernel_dditable_t { @@ -550,7 +544,6 @@ typedef struct ur_kernel_dditable_t { ur_pfnKernelSetArgSampler_t pfnSetArgSampler; ur_pfnKernelSetArgMemObj_t pfnSetArgMemObj; ur_pfnKernelSetSpecializationConstants_t pfnSetSpecializationConstants; - ur_pfnKernelSuggestMaxCooperativeGroupCount_t pfnSuggestMaxCooperativeGroupCount; } ur_kernel_dditable_t; /////////////////////////////////////////////////////////////////////////////// @@ -574,6 +567,39 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)( ur_api_version_t, ur_kernel_dditable_t *); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp +typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)( + ur_kernel_handle_t, + uint32_t *); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Table of KernelExp functions pointers +typedef struct ur_kernel_exp_dditable_t { + ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t pfnSuggestMaxCooperativeGroupCountExp; +} ur_kernel_exp_dditable_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL +urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers +); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urGetKernelExpProcAddrTable +typedef ur_result_t(UR_APICALL *ur_pfnGetKernelExpProcAddrTable_t)( + ur_api_version_t, + ur_kernel_exp_dditable_t *); + /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urSamplerCreate typedef ur_result_t(UR_APICALL *ur_pfnSamplerCreate_t)( @@ -1202,19 +1228,6 @@ typedef ur_result_t(UR_APICALL *ur_pfnEnqueueWriteHostPipe_t)( const ur_event_handle_t *, ur_event_handle_t *); -/////////////////////////////////////////////////////////////////////////////// -/// @brief Function-pointer for urEnqueueCooperativeKernelLaunch -typedef ur_result_t(UR_APICALL *ur_pfnEnqueueCooperativeKernelLaunch_t)( - ur_queue_handle_t, - ur_kernel_handle_t, - uint32_t, - const size_t *, - const size_t *, - const size_t *, - uint32_t, - const ur_event_handle_t *, - ur_event_handle_t *); - /////////////////////////////////////////////////////////////////////////////// /// @brief Table of Enqueue functions pointers typedef struct ur_enqueue_dditable_t { @@ -1243,7 +1256,6 @@ typedef struct ur_enqueue_dditable_t { ur_pfnEnqueueDeviceGlobalVariableRead_t pfnDeviceGlobalVariableRead; ur_pfnEnqueueReadHostPipe_t pfnReadHostPipe; ur_pfnEnqueueWriteHostPipe_t pfnWriteHostPipe; - ur_pfnEnqueueCooperativeKernelLaunch_t pfnCooperativeKernelLaunch; } ur_enqueue_dditable_t; /////////////////////////////////////////////////////////////////////////////// @@ -1267,6 +1279,46 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueProcAddrTable_t)( ur_api_version_t, ur_enqueue_dditable_t *); +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp +typedef ur_result_t(UR_APICALL *ur_pfnEnqueueCooperativeKernelLaunchExp_t)( + ur_queue_handle_t, + ur_kernel_handle_t, + uint32_t, + const size_t *, + const size_t *, + const size_t *, + uint32_t, + const ur_event_handle_t *, + ur_event_handle_t *); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Table of EnqueueExp functions pointers +typedef struct ur_enqueue_exp_dditable_t { + ur_pfnEnqueueCooperativeKernelLaunchExp_t pfnCooperativeKernelLaunchExp; +} ur_enqueue_exp_dditable_t; + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL +urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers +); + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Function-pointer for urGetEnqueueExpProcAddrTable +typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueExpProcAddrTable_t)( + ur_api_version_t, + ur_enqueue_exp_dditable_t *); + /////////////////////////////////////////////////////////////////////////////// /// @brief Function-pointer for urQueueGetInfo typedef ur_result_t(UR_APICALL *ur_pfnQueueGetInfo_t)( @@ -2175,11 +2227,13 @@ typedef struct ur_dditable_t { ur_event_dditable_t Event; ur_program_dditable_t Program; ur_kernel_dditable_t Kernel; + ur_kernel_exp_dditable_t KernelExp; ur_sampler_dditable_t Sampler; ur_mem_dditable_t Mem; ur_physical_mem_dditable_t PhysicalMem; ur_global_dditable_t Global; ur_enqueue_dditable_t Enqueue; + ur_enqueue_exp_dditable_t EnqueueExp; ur_queue_dditable_t Queue; ur_bindless_images_exp_dditable_t BindlessImagesExp; ur_usm_dditable_t USM; diff --git a/scripts/core/EXP-COOPERATIVE-KERNELS.rst b/scripts/core/EXP-COOPERATIVE-KERNELS.rst index d82b55a7bc..c6b64ef669 100644 --- a/scripts/core/EXP-COOPERATIVE-KERNELS.rst +++ b/scripts/core/EXP-COOPERATIVE-KERNELS.rst @@ -42,8 +42,8 @@ Macros Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -* ${x}EnqueueCooperativeKernelLaunch -* ${x}KernelSuggestMaxCooperativeGroupCount +* ${x}EnqueueCooperativeKernelLaunchExp +* ${x}KernelSuggestMaxCooperativeGroupCountExp Changelog -------------------------------------------------------------------------------- diff --git a/scripts/core/exp-cooperative-kernels.yml b/scripts/core/exp-cooperative-kernels.yml index 529a1aec9a..fb2c6b3a4a 100644 --- a/scripts/core/exp-cooperative-kernels.yml +++ b/scripts/core/exp-cooperative-kernels.yml @@ -22,7 +22,7 @@ value: "\"$x_exp_cooperative_kernels\"" type: function desc: "Enqueue a command to execute a cooperative kernel" class: $xEnqueue -name: CooperativeKernelLaunch +name: CooperativeKernelLaunchExp params: - type: $x_queue_handle_t name: hQueue @@ -73,7 +73,7 @@ returns: type: function desc: "Query the maximum number of work groups for a cooperative kernel" class: $xKernel -name: SuggestMaxCooperativeGroupCount +name: SuggestMaxCooperativeGroupCountExp params: - type: $x_kernel_handle_t name: hKernel diff --git a/scripts/core/registry.yml b/scripts/core/registry.yml index 1a5375e0f0..4924e3e947 100644 --- a/scripts/core/registry.yml +++ b/scripts/core/registry.yml @@ -535,11 +535,11 @@ etors: - name: COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP desc: Enumerator for $xCommandBufferAppendMemBufferFillExp value: '192' -- name: ENQUEUE_COOPERATIVE_KERNEL_LAUNCH - desc: Enumerator for $xEnqueueCooperativeKernelLaunch +- name: ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP + desc: Enumerator for $xEnqueueCooperativeKernelLaunchExp value: '193' -- name: KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT - desc: Enumerator for $xKernelSuggestMaxCooperativeGroupCount +- name: KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP + desc: Enumerator for $xKernelSuggestMaxCooperativeGroupCountExp value: '194' --- type: enum diff --git a/source/adapters/null/ur_nullddi.cpp b/source/adapters/null/ur_nullddi.cpp index b2b7b0e6b5..314cb9138e 100644 --- a/source/adapters/null/ur_nullddi.cpp +++ b/source/adapters/null/ur_nullddi.cpp @@ -4928,8 +4928,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueCooperativeKernelLaunch -__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( ur_queue_handle_t hQueue, ///< [in] handle of the queue object ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t @@ -4961,10 +4961,10 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( ur_result_t result = UR_RESULT_SUCCESS; // if the driver has created a custom function, then call it instead of using the generic path - auto pfnCooperativeKernelLaunch = - d_context.urDdiTable.Enqueue.pfnCooperativeKernelLaunch; - if (nullptr != pfnCooperativeKernelLaunch) { - result = pfnCooperativeKernelLaunch( + auto pfnCooperativeKernelLaunchExp = + d_context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; + if (nullptr != pfnCooperativeKernelLaunchExp) { + result = pfnCooperativeKernelLaunchExp( hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); } else { @@ -4980,18 +4980,18 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCount -__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups ) try { ur_result_t result = UR_RESULT_SUCCESS; // if the driver has created a custom function, then call it instead of using the generic path - auto pfnSuggestMaxCooperativeGroupCount = - d_context.urDdiTable.Kernel.pfnSuggestMaxCooperativeGroupCount; - if (nullptr != pfnSuggestMaxCooperativeGroupCount) { - result = pfnSuggestMaxCooperativeGroupCount(hKernel, pGroupCountRet); + auto pfnSuggestMaxCooperativeGroupCountExp = + d_context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; + if (nullptr != pfnSuggestMaxCooperativeGroupCountExp) { + result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); } else { // generic implementation } @@ -5426,8 +5426,36 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnWriteHostPipe = driver::urEnqueueWriteHostPipe; - pDdiTable->pfnCooperativeKernelLaunch = - driver::urEnqueueCooperativeKernelLaunch; + return result; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers + ) try { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (driver::d_context.version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnCooperativeKernelLaunchExp = + driver::urEnqueueCooperativeKernelLaunchExp; return result; } catch (...) { @@ -5534,8 +5562,36 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( pDdiTable->pfnSetSpecializationConstants = driver::urKernelSetSpecializationConstants; - pDdiTable->pfnSuggestMaxCooperativeGroupCount = - driver::urKernelSuggestMaxCooperativeGroupCount; + return result; +} catch (...) { + return exceptionToResult(std::current_exception()); +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers + ) try { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (driver::d_context.version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + driver::urKernelSuggestMaxCooperativeGroupCountExp; return result; } catch (...) { diff --git a/source/common/ur_params.hpp b/source/common/ur_params.hpp index 283532a518..09a9c667a4 100644 --- a/source/common/ur_params.hpp +++ b/source/common/ur_params.hpp @@ -1150,12 +1150,12 @@ inline std::ostream &operator<<(std::ostream &os, enum ur_function_t value) { os << "UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP"; break; - case UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH: - os << "UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH"; + case UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP: + os << "UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP"; break; - case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT: - os << "UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT"; + case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP: + os << "UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP"; break; default: os << "unknown enumerator"; @@ -12925,9 +12925,9 @@ operator<<(std::ostream &os, return os; } -inline std::ostream & -operator<<(std::ostream &os, - const struct ur_enqueue_cooperative_kernel_launch_params_t *params) { +inline std::ostream &operator<<( + std::ostream &os, + const struct ur_enqueue_cooperative_kernel_launch_exp_params_t *params) { os << ".hQueue = "; @@ -13531,10 +13531,10 @@ inline std::ostream &operator<<( return os; } -inline std::ostream & -operator<<(std::ostream &os, - const struct ur_kernel_suggest_max_cooperative_group_count_params_t - *params) { +inline std::ostream &operator<<( + std::ostream &os, + const struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t + *params) { os << ".hKernel = "; @@ -15767,8 +15767,8 @@ inline int serializeFunctionParams(std::ostream &os, uint32_t function, case UR_FUNCTION_ENQUEUE_WRITE_HOST_PIPE: { os << (const struct ur_enqueue_write_host_pipe_params_t *)params; } break; - case UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH: { - os << (const struct ur_enqueue_cooperative_kernel_launch_params_t *) + case UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP: { + os << (const struct ur_enqueue_cooperative_kernel_launch_exp_params_t *) params; } break; case UR_FUNCTION_EVENT_GET_INFO: { @@ -15843,9 +15843,10 @@ inline int serializeFunctionParams(std::ostream &os, uint32_t function, os << (const struct ur_kernel_set_specialization_constants_params_t *) params; } break; - case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT: { + case UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP: { os << (const struct - ur_kernel_suggest_max_cooperative_group_count_params_t *)params; + ur_kernel_suggest_max_cooperative_group_count_exp_params_t *) + params; } break; case UR_FUNCTION_LOADER_INIT: { os << (const struct ur_loader_init_params_t *)params; diff --git a/source/loader/layers/tracing/ur_trcddi.cpp b/source/loader/layers/tracing/ur_trcddi.cpp index 433142e99a..6e4ce75bb8 100644 --- a/source/loader/layers/tracing/ur_trcddi.cpp +++ b/source/loader/layers/tracing/ur_trcddi.cpp @@ -5696,8 +5696,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueCooperativeKernelLaunch -__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( ur_queue_handle_t hQueue, ///< [in] handle of the queue object ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t @@ -5726,14 +5726,14 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( phEvent ///< [out][optional] return an event object that identifies this particular ///< kernel execution instance. ) { - auto pfnCooperativeKernelLaunch = - context.urDdiTable.Enqueue.pfnCooperativeKernelLaunch; + auto pfnCooperativeKernelLaunchExp = + context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; - if (nullptr == pfnCooperativeKernelLaunch) { + if (nullptr == pfnCooperativeKernelLaunchExp) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } - ur_enqueue_cooperative_kernel_launch_params_t params = { + ur_enqueue_cooperative_kernel_launch_exp_params_t params = { &hQueue, &hKernel, &workDim, @@ -5744,45 +5744,46 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( &phEventWaitList, &phEvent}; uint64_t instance = - context.notify_begin(UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH, - "urEnqueueCooperativeKernelLaunch", ¶ms); + context.notify_begin(UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP, + "urEnqueueCooperativeKernelLaunchExp", ¶ms); - ur_result_t result = pfnCooperativeKernelLaunch( + ur_result_t result = pfnCooperativeKernelLaunchExp( hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); - context.notify_end(UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH, - "urEnqueueCooperativeKernelLaunch", ¶ms, &result, + context.notify_end(UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP, + "urEnqueueCooperativeKernelLaunchExp", ¶ms, &result, instance); return result; } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCount -__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups ) { - auto pfnSuggestMaxCooperativeGroupCount = - context.urDdiTable.Kernel.pfnSuggestMaxCooperativeGroupCount; + auto pfnSuggestMaxCooperativeGroupCountExp = + context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; - if (nullptr == pfnSuggestMaxCooperativeGroupCount) { + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } - ur_kernel_suggest_max_cooperative_group_count_params_t params = { + ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = { &hKernel, &pGroupCountRet}; uint64_t instance = context.notify_begin( - UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT, - "urKernelSuggestMaxCooperativeGroupCount", ¶ms); + UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP, + "urKernelSuggestMaxCooperativeGroupCountExp", ¶ms); ur_result_t result = - pfnSuggestMaxCooperativeGroupCount(hKernel, pGroupCountRet); + pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); - context.notify_end(UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT, - "urKernelSuggestMaxCooperativeGroupCount", ¶ms, - &result, instance); + context.notify_end( + UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP, + "urKernelSuggestMaxCooperativeGroupCountExp", ¶ms, &result, + instance); return result; } @@ -6336,9 +6337,40 @@ __urdlllocal ur_result_t UR_APICALL urGetEnqueueProcAddrTable( dditable.pfnWriteHostPipe = pDdiTable->pfnWriteHostPipe; pDdiTable->pfnWriteHostPipe = ur_tracing_layer::urEnqueueWriteHostPipe; - dditable.pfnCooperativeKernelLaunch = pDdiTable->pfnCooperativeKernelLaunch; - pDdiTable->pfnCooperativeKernelLaunch = - ur_tracing_layer::urEnqueueCooperativeKernelLaunch; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_tracing_layer::context.urDdiTable.EnqueueExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_tracing_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCooperativeKernelLaunchExp = + pDdiTable->pfnCooperativeKernelLaunchExp; + pDdiTable->pfnCooperativeKernelLaunchExp = + ur_tracing_layer::urEnqueueCooperativeKernelLaunchExp; return result; } @@ -6473,10 +6505,40 @@ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable( pDdiTable->pfnSetSpecializationConstants = ur_tracing_layer::urKernelSetSpecializationConstants; - dditable.pfnSuggestMaxCooperativeGroupCount = - pDdiTable->pfnSuggestMaxCooperativeGroupCount; - pDdiTable->pfnSuggestMaxCooperativeGroupCount = - ur_tracing_layer::urKernelSuggestMaxCooperativeGroupCount; + return result; +} +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +__urdlllocal ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_tracing_layer::context.urDdiTable.KernelExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_tracing_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnSuggestMaxCooperativeGroupCountExp = + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp; + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + ur_tracing_layer::urKernelSuggestMaxCooperativeGroupCountExp; return result; } @@ -7094,6 +7156,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Enqueue); } + if (UR_RESULT_SUCCESS == result) { + result = ur_tracing_layer::urGetEnqueueExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->EnqueueExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_tracing_layer::urGetEventProcAddrTable( UR_API_VERSION_CURRENT, &dditable->Event); @@ -7104,6 +7171,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Kernel); } + if (UR_RESULT_SUCCESS == result) { + result = ur_tracing_layer::urGetKernelExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->KernelExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_tracing_layer::urGetMemProcAddrTable(UR_API_VERSION_CURRENT, &dditable->Mem); diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index 571504b1a6..ebc7dd67bf 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -7061,8 +7061,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueCooperativeKernelLaunch -__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( ur_queue_handle_t hQueue, ///< [in] handle of the queue object ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t @@ -7091,10 +7091,10 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( phEvent ///< [out][optional] return an event object that identifies this particular ///< kernel execution instance. ) { - auto pfnCooperativeKernelLaunch = - context.urDdiTable.Enqueue.pfnCooperativeKernelLaunch; + auto pfnCooperativeKernelLaunchExp = + context.urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; - if (nullptr == pfnCooperativeKernelLaunch) { + if (nullptr == pfnCooperativeKernelLaunchExp) { return UR_RESULT_ERROR_UNINITIALIZED; } @@ -7124,7 +7124,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( } } - ur_result_t result = pfnCooperativeKernelLaunch( + ur_result_t result = pfnCooperativeKernelLaunchExp( hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); @@ -7132,15 +7132,15 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCount -__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups ) { - auto pfnSuggestMaxCooperativeGroupCount = - context.urDdiTable.Kernel.pfnSuggestMaxCooperativeGroupCount; + auto pfnSuggestMaxCooperativeGroupCountExp = + context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; - if (nullptr == pfnSuggestMaxCooperativeGroupCount) { + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { return UR_RESULT_ERROR_UNINITIALIZED; } @@ -7155,7 +7155,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( } ur_result_t result = - pfnSuggestMaxCooperativeGroupCount(hKernel, pGroupCountRet); + pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); return result; } @@ -7742,9 +7742,41 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( dditable.pfnWriteHostPipe = pDdiTable->pfnWriteHostPipe; pDdiTable->pfnWriteHostPipe = ur_validation_layer::urEnqueueWriteHostPipe; - dditable.pfnCooperativeKernelLaunch = pDdiTable->pfnCooperativeKernelLaunch; - pDdiTable->pfnCooperativeKernelLaunch = - ur_validation_layer::urEnqueueCooperativeKernelLaunch; + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_validation_layer::context.urDdiTable.EnqueueExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_validation_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnCooperativeKernelLaunchExp = + pDdiTable->pfnCooperativeKernelLaunchExp; + pDdiTable->pfnCooperativeKernelLaunchExp = + ur_validation_layer::urEnqueueCooperativeKernelLaunchExp; return result; } @@ -7884,10 +7916,41 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( pDdiTable->pfnSetSpecializationConstants = ur_validation_layer::urKernelSetSpecializationConstants; - dditable.pfnSuggestMaxCooperativeGroupCount = - pDdiTable->pfnSuggestMaxCooperativeGroupCount; - pDdiTable->pfnSuggestMaxCooperativeGroupCount = - ur_validation_layer::urKernelSuggestMaxCooperativeGroupCount; + return result; +} + +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + auto &dditable = ur_validation_layer::context.urDdiTable.KernelExp; + + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != + UR_MAJOR_VERSION(version) || + UR_MINOR_VERSION(ur_validation_layer::context.version) > + UR_MINOR_VERSION(version)) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + dditable.pfnSuggestMaxCooperativeGroupCountExp = + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp; + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + ur_validation_layer::urKernelSuggestMaxCooperativeGroupCountExp; return result; } @@ -8532,6 +8595,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Enqueue); } + if (UR_RESULT_SUCCESS == result) { + result = ur_validation_layer::urGetEnqueueExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->EnqueueExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_validation_layer::urGetEventProcAddrTable( UR_API_VERSION_CURRENT, &dditable->Event); @@ -8542,6 +8610,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable, UR_API_VERSION_CURRENT, &dditable->Kernel); } + if (UR_RESULT_SUCCESS == result) { + result = ur_validation_layer::urGetKernelExpProcAddrTable( + UR_API_VERSION_CURRENT, &dditable->KernelExp); + } + if (UR_RESULT_SUCCESS == result) { result = ur_validation_layer::urGetMemProcAddrTable( UR_API_VERSION_CURRENT, &dditable->Mem); diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index ade8c87230..528b2d1eba 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -6847,8 +6847,8 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferEnqueueExp( } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urEnqueueCooperativeKernelLaunch -__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( +/// @brief Intercept function for urEnqueueCooperativeKernelLaunchExp +__urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( ur_queue_handle_t hQueue, ///< [in] handle of the queue object ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t @@ -6881,9 +6881,9 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( // extract platform's function pointer table auto dditable = reinterpret_cast(hQueue)->dditable; - auto pfnCooperativeKernelLaunch = - dditable->ur.Enqueue.pfnCooperativeKernelLaunch; - if (nullptr == pfnCooperativeKernelLaunch) { + auto pfnCooperativeKernelLaunchExp = + dditable->ur.EnqueueExp.pfnCooperativeKernelLaunchExp; + if (nullptr == pfnCooperativeKernelLaunchExp) { return UR_RESULT_ERROR_UNINITIALIZED; } @@ -6902,10 +6902,10 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( } // forward to device-platform - result = pfnCooperativeKernelLaunch(hQueue, hKernel, workDim, - pGlobalWorkOffset, pGlobalWorkSize, - pLocalWorkSize, numEventsInWaitList, - phEventWaitListLocal.data(), phEvent); + result = pfnCooperativeKernelLaunchExp( + hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize, numEventsInWaitList, phEventWaitListLocal.data(), + phEvent); if (UR_RESULT_SUCCESS != result) { return result; @@ -6925,8 +6925,8 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( } /////////////////////////////////////////////////////////////////////////////// -/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCount -__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( +/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp +__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups ) { @@ -6934,9 +6934,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( // extract platform's function pointer table auto dditable = reinterpret_cast(hKernel)->dditable; - auto pfnSuggestMaxCooperativeGroupCount = - dditable->ur.Kernel.pfnSuggestMaxCooperativeGroupCount; - if (nullptr == pfnSuggestMaxCooperativeGroupCount) { + auto pfnSuggestMaxCooperativeGroupCountExp = + dditable->ur.KernelExp.pfnSuggestMaxCooperativeGroupCountExp; + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { return UR_RESULT_ERROR_UNINITIALIZED; } @@ -6944,7 +6944,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( hKernel = reinterpret_cast(hKernel)->handle; // forward to device-platform - result = pfnSuggestMaxCooperativeGroupCount(hKernel, pGroupCountRet); + result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); return result; } @@ -7469,8 +7469,6 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( ur_loader::urEnqueueDeviceGlobalVariableRead; pDdiTable->pfnReadHostPipe = ur_loader::urEnqueueReadHostPipe; pDdiTable->pfnWriteHostPipe = ur_loader::urEnqueueWriteHostPipe; - pDdiTable->pfnCooperativeKernelLaunch = - ur_loader::urEnqueueCooperativeKernelLaunch; } else { // return pointers directly to platform's DDIs *pDdiTable = @@ -7481,6 +7479,61 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's EnqueueExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_enqueue_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (ur_loader::context->version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + // Load the device-platform DDI tables + for (auto &platform : ur_loader::context->platforms) { + if (platform.initStatus != UR_RESULT_SUCCESS) { + continue; + } + auto getTable = reinterpret_cast( + ur_loader::LibLoader::getFunctionPtr( + platform.handle.get(), "urGetEnqueueExpProcAddrTable")); + if (!getTable) { + continue; + } + platform.initStatus = + getTable(version, &platform.dditable.ur.EnqueueExp); + } + + if (UR_RESULT_SUCCESS == result) { + if (ur_loader::context->platforms.size() != 1 || + ur_loader::context->forceIntercept) { + // return pointers to loader's DDIs + pDdiTable->pfnCooperativeKernelLaunchExp = + ur_loader::urEnqueueCooperativeKernelLaunchExp; + } else { + // return pointers directly to platform's DDIs + *pDdiTable = + ur_loader::context->platforms.front().dditable.ur.EnqueueExp; + } + } + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Event table /// with current process' addresses @@ -7601,8 +7654,6 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( pDdiTable->pfnSetArgMemObj = ur_loader::urKernelSetArgMemObj; pDdiTable->pfnSetSpecializationConstants = ur_loader::urKernelSetSpecializationConstants; - pDdiTable->pfnSuggestMaxCooperativeGroupCount = - ur_loader::urKernelSuggestMaxCooperativeGroupCount; } else { // return pointers directly to platform's DDIs *pDdiTable = @@ -7613,6 +7664,61 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelProcAddrTable( return result; } +/////////////////////////////////////////////////////////////////////////////// +/// @brief Exported function for filling application's KernelExp table +/// with current process' addresses +/// +/// @returns +/// - ::UR_RESULT_SUCCESS +/// - ::UR_RESULT_ERROR_UNINITIALIZED +/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER +/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION +UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable( + ur_api_version_t version, ///< [in] API version requested + ur_kernel_exp_dditable_t + *pDdiTable ///< [in,out] pointer to table of DDI function pointers +) { + if (nullptr == pDdiTable) { + return UR_RESULT_ERROR_INVALID_NULL_POINTER; + } + + if (ur_loader::context->version < version) { + return UR_RESULT_ERROR_UNSUPPORTED_VERSION; + } + + ur_result_t result = UR_RESULT_SUCCESS; + + // Load the device-platform DDI tables + for (auto &platform : ur_loader::context->platforms) { + if (platform.initStatus != UR_RESULT_SUCCESS) { + continue; + } + auto getTable = reinterpret_cast( + ur_loader::LibLoader::getFunctionPtr( + platform.handle.get(), "urGetKernelExpProcAddrTable")); + if (!getTable) { + continue; + } + platform.initStatus = + getTable(version, &platform.dditable.ur.KernelExp); + } + + if (UR_RESULT_SUCCESS == result) { + if (ur_loader::context->platforms.size() != 1 || + ur_loader::context->forceIntercept) { + // return pointers to loader's DDIs + pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = + ur_loader::urKernelSuggestMaxCooperativeGroupCountExp; + } else { + // return pointers directly to platform's DDIs + *pDdiTable = + ur_loader::context->platforms.front().dditable.ur.KernelExp; + } + } + + return result; +} + /////////////////////////////////////////////////////////////////////////////// /// @brief Exported function for filling application's Mem table /// with current process' addresses diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 138fb1e2b3..8f7e04e76c 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -7641,7 +7641,7 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_INVALID_VALUE /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES -ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( +ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( ur_queue_handle_t hQueue, ///< [in] handle of the queue object ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t @@ -7670,13 +7670,13 @@ ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( phEvent ///< [out][optional] return an event object that identifies this particular ///< kernel execution instance. ) try { - auto pfnCooperativeKernelLaunch = - ur_lib::context->urDdiTable.Enqueue.pfnCooperativeKernelLaunch; - if (nullptr == pfnCooperativeKernelLaunch) { + auto pfnCooperativeKernelLaunchExp = + ur_lib::context->urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp; + if (nullptr == pfnCooperativeKernelLaunchExp) { return UR_RESULT_ERROR_UNINITIALIZED; } - return pfnCooperativeKernelLaunch( + return pfnCooperativeKernelLaunchExp( hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, numEventsInWaitList, phEventWaitList, phEvent); } catch (...) { @@ -7696,17 +7696,18 @@ ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pGroupCountRet` /// - ::UR_RESULT_ERROR_INVALID_KERNEL -ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( +ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups ) try { - auto pfnSuggestMaxCooperativeGroupCount = - ur_lib::context->urDdiTable.Kernel.pfnSuggestMaxCooperativeGroupCount; - if (nullptr == pfnSuggestMaxCooperativeGroupCount) { + auto pfnSuggestMaxCooperativeGroupCountExp = + ur_lib::context->urDdiTable.KernelExp + .pfnSuggestMaxCooperativeGroupCountExp; + if (nullptr == pfnSuggestMaxCooperativeGroupCountExp) { return UR_RESULT_ERROR_UNINITIALIZED; } - return pfnSuggestMaxCooperativeGroupCount(hKernel, pGroupCountRet); + return pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet); } catch (...) { return exceptionToResult(std::current_exception()); } diff --git a/source/loader/ur_libddi.cpp b/source/loader/ur_libddi.cpp index 912449b54e..9d0a2566f4 100644 --- a/source/loader/ur_libddi.cpp +++ b/source/loader/ur_libddi.cpp @@ -45,6 +45,11 @@ __urdlllocal ur_result_t context_t::urLoaderInit() { &urDdiTable.Enqueue); } + if (UR_RESULT_SUCCESS == result) { + result = urGetEnqueueExpProcAddrTable(UR_API_VERSION_CURRENT, + &urDdiTable.EnqueueExp); + } + if (UR_RESULT_SUCCESS == result) { result = urGetEventProcAddrTable(UR_API_VERSION_CURRENT, &urDdiTable.Event); @@ -55,6 +60,11 @@ __urdlllocal ur_result_t context_t::urLoaderInit() { &urDdiTable.Kernel); } + if (UR_RESULT_SUCCESS == result) { + result = urGetKernelExpProcAddrTable(UR_API_VERSION_CURRENT, + &urDdiTable.KernelExp); + } + if (UR_RESULT_SUCCESS == result) { result = urGetMemProcAddrTable(UR_API_VERSION_CURRENT, &urDdiTable.Mem); } diff --git a/source/ur_api.cpp b/source/ur_api.cpp index f622aaa2e5..413120fa0a 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -6446,7 +6446,7 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// - ::UR_RESULT_ERROR_INVALID_VALUE /// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY /// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES -ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( +ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp( ur_queue_handle_t hQueue, ///< [in] handle of the queue object ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t @@ -6492,7 +6492,7 @@ ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunch( /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pGroupCountRet` /// - ::UR_RESULT_ERROR_INVALID_KERNEL -ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCount( +ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp( ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups ) {