diff --git a/include/ur_api.h b/include/ur_api.h index a0a5dd30a6..5fb5dede1f 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -264,7 +264,8 @@ typedef enum ur_structure_type_t { UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC = 0x1001, ///< ::ur_exp_command_buffer_update_kernel_launch_desc_t UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC = 0x1002, ///< ::ur_exp_command_buffer_update_memobj_arg_desc_t UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC = 0x1003, ///< ::ur_exp_command_buffer_update_pointer_arg_desc_t - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC = 0x1004, ///< ::ur_exp_command_buffer_update_exec_info_desc_t + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC = 0x1004, ///< ::ur_exp_command_buffer_update_value_arg_desc_t + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC = 0x1005, ///< ::ur_exp_command_buffer_update_exec_info_desc_t UR_STRUCTURE_TYPE_EXP_SAMPLER_MIP_PROPERTIES = 0x2000, ///< ::ur_exp_sampler_mip_properties_t UR_STRUCTURE_TYPE_EXP_INTEROP_MEM_DESC = 0x2001, ///< ::ur_exp_interop_mem_desc_t UR_STRUCTURE_TYPE_EXP_INTEROP_SEMAPHORE_DESC = 0x2002, ///< ::ur_exp_interop_semaphore_desc_t @@ -7798,6 +7799,19 @@ typedef struct ur_exp_command_buffer_update_pointer_arg_desc_t { } ur_exp_command_buffer_update_pointer_arg_desc_t; +/////////////////////////////////////////////////////////////////////////////// +/// @brief Descriptor type for updating a kernel command value argument. +typedef struct ur_exp_command_buffer_update_value_arg_desc_t { + ur_structure_type_t stype; ///< [in] type of this structure, must be + ///< ::UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC + const void *pNext; ///< [in][optional] pointer to extension-specific structure + uint32_t argIndex; ///< [in] Argument index. + uint32_t argSize; ///< [in] Argument size. + const ur_kernel_arg_value_properties_t *pProperties; ///< [in][optinal] Pointer to memory object properties. + const void *pArgValue; ///< [in][optional] Argument value representing kernel arg type. + +} ur_exp_command_buffer_update_value_arg_desc_t; + /////////////////////////////////////////////////////////////////////////////// /// @brief Descriptor type for updating kernel command execution info. typedef struct ur_exp_command_buffer_update_exec_info_desc_t { @@ -7819,20 +7833,25 @@ typedef struct ur_exp_command_buffer_update_kernel_launch_desc_t { const void *pNext; ///< [in][optional] pointer to extension-specific structure uint32_t numMemobjArgs; ///< [in] Length of pArgMemobjList. uint32_t numPointerArgs; ///< [in] Length of pArgPointerList. + uint32_t numValueArgs; ///< [in] Length of pArgValueList. uint32_t numExecInfos; ///< [in] Length of pExecInfoList. - uint32_t workDim; ///< [in] Number of work dimensions in the kernel ND-range, from 1-3. - const ur_exp_command_buffer_update_memobj_arg_desc_t *pArgMemobjList; ///< [in] An array describing the new kernel mem obj arguments for the - ///< command. - const ur_exp_command_buffer_update_pointer_arg_desc_t *pArgPointerList; ///< [in] An array describing the new kernel pointer arguments for the + uint32_t workDim; ///< [in][optional] Number of work dimensions in the kernel ND-range, from + ///< 1-3. + const ur_exp_command_buffer_update_memobj_arg_desc_t *pArgMemobjList; ///< [in][optional] An array describing the new kernel mem obj arguments + ///< for the command. + const ur_exp_command_buffer_update_pointer_arg_desc_t *pArgPointerList; ///< [in][optional] An array describing the new kernel pointer arguments + ///< for the command. + const ur_exp_command_buffer_update_value_arg_desc_t *pArgValueList; ///< [in][optional] An array describing the new kernel value arguments for + ///< the command. + const ur_exp_command_buffer_update_exec_info_desc_t *pArgExecInfoList; ///< [in][optional] An array describing the execution info objects for the ///< command. - const ur_exp_command_buffer_update_exec_info_desc_t *pArgExecInfoList; ///< [in] An array describing the execution info objects for the command. - size_t *pGlobalWorkOffset; ///< [in] Array of workDim unsigned values that describe the offset used to - ///< calculate the global ID. - size_t *pGlobalWorkSize; ///< [in] Array of workDim unsigned values that describe the number of - ///< global work-items. - size_t *pLocalWorkSize; ///< [in] Array of workDim unsigned values that describe the number of - ///< work-items that make up a work-group. If nullptr, the runtime - ///< implementation will choose the work-group size. + size_t *pGlobalWorkOffset; ///< [in][optional] Array of workDim unsigned values that describe the + ///< offset used to calculate the global ID. + size_t *pGlobalWorkSize; ///< [in][optional] Array of workDim unsigned values that describe the + ///< number of global work-items. + size_t *pLocalWorkSize; ///< [in][optional] Array of workDim unsigned values that describe the + ///< number of work-items that make up a work-group. If nullptr, the + ///< runtime implementation will choose the work-group size. } ur_exp_command_buffer_update_kernel_launch_desc_t; @@ -8426,12 +8445,6 @@ urCommandBufferEnqueueExp( /// + `NULL == hCommand` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pUpdateKernelLaunch` -/// + `NULL == pUpdateKernelLaunch->pArgMemobjList` -/// + `NULL == pUpdateKernelLaunch->pArgPointerList` -/// + `NULL == pUpdateKernelLaunch->pArgExecInfoList` -/// + `NULL == pUpdateKernelLaunch->pGlobalWorkOffset` -/// + `NULL == pUpdateKernelLaunch->pGlobalWorkSize` -/// + `NULL == pUpdateKernelLaunch->pLocalWorkSize` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If update functionality is not supported by the device. /// - ::UR_RESULT_ERROR_INVALID_OPERATION diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 03a2e4f115..e068533fcd 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -323,6 +323,7 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_exp_command_buffer_desc_t params); inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_exp_command_buffer_update_memobj_arg_desc_t params); inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_exp_command_buffer_update_pointer_arg_desc_t params); +inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_exp_command_buffer_update_value_arg_desc_t params); inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_exp_command_buffer_update_exec_info_desc_t params); inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur_exp_command_buffer_update_kernel_launch_desc_t params); inline std::ostream &operator<<(std::ostream &os, ur_exp_peer_info_t value); @@ -1017,6 +1018,9 @@ inline std::ostream &operator<<(std::ostream &os, ur_structure_type_t value) { case UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC: os << "UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC"; break; + case UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC: + os << "UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC"; + break; case UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC: os << "UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC"; break; @@ -1250,6 +1254,11 @@ inline ur_result_t printStruct(std::ostream &os, const void *ptr) { printPtr(os, pstruct); } break; + case UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC: { + const ur_exp_command_buffer_update_value_arg_desc_t *pstruct = (const ur_exp_command_buffer_update_value_arg_desc_t *)ptr; + printPtr(os, pstruct); + } break; + case UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC: { const ur_exp_command_buffer_update_exec_info_desc_t *pstruct = (const ur_exp_command_buffer_update_exec_info_desc_t *)ptr; printPtr(os, pstruct); @@ -9291,6 +9300,46 @@ inline std::ostream &operator<<(std::ostream &os, const struct ur_exp_command_bu return os; } /////////////////////////////////////////////////////////////////////////////// +/// @brief Print operator for the ur_exp_command_buffer_update_value_arg_desc_t type +/// @returns +/// std::ostream & +inline std::ostream &operator<<(std::ostream &os, const struct ur_exp_command_buffer_update_value_arg_desc_t params) { + os << "(struct ur_exp_command_buffer_update_value_arg_desc_t){"; + + os << ".stype = "; + + os << (params.stype); + + os << ", "; + os << ".pNext = "; + + ur::details::printStruct(os, + (params.pNext)); + + os << ", "; + os << ".argIndex = "; + + os << (params.argIndex); + + os << ", "; + os << ".argSize = "; + + os << (params.argSize); + + os << ", "; + os << ".pProperties = "; + + os << (params.pProperties); + + os << ", "; + os << ".pArgValue = "; + + os << (params.pArgValue); + + os << "}"; + return os; +} +/////////////////////////////////////////////////////////////////////////////// /// @brief Print operator for the ur_exp_command_buffer_update_exec_info_desc_t type /// @returns /// std::ostream & @@ -9357,6 +9406,11 @@ inline std::ostream &operator<<(std::ostream &os, const struct ur_exp_command_bu os << (params.numPointerArgs); + os << ", "; + os << ".numValueArgs = "; + + os << (params.numValueArgs); + os << ", "; os << ".numExecInfos = "; @@ -9379,6 +9433,12 @@ inline std::ostream &operator<<(std::ostream &os, const struct ur_exp_command_bu ur::details::printPtr(os, (params.pArgPointerList)); + os << ", "; + os << ".pArgValueList = "; + + ur::details::printPtr(os, + (params.pArgValueList)); + os << ", "; os << ".pArgExecInfoList = "; diff --git a/scripts/core/EXP-COMMAND-BUFFER.rst b/scripts/core/EXP-COMMAND-BUFFER.rst index 823d186ec2..0e4c20af1f 100644 --- a/scripts/core/EXP-COMMAND-BUFFER.rst +++ b/scripts/core/EXP-COMMAND-BUFFER.rst @@ -165,7 +165,7 @@ support updating the configuration of kernel commands recorded to a command-buffer. Support of this is reported by returning true in the ${X}_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP query. -Updating a kernel commands is done by passing the new kernel configuration +Updating kernel commands is done by passing the new kernel configuration to ${x}CommandBufferUpdateKernelLaunchExp along with the command handle of the kernel command to update. Configurations that can be changed are the kernels ND-Range and parameters. @@ -178,11 +178,12 @@ kernels ND-Range and parameters. nullptr, true }; - ${x}_exp_command_buffer_command_handle_t handle; + ${x}_exp_command_buffer_handle_t hCommandBuffer; ${x}CommandBufferCreateExp(hContext, hDevice, &desc, &handle); // Append a kernel command which has two buffer parameters, an input // and an output. + ${x}_exp_command_buffer_command_handle_t handle; ${x}CommandBufferAppendKernelLaunchExp(hCommandBuffer, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize, 0, nullptr, @@ -193,7 +194,7 @@ kernels ND-Range and parameters. // Define kernel argument at index 0 to be a new input buffer object ${x}_exp_command_buffer_update_memobj_arg_desc_t newInputArg { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG, // stype + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC, // stype nullptr, // pNext, 0, // argIndex, nullptr, // pProperties @@ -202,7 +203,7 @@ kernels ND-Range and parameters. // Define kernel argument at index 1 to be a new output buffer object ${x}_exp_command_buffer_update_memobj_arg_desc_t newOutputArg { - UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG, // stype + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC, // stype nullptr, // pNext, 1, // argIndex, nullptr, // pProperties @@ -216,10 +217,12 @@ kernels ND-Range and parameters. nullptr, // pNext 2, // numMemobjArgs 0, // numPointerArgs + 0, // numValueArgs 0, // numExecInfos 0, // workDim; new_args, // pArgMemobjList nullptr, // pArgPointerList + nullptr, // pArgValueList nullptr, // pArgExecInfoList nullptr, // pGlobalWorkOffset nullptr, // pGlobalWorkSize diff --git a/scripts/core/exp-command-buffer.yml b/scripts/core/exp-command-buffer.yml index 4257713c69..f1a8c037de 100644 --- a/scripts/core/exp-command-buffer.yml +++ b/scripts/core/exp-command-buffer.yml @@ -61,9 +61,12 @@ etors: - name: EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC desc: $x_exp_command_buffer_update_pointer_arg_desc_t value: "0x1003" + - name: EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC + desc: $x_exp_command_buffer_update_value_arg_desc_t + value: "0x1004" - name: EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC desc: $x_exp_command_buffer_update_exec_info_desc_t - value: "0x1004" + value: "0x1005" --- #-------------------------------------------------------------------------- type: enum @@ -120,6 +123,25 @@ members: desc: "[in][optional] USM pointer to memory location holding the argument value." --- #-------------------------------------------------------------------------- type: struct +desc: "Descriptor type for updating a kernel command value argument." +base: $x_base_desc_t +name: $x_exp_command_buffer_update_value_arg_desc_t +members: + - type: uint32_t + name: argIndex + desc: "[in] Argument index." + - type: uint32_t + name: argSize + desc: "[in] Argument size." + - type: "const ur_kernel_arg_value_properties_t *" + name: pProperties + desc: "[in][optinal] Pointer to memory object properties." + - type: "const void *" + name: pArgValue + desc: "[in][optional] Argument value representing kernel arg type." + +--- #-------------------------------------------------------------------------- +type: struct desc: "Descriptor type for updating kernel command execution info." base: $x_base_desc_t name: $x_exp_command_buffer_update_exec_info_desc_t @@ -149,30 +171,36 @@ members: - type: uint32_t name: numPointerArgs desc: "[in] Length of pArgPointerList." + - type: uint32_t + name: numValueArgs + desc: "[in] Length of pArgValueList." - type: uint32_t name: numExecInfos desc: "[in] Length of pExecInfoList." - type: uint32_t name: workDim - desc: "[in] Number of work dimensions in the kernel ND-range, from 1-3." + desc: "[in][optional] Number of work dimensions in the kernel ND-range, from 1-3." - type: "const $x_exp_command_buffer_update_memobj_arg_desc_t*" name: pArgMemobjList - desc: "[in] An array describing the new kernel mem obj arguments for the command." + desc: "[in][optional] An array describing the new kernel mem obj arguments for the command." - type: "const $x_exp_command_buffer_update_pointer_arg_desc_t*" name: pArgPointerList - desc: "[in] An array describing the new kernel pointer arguments for the command." + desc: "[in][optional] An array describing the new kernel pointer arguments for the command." + - type: "const $x_exp_command_buffer_update_value_arg_desc_t*" + name: pArgValueList + desc: "[in][optional] An array describing the new kernel value arguments for the command." - type: "const $x_exp_command_buffer_update_exec_info_desc_t*" name: pArgExecInfoList - desc: "[in] An array describing the execution info objects for the command." + desc: "[in][optional] An array describing the execution info objects for the command." - type: "size_t*" name: pGlobalWorkOffset - desc: "[in] Array of workDim unsigned values that describe the offset used to calculate the global ID." + desc: "[in][optional] Array of workDim unsigned values that describe the offset used to calculate the global ID." - type: "size_t*" name: pGlobalWorkSize - desc: "[in] Array of workDim unsigned values that describe the number of global work-items." + desc: "[in][optional] Array of workDim unsigned values that describe the number of global work-items." - type: "size_t*" name: pLocalWorkSize - desc: "[in] Array of workDim unsigned values that describe the number of work-items that make up a work-group. If nullptr, the runtime implementation will choose the work-group size." + desc: "[in][optional] Array of workDim unsigned values that describe the number of work-items that make up a work-group. If nullptr, the runtime implementation will choose the work-group size." --- #-------------------------------------------------------------------------- type: typedef desc: "A value that identifies a command inside of a command-buffer, used for defining dependencies between commands in the same command-buffer." diff --git a/source/adapters/cuda/command_buffer.cpp b/source/adapters/cuda/command_buffer.cpp index 55ace08aa4..e37d13f6dd 100644 --- a/source/adapters/cuda/command_buffer.cpp +++ b/source/adapters/cuda/command_buffer.cpp @@ -241,10 +241,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp( // Get sync point and register the cuNode with it. auto NodeSP = std::make_shared(GraphNode); - *pSyncPoint = hCommandBuffer->AddSyncPoint(NodeSP); + if (pSyncPoint) { + *pSyncPoint = hCommandBuffer->AddSyncPoint(NodeSP); + } - *phCommand = - hCommandBuffer->AddCommandHandle(hKernel, NodeSP, NodeParams).get(); + *phCommand = hCommandBuffer + ->AddCommandHandle(hKernel, NodeSP, NodeParams, workDim, + pGlobalWorkOffset, pGlobalWorkSize, + pLocalWorkSize) + .get(); } catch (ur_result_t Err) { Result = Err; } @@ -580,7 +585,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( const ur_exp_command_buffer_update_pointer_arg_desc_t *ArgPointerList = pKernelLaunch->pArgPointerList; for (uint32_t i = 0; i < NumPointerArgs; i++) { - auto PointerArgDesc = ArgPointerList[i]; + const auto &PointerArgDesc = ArgPointerList[i]; uint32_t ArgIndex = PointerArgDesc.argIndex; const void *ArgValue = PointerArgDesc.pArgValue; @@ -598,7 +603,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( const ur_exp_command_buffer_update_memobj_arg_desc_t *ArgMemobjList = pKernelLaunch->pArgMemobjList; for (uint32_t i = 0; i < NumMemobjArgs; i++) { - auto MemobjArgDesc = ArgMemobjList[i]; + const auto &MemobjArgDesc = ArgMemobjList[i]; uint32_t ArgIndex = MemobjArgDesc.argIndex; ur_mem_handle_t ArgValue = MemobjArgDesc.hArgValue; @@ -616,20 +621,58 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( } } + // Update value arguments to the kernel + uint32_t NumValueArgs = pKernelLaunch->numValueArgs; + const ur_exp_command_buffer_update_value_arg_desc_t *ArgValueList = + pKernelLaunch->pArgValueList; + for (uint32_t i = 0; i < NumValueArgs; i++) { + const auto &ValueArgDesc = ArgValueList[i]; + uint32_t ArgIndex = ValueArgDesc.argIndex; + size_t ArgSize = ValueArgDesc.argSize; + const void *ArgValue = ValueArgDesc.pArgValue; + + ur_result_t Result = UR_RESULT_SUCCESS; + + try { + Kernel->setKernelArg(ArgIndex, ArgSize, ArgValue); + } catch (ur_result_t Err) { + Result = Err; + return Result; + } + } + // Set the updated ND range - ur_context_handle_t Context = hCommand->CommandBuffer->Context; - ur_device_handle_t Device = hCommand->CommandBuffer->Device; - size_t *GlobalWorkOffset = pKernelLaunch->pGlobalWorkOffset; - size_t *GlobalWorkSize = pKernelLaunch->pGlobalWorkSize; - size_t *LocalWorkSize = pKernelLaunch->pLocalWorkSize; - uint32_t WorkDim = pKernelLaunch->workDim; + const uint32_t NewWorkDim = pKernelLaunch->workDim; + if (NewWorkDim != 0) { + UR_ASSERT(NewWorkDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION); + UR_ASSERT(NewWorkDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION); + hCommand->WorkDim = NewWorkDim; + } + + if (pKernelLaunch->pGlobalWorkOffset) { + hCommand->SetGlobalOffset(pKernelLaunch->pGlobalWorkOffset); + } + + if (pKernelLaunch->pGlobalWorkSize) { + hCommand->SetGlobalSize(pKernelLaunch->pGlobalWorkSize); + } + + if (pKernelLaunch->pLocalWorkSize) { + hCommand->SetLocalSize(pKernelLaunch->pLocalWorkSize); + } + + size_t *GlobalWorkOffset = hCommand->GlobalWorkOffset; + size_t *GlobalWorkSize = hCommand->GlobalWorkSize; + size_t *LocalWorkSize = hCommand->LocalWorkSize; + uint32_t WorkDim = hCommand->WorkDim; // Set the number of threads per block to the number of threads per warp // by default unless user has provided a better number size_t ThreadsPerBlock[3] = {32u, 1u, 1u}; size_t BlocksPerGrid[3] = {1u, 1u, 1u}; CUfunction CuFunc = Kernel->get(); - uint32_t LocalSize = Kernel->getLocalSize(); + ur_context_handle_t Context = hCommand->CommandBuffer->Context; + ur_device_handle_t Device = hCommand->CommandBuffer->Device; auto Result = setKernelParams(Context, Device, WorkDim, GlobalWorkOffset, GlobalWorkSize, LocalWorkSize, Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid); @@ -638,6 +681,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( } CUDA_KERNEL_NODE_PARAMS &Params = hCommand->Params; + Params.func = CuFunc; Params.gridDimX = BlocksPerGrid[0]; Params.gridDimY = BlocksPerGrid[1]; @@ -645,10 +689,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( Params.blockDimX = ThreadsPerBlock[0]; Params.blockDimY = ThreadsPerBlock[1]; Params.blockDimZ = ThreadsPerBlock[2]; - Params.sharedMemBytes = LocalSize; + Params.sharedMemBytes = Kernel->getLocalSize(); Params.kernelParams = const_cast(Kernel->getArgIndices().data()); CUgraphNode Node = *(hCommand->Node); - UR_CHECK_ERROR(cuGraphKernelNodeSetParams(Node, &Params)); + CUgraphExec CudaGraphExec = hCommand->CommandBuffer->CudaGraphExec; + UR_CHECK_ERROR(cuGraphExecKernelNodeSetParams(CudaGraphExec, Node, &Params)); return UR_RESULT_SUCCESS; } diff --git a/source/adapters/cuda/command_buffer.hpp b/source/adapters/cuda/command_buffer.hpp index a0d45f08d8..c31c053f2a 100644 --- a/source/adapters/cuda/command_buffer.hpp +++ b/source/adapters/cuda/command_buffer.hpp @@ -181,14 +181,60 @@ static inline const char *getUrResultString(ur_result_t Result) { struct ur_exp_command_buffer_command_handle_t_ { ur_exp_command_buffer_command_handle_t_( ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel, - std::shared_ptr Node, CUDA_KERNEL_NODE_PARAMS Params) + std::shared_ptr Node, CUDA_KERNEL_NODE_PARAMS Params, + uint32_t WorkDim, const size_t *pGlobalWorkOffset, + const size_t *pGlobalWorkSize, const size_t *pLocalWorkSize) : CommandBuffer(CommandBuffer), Kernel(Kernel), Node(Node), - Params(Params) {} + Params(Params), WorkDim(WorkDim) { + const size_t CopySize = sizeof(size_t) * WorkDim; + std::memcpy(GlobalWorkOffset, pGlobalWorkOffset, CopySize); + std::memcpy(GlobalWorkSize, pGlobalWorkSize, CopySize); + std::memcpy(LocalWorkSize, pLocalWorkSize, CopySize); + + if (WorkDim < 3) { + const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim); + std::memset(GlobalWorkOffset + WorkDim, 0, ZeroSize); + std::memset(GlobalWorkSize + WorkDim, 0, ZeroSize); + std::memset(LocalWorkSize + WorkDim, 0, ZeroSize); + } + } + + void SetGlobalOffset(const size_t *pGlobalWorkOffset) { + const size_t CopySize = sizeof(size_t) * WorkDim; + std::memcpy(GlobalWorkOffset, pGlobalWorkOffset, CopySize); + if (WorkDim < 3) { + const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim); + std::memset(GlobalWorkOffset + WorkDim, 0, ZeroSize); + } + } + + void SetGlobalSize(const size_t *pGlobalWorkSize) { + const size_t CopySize = sizeof(size_t) * WorkDim; + std::memcpy(GlobalWorkSize, pGlobalWorkSize, CopySize); + if (WorkDim < 3) { + const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim); + std::memset(GlobalWorkSize + WorkDim, 0, ZeroSize); + } + } + + void SetLocalSize(const size_t *pLocalWorkSize) { + const size_t copy_size = sizeof(size_t) * WorkDim; + std::memcpy(LocalWorkSize, pLocalWorkSize, copy_size); + if (WorkDim < 3) { + const size_t zero_size = sizeof(size_t) * (3 - WorkDim); + std::memset(LocalWorkSize + WorkDim, 0, zero_size); + } + } ur_exp_command_buffer_handle_t CommandBuffer; ur_kernel_handle_t Kernel; std::shared_ptr Node; CUDA_KERNEL_NODE_PARAMS Params; + + uint32_t WorkDim; + size_t GlobalWorkOffset[3]; + size_t GlobalWorkSize[3]; + size_t LocalWorkSize[3]; }; struct ur_exp_command_buffer_handle_t_ { @@ -225,10 +271,13 @@ struct ur_exp_command_buffer_handle_t_ { // @return Shared pointer to the created handle. std::shared_ptr AddCommandHandle(ur_kernel_handle_t Kernel, std::shared_ptr Node, - const CUDA_KERNEL_NODE_PARAMS &Params) { + const CUDA_KERNEL_NODE_PARAMS &Params, uint32_t WorkDim, + const size_t *GlobalWorkOffset, const size_t *GlobalWorkSize, + const size_t *LocalWorkSize) { Handles.push_back(std::make_shared( - this, Kernel, Node, Params)); + this, Kernel, Node, Params, WorkDim, GlobalWorkOffset, GlobalWorkSize, + LocalWorkSize)); return Handles.back(); } diff --git a/source/adapters/cuda/device.cpp b/source/adapters/cuda/device.cpp index 45297dea56..800e68ff32 100644 --- a/source/adapters/cuda/device.cpp +++ b/source/adapters/cuda/device.cpp @@ -1030,6 +1030,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, case UR_DEVICE_INFO_GPU_HW_THREADS_PER_EU: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; + case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP: + case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: + return ReturnValue(true); + default: break; } diff --git a/source/adapters/cuda/ur_interface_loader.cpp b/source/adapters/cuda/ur_interface_loader.cpp index 4bf7367017..35e7a84b56 100644 --- a/source/adapters/cuda/ur_interface_loader.cpp +++ b/source/adapters/cuda/ur_interface_loader.cpp @@ -290,6 +290,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnAppendMemBufferWriteRectExp = urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; + pDdiTable->pfnUpdateKernelLaunchExp = urCommandBufferUpdateKernelLaunchExp; return retVal; } diff --git a/source/adapters/hip/device.cpp b/source/adapters/hip/device.cpp index e40470f9aa..21e88c4d6a 100644 --- a/source/adapters/hip/device.cpp +++ b/source/adapters/hip/device.cpp @@ -837,6 +837,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, case UR_DEVICE_INFO_ASYNC_BARRIER: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; + case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP: + case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: + return ReturnValue(false); + default: break; } diff --git a/source/adapters/hip/ur_interface_loader.cpp b/source/adapters/hip/ur_interface_loader.cpp index afefdc1a0a..55bcff2619 100644 --- a/source/adapters/hip/ur_interface_loader.cpp +++ b/source/adapters/hip/ur_interface_loader.cpp @@ -287,6 +287,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnAppendMemBufferWriteRectExp = urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; + pDdiTable->pfnUpdateKernelLaunchExp = urCommandBufferUpdateKernelLaunchExp; return retVal; } diff --git a/source/adapters/level_zero/device.cpp b/source/adapters/level_zero/device.cpp index abdfd2e541..6c4bdb1045 100644 --- a/source/adapters/level_zero/device.cpp +++ b/source/adapters/level_zero/device.cpp @@ -823,6 +823,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo( return ReturnValue(result); } + case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: + return ReturnValue(false); + default: urPrint("Unsupported ParamName in urGetDeviceInfo\n"); urPrint("ParamName=%d(0x%x)\n", ParamName, ParamName); diff --git a/source/adapters/level_zero/ur_interface_loader.cpp b/source/adapters/level_zero/ur_interface_loader.cpp index c281341bff..9da7a4d5c0 100644 --- a/source/adapters/level_zero/ur_interface_loader.cpp +++ b/source/adapters/level_zero/ur_interface_loader.cpp @@ -337,6 +337,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnAppendMemBufferWriteRectExp = urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; + pDdiTable->pfnUpdateKernelLaunchExp = urCommandBufferUpdateKernelLaunchExp; return retVal; } diff --git a/source/adapters/native_cpu/device.cpp b/source/adapters/native_cpu/device.cpp index 3432ce780e..3a67b30c08 100644 --- a/source/adapters/native_cpu/device.cpp +++ b/source/adapters/native_cpu/device.cpp @@ -304,6 +304,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, CASE_UR_UNSUPPORTED(UR_DEVICE_INFO_MAX_MEMORY_BANDWIDTH); case UR_DEVICE_INFO_VIRTUAL_MEMORY_SUPPORT: return ReturnValue(false); + + case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP: + case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: + return ReturnValue(false); + default: DIE_NO_IMPLEMENTATION; } diff --git a/source/adapters/native_cpu/ur_interface_loader.cpp b/source/adapters/native_cpu/ur_interface_loader.cpp index 9408101927..d884539d3d 100644 --- a/source/adapters/native_cpu/ur_interface_loader.cpp +++ b/source/adapters/native_cpu/ur_interface_loader.cpp @@ -283,6 +283,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnAppendMemBufferWriteRectExp = urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; + pDdiTable->pfnUpdateKernelLaunchExp = urCommandBufferUpdateKernelLaunchExp; return retVal; } diff --git a/source/adapters/opencl/device.cpp b/source/adapters/opencl/device.cpp index 27577eab39..c437dbc5c8 100644 --- a/source/adapters/opencl/device.cpp +++ b/source/adapters/opencl/device.cpp @@ -947,6 +947,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, case UR_DEVICE_INFO_ASYNC_BARRIER: { return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; } + + case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP: + return ReturnValue(true); + case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP: + return ReturnValue(false); + default: { return UR_RESULT_ERROR_INVALID_ENUMERATION; } diff --git a/source/adapters/opencl/ur_interface_loader.cpp b/source/adapters/opencl/ur_interface_loader.cpp index d09b64c6b0..b66212000e 100644 --- a/source/adapters/opencl/ur_interface_loader.cpp +++ b/source/adapters/opencl/ur_interface_loader.cpp @@ -297,6 +297,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetCommandBufferExpProcAddrTable( pDdiTable->pfnAppendMemBufferWriteRectExp = urCommandBufferAppendMemBufferWriteRectExp; pDdiTable->pfnEnqueueExp = urCommandBufferEnqueueExp; + pDdiTable->pfnUpdateKernelLaunchExp = urCommandBufferUpdateKernelLaunchExp; return retVal; } diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index 81eab43d49..7a2355167f 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -7640,30 +7640,6 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( if (NULL == pUpdateKernelLaunch) { return UR_RESULT_ERROR_INVALID_NULL_POINTER; } - - if (NULL == pUpdateKernelLaunch->pArgMemobjList) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (NULL == pUpdateKernelLaunch->pArgPointerList) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (NULL == pUpdateKernelLaunch->pArgExecInfoList) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (NULL == pUpdateKernelLaunch->pGlobalWorkOffset) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (NULL == pUpdateKernelLaunch->pGlobalWorkSize) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } - - if (NULL == pUpdateKernelLaunch->pLocalWorkSize) { - return UR_RESULT_ERROR_INVALID_NULL_POINTER; - } } ur_result_t result = diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 4671709ab1..cb10d094b1 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -6885,6 +6885,19 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp( reinterpret_cast(hCommand) ->handle; + uint32_t NumMemobjArgs = pUpdateKernelLaunch->numMemobjArgs; + if (NumMemobjArgs) { + for (uint32_t i = 0; i < NumMemobjArgs; i++) { + auto &MemobjArgDesc = pUpdateKernelLaunch->pArgMemobjList[i]; + ur_mem_handle_t &ArgValue = + const_cast(MemobjArgDesc.hArgValue); + ArgValue = + (ArgValue) + ? reinterpret_cast(ArgValue)->handle + : nullptr; + } + } + // forward to device-platform result = pfnUpdateKernelLaunchExp(hCommand, pUpdateKernelLaunch); diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 1d3fd8b0d6..32ae64b35a 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -7846,12 +7846,6 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// + `NULL == hCommand` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pUpdateKernelLaunch` -/// + `NULL == pUpdateKernelLaunch->pArgMemobjList` -/// + `NULL == pUpdateKernelLaunch->pArgPointerList` -/// + `NULL == pUpdateKernelLaunch->pArgExecInfoList` -/// + `NULL == pUpdateKernelLaunch->pGlobalWorkOffset` -/// + `NULL == pUpdateKernelLaunch->pGlobalWorkSize` -/// + `NULL == pUpdateKernelLaunch->pLocalWorkSize` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If update functionality is not supported by the device. /// - ::UR_RESULT_ERROR_INVALID_OPERATION diff --git a/source/ur_api.cpp b/source/ur_api.cpp index dd58845649..01bc8bc9aa 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -6631,12 +6631,6 @@ ur_result_t UR_APICALL urCommandBufferEnqueueExp( /// + `NULL == hCommand` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == pUpdateKernelLaunch` -/// + `NULL == pUpdateKernelLaunch->pArgMemobjList` -/// + `NULL == pUpdateKernelLaunch->pArgPointerList` -/// + `NULL == pUpdateKernelLaunch->pArgExecInfoList` -/// + `NULL == pUpdateKernelLaunch->pGlobalWorkOffset` -/// + `NULL == pUpdateKernelLaunch->pGlobalWorkSize` -/// + `NULL == pUpdateKernelLaunch->pLocalWorkSize` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If update functionality is not supported by the device. /// - ::UR_RESULT_ERROR_INVALID_OPERATION diff --git a/test/conformance/CMakeLists.txt b/test/conformance/CMakeLists.txt index df80c02681..05e522c38b 100644 --- a/test/conformance/CMakeLists.txt +++ b/test/conformance/CMakeLists.txt @@ -115,6 +115,7 @@ if(UR_DPCXX) add_subdirectory(kernel) add_subdirectory(program) add_subdirectory(enqueue) + add_subdirectory(exp_command_buffer) else() message(WARNING "UR_DPCXX is not defined, the following conformance test executables \ diff --git a/test/conformance/exp_command_buffer/CMakeLists.txt b/test/conformance/exp_command_buffer/CMakeLists.txt new file mode 100644 index 0000000000..a2f70facae --- /dev/null +++ b/test/conformance/exp_command_buffer/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright (C) 2023 Intel Corporation +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_conformance_test_with_kernels_environment(exp_command_buffer + urCommandBufferAppendKernelLaunchExp.cpp +) diff --git a/test/conformance/exp_command_buffer/urCommandBufferAppendKernelLaunchExp.cpp b/test/conformance/exp_command_buffer/urCommandBufferAppendKernelLaunchExp.cpp new file mode 100644 index 0000000000..525805e4c7 --- /dev/null +++ b/test/conformance/exp_command_buffer/urCommandBufferAppendKernelLaunchExp.cpp @@ -0,0 +1,384 @@ +// Copyright (C) 2023 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +struct SingleKernelCommandTest : uur::urExpUpdatableCommandBufferTests { + void SetUp() override { + program_name = "fill"; + UUR_RETURN_ON_FATAL_FAILURE(urExpUpdatableCommandBufferTests::SetUp()); + + AddBuffer1DArg(sizeof(val) * global_size, &buffer); + AddPodArg(val); + + ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp( + updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset, + &global_size, &local_size, 0, nullptr, nullptr, &command_handle)); + ASSERT_NE(command_handle, nullptr); + + ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle)); + } + + void TearDown() override { + if (new_buffer) { + ASSERT_SUCCESS(urMemRelease(new_buffer)); + } + } + + uint32_t val = 42; + size_t local_size = 4; + size_t global_size = 32; + size_t global_offset = 0; + size_t n_dimensions = 1; + ur_mem_handle_t buffer = nullptr; + ur_mem_handle_t new_buffer = nullptr; + ur_exp_command_buffer_command_handle_t command_handle = nullptr; +}; + +UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(SingleKernelCommandTest); + +TEST_P(SingleKernelCommandTest, UpdateParmeters) { + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + + ASSERT_SUCCESS(urQueueFinish(queue)); + + const size_t buffer_size = sizeof(val) * global_size; + ValidateBuffer(buffer, buffer_size, val); + + ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, + buffer_size, nullptr, &new_buffer)); + char zero = 0; + ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, new_buffer, &zero, + sizeof(zero), 0, buffer_size, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ur_exp_command_buffer_update_memobj_arg_desc_t new_output_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC, // stype + nullptr, // pNext, + 0, // argIndex, + nullptr, // pProperties + new_buffer, // hArgValue + }; + + int new_val = 33; + ur_exp_command_buffer_update_value_arg_desc_t new_input_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext, + 2, // argIndex, + sizeof(int), // argSize, + nullptr, // pProperties + &new_val, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 1, // numMemobjArgs + 0, // numPointerArgs + 1, // numValueArgs + 0, // numExecInfos + 0, // workDim; + &new_output_desc, // pArgMemobjList + nullptr, // pArgPointerList + &new_input_desc, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + nullptr, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(new_buffer, buffer_size, new_val); +} + +TEST_P(SingleKernelCommandTest, UpdateGlobalSize) { + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(buffer, sizeof(val) * global_size, val); + + size_t new_global_size = 64; + const size_t buffer_size = sizeof(val) * new_global_size; + ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, + buffer_size, nullptr, &new_buffer)); + char zero = 0; + ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, new_buffer, &zero, + sizeof(zero), 0, buffer_size, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ur_exp_command_buffer_update_memobj_arg_desc_t new_output_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC, // stype + nullptr, // pNext, + 0, // argIndex, + nullptr, // pProperties + new_buffer, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 1, // numMemobjArgs + 0, // numPointerArgs + 0, // numValueArgs + 0, // numExecInfos + 0, // workDim; + &new_output_desc, // pArgMemobjList + nullptr, // pArgPointerList + nullptr, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + &new_global_size, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + + ASSERT_SUCCESS( + urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc)); + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(new_buffer, buffer_size, val); +} + +TEST_P(SingleKernelCommandTest, SeparateUpdateCalls) { + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(buffer, sizeof(val) * global_size, val); + + size_t new_global_size = 64; + const size_t buffer_size = sizeof(val) * new_global_size; + ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, + buffer_size, nullptr, &new_buffer)); + char zero = 0; + ASSERT_SUCCESS(urEnqueueMemBufferFill(queue, new_buffer, &zero, + sizeof(zero), 0, buffer_size, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ur_exp_command_buffer_update_memobj_arg_desc_t new_output_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC, // stype + nullptr, // pNext, + 0, // argIndex, + nullptr, // pProperties + new_buffer, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t output_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 1, // numMemobjArgs + 0, // numPointerArgs + 0, // numValueArgs + 0, // numExecInfos + 0, // workDim; + &new_output_desc, // pArgMemobjList + nullptr, // pArgPointerList + nullptr, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + nullptr, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, + &output_update_desc)); + + int new_val = 33; + ur_exp_command_buffer_update_value_arg_desc_t new_input_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext, + 2, // argIndex, + sizeof(int), // argSize, + nullptr, // pProperties + &new_val, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t input_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 0, // numMemobjArgs + 0, // numPointerArgs + 1, // numValueArgs + 0, // numExecInfos + 0, // workDim; + nullptr, // pArgMemobjList + nullptr, // pArgPointerList + &new_input_desc, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + &new_global_size, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, + &input_update_desc)); + + ur_exp_command_buffer_update_kernel_launch_desc_t global_size_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 0, // numMemobjArgs + 0, // numPointerArgs + 0, // numValueArgs + 0, // numExecInfos + 0, // workDim; + nullptr, // pArgMemobjList + nullptr, // pArgPointerList + nullptr, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + &new_global_size, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp( + command_handle, &global_size_update_desc)); + + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(new_buffer, buffer_size, new_val); +} + +TEST_P(SingleKernelCommandTest, OverrideUpdate) { + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(buffer, sizeof(val) * global_size, val); + + int first_val = 33; + ur_exp_command_buffer_update_value_arg_desc_t first_input_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext, + 2, // argIndex, + sizeof(int), // argSize, + nullptr, // pProperties + &first_val, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t first_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 0, // numMemobjArgs + 0, // numPointerArgs + 1, // numValueArgs + 0, // numExecInfos + 0, // workDim; + nullptr, // pArgMemobjList + nullptr, // pArgPointerList + &first_input_desc, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + nullptr, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, + &first_update_desc)); + + int second_val = -99; + ur_exp_command_buffer_update_value_arg_desc_t second_input_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext, + 2, // argIndex, + sizeof(int), // argSize, + nullptr, // pProperties + &second_val, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t second_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 0, // numMemobjArgs + 0, // numPointerArgs + 1, // numValueArgs + 0, // numExecInfos + 0, // workDim; + nullptr, // pArgMemobjList + nullptr, // pArgPointerList + &second_input_desc, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + nullptr, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, + &second_update_desc)); + + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(buffer, sizeof(val) * global_size, second_val); +} + +TEST_P(SingleKernelCommandTest, OverrideArgList) { + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(buffer, sizeof(val) * global_size, val); + + ur_exp_command_buffer_update_value_arg_desc_t input_descs[2]; + int first_val = 33; + input_descs[0] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext, + 2, // argIndex, + sizeof(int), // argSize, + nullptr, // pProperties + &first_val, // hArgValue + }; + + int second_val = -99; + input_descs[1] = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype + nullptr, // pNext, + 2, // argIndex, + sizeof(int), // argSize, + nullptr, // pProperties + &second_val, // hArgValue + }; + + ur_exp_command_buffer_update_kernel_launch_desc_t second_update_desc = { + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype + nullptr, // pNext + 0, // numMemobjArgs + 0, // numPointerArgs + 2, // numValueArgs + 0, // numExecInfos + 0, // workDim; + nullptr, // pArgMemobjList + nullptr, // pArgPointerList + input_descs, // pArgValueList + nullptr, // pArgExecInfoList + nullptr, // pGlobalWorkOffset + nullptr, // pGlobalWorkSize + nullptr, // pLocalWorkSize + }; + + ASSERT_SUCCESS(urCommandBufferUpdateKernelLaunchExp(command_handle, + &second_update_desc)); + + ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0, + nullptr, nullptr)); + ASSERT_SUCCESS(urQueueFinish(queue)); + + ValidateBuffer(buffer, sizeof(val) * global_size, second_val); +} diff --git a/test/conformance/testing/include/uur/fixtures.h b/test/conformance/testing/include/uur/fixtures.h index 2ede84d135..0da9bdc22f 100644 --- a/test/conformance/testing/include/uur/fixtures.h +++ b/test/conformance/testing/include/uur/fixtures.h @@ -1255,7 +1255,7 @@ struct urBaseKernelExecutionTest : urBaseKernelTest { }; struct urKernelExecutionTest : urBaseKernelExecutionTest { - void SetUp() { + void SetUp() override { UUR_RETURN_ON_FATAL_FAILURE(urBaseKernelExecutionTest::SetUp()); Build(); } @@ -1276,6 +1276,77 @@ struct urGlobalVariableTest : uur::urKernelExecutionTest { GlobalVar global_var; }; +struct urExpCommandBufferTest : urKernelExecutionTest { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE(urKernelExecutionTest::SetUp()); + + size_t returned_size; + ASSERT_SUCCESS(urDeviceGetInfo(device, UR_DEVICE_INFO_EXTENSIONS, 0, + nullptr, &returned_size)); + + std::unique_ptr returned_extensions(new char[returned_size]); + + ASSERT_SUCCESS(urDeviceGetInfo(device, UR_DEVICE_INFO_EXTENSIONS, + returned_size, returned_extensions.get(), + nullptr)); + + std::string_view extensions_string(returned_extensions.get()); + bool command_buffer_support = + extensions_string.find(UR_COMMAND_BUFFER_EXTENSION_STRING_EXP) != + std::string::npos; + + if (!command_buffer_support) { + GTEST_SKIP() << "EXP command-buffer feature is not supported."; + } + + ASSERT_SUCCESS(urDeviceGetInfo( + device, UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP, + sizeof(ur_bool_t), &updatable_command_buffer_support, nullptr)); + + // Create a command-buffer + ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, nullptr, + &cmd_buf_handle)); + ASSERT_NE(cmd_buf_handle, nullptr); + } + + void TearDown() override { + if (cmd_buf_handle) { + EXPECT_SUCCESS(urCommandBufferReleaseExp(cmd_buf_handle)); + } + UUR_RETURN_ON_FATAL_FAILURE(urKernelExecutionTest::TearDown()); + } + + ur_exp_command_buffer_handle_t cmd_buf_handle = nullptr; + ur_bool_t updatable_command_buffer_support = false; +}; + +struct urExpUpdatableCommandBufferTests : urExpCommandBufferTest { + void SetUp() override { + UUR_RETURN_ON_FATAL_FAILURE(urExpCommandBufferTest ::SetUp()); + + if (!updatable_command_buffer_support) { + GTEST_SKIP() << "Updating EXP command-buffers is not supported."; + } + + // Create a command-buffer with update enabled. + ur_exp_command_buffer_desc_t desc{ + UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC, nullptr, true}; + + ASSERT_SUCCESS(urCommandBufferCreateExp(context, device, &desc, + &updatable_cmd_buf_handle)); + ASSERT_NE(updatable_cmd_buf_handle, nullptr); + } + + void TearDown() override { + if (updatable_cmd_buf_handle) { + EXPECT_SUCCESS(urCommandBufferReleaseExp(updatable_cmd_buf_handle)); + } + UUR_RETURN_ON_FATAL_FAILURE(urExpCommandBufferTest::TearDown()); + } + + ur_exp_command_buffer_handle_t updatable_cmd_buf_handle = nullptr; +}; + } // namespace uur #endif // UR_CONFORMANCE_INCLUDE_FIXTURES_H_INCLUDED