Skip to content

Commit

Permalink
[EXP][Command-Buffer] Add kernel command update
Browse files Browse the repository at this point in the history
This change introduces a new API that allows the
kernel commands of a command-buffer to be updated
with a new configuration. For example, modified
arguments or ND-Range.

See
[cl_khr_command_buffer_mutable_dispatch](https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Ext.html#cl_khr_command_buffer_mutable_dispatch)
as prior art. The differences between the proposed API and the above
are:

* No flag is required to be set on command-buffer creation to enable
  this functionality.
* Only the append kernel entry-point returns a command handle. I imagine
  this will be changed in future to enable other commands to do update.
* Only USM and buffer arguments can be updated, there is not equivalent
  update struct for `urKernelSetArgLocal`, `urKernelSetArgValue`, or
  `urKernelSetArgSampler`
* There is no granularity of optional support for update, an implementer
  must either implement all the ways to update a kernel configuration,
  or none of them.
  • Loading branch information
EwanC committed Nov 16, 2023
1 parent 534071e commit e2a9320
Show file tree
Hide file tree
Showing 21 changed files with 1,643 additions and 563 deletions.
93 changes: 90 additions & 3 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class ur_function_v(IntEnum):
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo
COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_EXP = 182 ## Enumerator for ::urCommandBufferUpdateKernelLaunchExp

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -240,6 +241,10 @@ class ur_structure_type_v(IntEnum):
KERNEL_ARG_VALUE_PROPERTIES = 32 ## ::ur_kernel_arg_value_properties_t
KERNEL_ARG_LOCAL_PROPERTIES = 33 ## ::ur_kernel_arg_local_properties_t
EXP_COMMAND_BUFFER_DESC = 0x1000 ## ::ur_exp_command_buffer_desc_t
EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC = 0x1001 ## ::ur_exp_command_buffer_update_kernel_launch_desc_t
EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC = 0x1002 ## ::ur_exp_command_buffer_update_memobj_arg_desc_t
EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC = 0x1003 ## ::ur_exp_command_buffer_update_pointer_arg_desc_t
EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC = 0x1004 ## ::ur_exp_command_buffer_update_exec_info_desc_t
EXP_SAMPLER_MIP_PROPERTIES = 0x2000 ## ::ur_exp_sampler_mip_properties_t
EXP_INTEROP_MEM_DESC = 0x2001 ## ::ur_exp_interop_mem_desc_t
EXP_INTEROP_SEMAPHORE_DESC = 0x2002 ## ::ur_exp_interop_semaphore_desc_t
Expand Down Expand Up @@ -443,6 +448,7 @@ class ur_result_v(IntEnum):
ERROR_INVALID_COMMAND_BUFFER_EXP = 0x1000 ## Invalid Command-Buffer
ERROR_INVALID_COMMAND_BUFFER_SYNC_POINT_EXP = 0x1001## Sync point is not valid for the command-buffer
ERROR_INVALID_COMMAND_BUFFER_SYNC_POINT_WAIT_LIST_EXP = 0x1002 ## Sync point wait list is invalid
ERROR_INVALID_COMMAND_BUFFER_COMMAND_HANDLE_EXP = 0x1003## Handle to command-buffer command is invalid
ERROR_UNKNOWN = 0x7ffffffe ## Unknown or internal error

class ur_result_t(c_int):
Expand Down Expand Up @@ -833,6 +839,10 @@ class ur_device_info_v(IntEnum):
## version than older devices.
VIRTUAL_MEMORY_SUPPORT = 114 ## [::ur_bool_t] return true if the device supports virtual memory.
ESIMD_SUPPORT = 115 ## [::ur_bool_t] return true if the device supports ESIMD.
COMMAND_BUFFER_SUPPORT_EXP = 0x1000 ## [::ur_bool_t] returns true if the device supports the use of
## command-buffers.
COMMAND_BUFFER_UPDATE_SUPPORT_EXP = 0x1001 ## [::ur_bool_t] returns true if the device supports updating the
## commands in a command-buffer.
BINDLESS_IMAGES_SUPPORT_EXP = 0x2000 ## [::ur_bool_t] returns true if the device supports the creation of
## bindless images
BINDLESS_IMAGES_SHARED_USM_SUPPORT_EXP = 0x2001 ## [::ur_bool_t] returns true if the device supports the creation of
Expand Down Expand Up @@ -2242,6 +2252,69 @@ class ur_exp_command_buffer_desc_t(Structure):
("pNext", c_void_p) ## [in][optional] pointer to extension-specific structure
]

###############################################################################
## @brief Descriptor type for updating a kernel command memobj argument.
class ur_exp_command_buffer_update_memobj_arg_desc_t(Structure):
_fields_ = [
("stype", ur_structure_type_t), ## [in] type of this structure, must be
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
("argIndex", c_ulong), ## [in] Argument index.
("pProperties", *), ## [in][optinal] Pointer to memory object properties.
("hArgValue", ur_mem_handle_t) ## [in][optional] Handle of memory object.
]

###############################################################################
## @brief Descriptor type for updating a kernel command pointer argument.
class ur_exp_command_buffer_update_pointer_arg_desc_t(Structure):
_fields_ = [
("stype", ur_structure_type_t), ## [in] type of this structure, must be
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
("argIndex", c_ulong), ## [in] Argument index.
("pProperties", *), ## [in][optinal] Pointer to USM pointer properties.
("pArgValue", *) ## [in][optional] USM pointer to memory location holding the argument
## value.
]

###############################################################################
## @brief Descriptor type for updating kernel command execution info.
class ur_exp_command_buffer_update_exec_info_desc_t(Structure):
_fields_ = [
("stype", ur_structure_type_t), ## [in] type of this structure, must be
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
("propName", ur_kernel_exec_info_t), ## [in] Name of execution attribute.
("propSize", c_size_t), ## [in] Size of execution attribute.
("pProperties", *), ## [in][optional] Pointer to execution info properties.
("pPropValue", *) ## [in] Pointer to memory location holding the property value.
]

###############################################################################
## @brief Descriptor type for updating a kernel launch command.
class ur_exp_command_buffer_update_kernel_launch_desc_t(Structure):
_fields_ = [
("stype", ur_structure_type_t), ## [in] type of this structure, must be
## ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC
("pNext", c_void_p), ## [in][optional] pointer to extension-specific structure
("numMemobjArgs", c_ulong), ## [in] Length of pArgMemobjList.
("numPointerArgs", c_ulong), ## [in] Length of pArgPointerList.
("numExecInfos", c_ulong), ## [in] Length of pExecInfoList.
("workDim", c_ulong), ## [in] Number of work dimensions in the kernel ND-range, from 1-3.
("pArgMemobjList", POINTER(ur_exp_command_buffer_update_memobj_arg_desc_t)),## [in] An array describing the new kernel mem obj arguments for the
## command.
("pArgPointerList", POINTER(ur_exp_command_buffer_update_pointer_arg_desc_t)), ## [in] An array describing the new kernel pointer arguments for the
## command.
("pArgExecInfoList", POINTER(ur_exp_command_buffer_update_exec_info_desc_t)), ## [in] An array describing the execution info objects for the command.
("pGlobalWorkOffset", POINTER(c_size_t)), ## [in] Array of workDim unsigned values that describe the offset used to
## calculate the global ID.
("pGlobalWorkSize", POINTER(c_size_t)), ## [in] Array of workDim unsigned values that describe the number of
## global work-items.
("pLocalWorkSize", POINTER(c_size_t)) ## [in] Array of workDim unsigned values that describe the number of
## work-items that make up a work-group. If nullptr, the runtime
## implementation will choose the work-group size.
]

###############################################################################
## @brief A value that identifies a command inside of a command-buffer, used for
## defining dependencies between commands in the same command-buffer.
Expand All @@ -2253,6 +2326,11 @@ class ur_exp_command_buffer_sync_point_t(c_ulong):
class ur_exp_command_buffer_handle_t(c_void_p):
pass

###############################################################################
## @brief Handle of a Command-Buffer command
class ur_exp_command_buffer_command_handle_t(c_void_p):
pass

###############################################################################
## @brief Supported peer info
class ur_exp_peer_info_v(IntEnum):
Expand Down Expand Up @@ -3431,9 +3509,9 @@ class ur_usm_exp_dditable_t(Structure):
###############################################################################
## @brief Function-pointer for urCommandBufferAppendKernelLaunchExp
if __use_win_types:
_urCommandBufferAppendKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t) )
_urCommandBufferAppendKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_command_handle_t) )
else:
_urCommandBufferAppendKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t) )
_urCommandBufferAppendKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_sync_point_t), POINTER(ur_exp_command_buffer_command_handle_t) )

###############################################################################
## @brief Function-pointer for urCommandBufferAppendMemcpyUSMExp
Expand Down Expand Up @@ -3491,6 +3569,13 @@ class ur_usm_exp_dditable_t(Structure):
else:
_urCommandBufferEnqueueExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_handle_t, ur_queue_handle_t, c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) )

###############################################################################
## @brief Function-pointer for urCommandBufferUpdateKernelLaunchExp
if __use_win_types:
_urCommandBufferUpdateKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_exp_command_buffer_command_handle_t, POINTER(ur_exp_command_buffer_update_kernel_launch_desc_t) )
else:
_urCommandBufferUpdateKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_exp_command_buffer_command_handle_t, POINTER(ur_exp_command_buffer_update_kernel_launch_desc_t) )


###############################################################################
## @brief Table of CommandBufferExp functions pointers
Expand All @@ -3508,7 +3593,8 @@ class ur_command_buffer_exp_dditable_t(Structure):
("pfnAppendMembufferCopyRectExp", c_void_p), ## _urCommandBufferAppendMembufferCopyRectExp_t
("pfnAppendMembufferWriteRectExp", c_void_p), ## _urCommandBufferAppendMembufferWriteRectExp_t
("pfnAppendMembufferReadRectExp", c_void_p), ## _urCommandBufferAppendMembufferReadRectExp_t
("pfnEnqueueExp", c_void_p) ## _urCommandBufferEnqueueExp_t
("pfnEnqueueExp", c_void_p), ## _urCommandBufferEnqueueExp_t
("pfnUpdateKernelLaunchExp", c_void_p) ## _urCommandBufferUpdateKernelLaunchExp_t
]

###############################################################################
Expand Down Expand Up @@ -4054,6 +4140,7 @@ def __init__(self, version : ur_api_version_t):
self.urCommandBufferAppendMembufferWriteRectExp = _urCommandBufferAppendMembufferWriteRectExp_t(self.__dditable.CommandBufferExp.pfnAppendMembufferWriteRectExp)
self.urCommandBufferAppendMembufferReadRectExp = _urCommandBufferAppendMembufferReadRectExp_t(self.__dditable.CommandBufferExp.pfnAppendMembufferReadRectExp)
self.urCommandBufferEnqueueExp = _urCommandBufferEnqueueExp_t(self.__dditable.CommandBufferExp.pfnEnqueueExp)
self.urCommandBufferUpdateKernelLaunchExp = _urCommandBufferUpdateKernelLaunchExp_t(self.__dditable.CommandBufferExp.pfnUpdateKernelLaunchExp)

# call driver to get function pointers
UsmP2PExp = ur_usm_p2p_exp_dditable_t()
Expand Down
Loading

0 comments on commit e2a9320

Please sign in to comment.