Skip to content

Commit

Permalink
Merge pull request #571 from KseniyaTikhomirova/mem_obj_proposal
Browse files Browse the repository at this point in the history
[UR] Proposal of new API for memory object properties
  • Loading branch information
kbenzie committed Jun 15, 2023
2 parents aa5d052 + bee29ae commit 7320835
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 20 deletions.
16 changes: 14 additions & 2 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ class ur_structure_type_v(IntEnum):
DEVICE_PARTITION_PROPERTIES = 26 ## ::ur_device_partition_properties_t
EXP_COMMAND_BUFFER_DESC = 27 ## ::ur_exp_command_buffer_desc_t
EXP_SAMPLER_MIP_PROPERTIES = 28 ## ::ur_exp_sampler_mip_properties_t
KERNEL_ARG_MEM_OBJ_PROPERTIES = 29 ## ::ur_kernel_arg_mem_obj_properties_t

class ur_structure_type_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -1513,6 +1514,17 @@ def __str__(self):
return str(ur_kernel_exec_info_v(self.value))


###############################################################################
## @brief Properties for for ::urKernelSetArgMemObj.
class ur_kernel_arg_mem_obj_properties_t(Structure):
_fields_ = [
("stype", ur_structure_type_t), ## [in] type of this structure, must be
## ::UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES
("pNext", c_void_p), ## [in,out][optional] pointer to extension-specific structure
("memoryAccess", ur_mem_flags_t) ## [in] Memory access flag. Allowed values are: ::UR_MEM_FLAG_READ_WRITE,
## ::UR_MEM_FLAG_WRITE_ONLY, ::UR_MEM_FLAG_READ_ONLY.
]

###############################################################################
## @brief Properties for for ::urKernelCreateWithNativeHandle.
class ur_kernel_native_properties_t(Structure):
Expand Down Expand Up @@ -2399,9 +2411,9 @@ class ur_program_dditable_t(Structure):
###############################################################################
## @brief Function-pointer for urKernelSetArgMemObj
if __use_win_types:
_urKernelSetArgMemObj_t = WINFUNCTYPE( ur_result_t, ur_kernel_handle_t, c_ulong, ur_mem_handle_t )
_urKernelSetArgMemObj_t = WINFUNCTYPE( ur_result_t, ur_kernel_handle_t, c_ulong, POINTER(ur_kernel_arg_mem_obj_properties_t), ur_mem_handle_t )
else:
_urKernelSetArgMemObj_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, c_ulong, ur_mem_handle_t )
_urKernelSetArgMemObj_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, c_ulong, POINTER(ur_kernel_arg_mem_obj_properties_t), ur_mem_handle_t )

###############################################################################
## @brief Function-pointer for urKernelSetSpecializationConstants
Expand Down
20 changes: 17 additions & 3 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ typedef enum ur_structure_type_t {
UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES = 26, ///< ::ur_device_partition_properties_t
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC = 27, ///< ::ur_exp_command_buffer_desc_t
UR_STRUCTURE_TYPE_EXP_SAMPLER_MIP_PROPERTIES = 28, ///< ::ur_exp_sampler_mip_properties_t
UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES = 29, ///< ::ur_kernel_arg_mem_obj_properties_t
/// @cond
UR_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
/// @endcond
Expand Down Expand Up @@ -3893,6 +3894,17 @@ urKernelSetArgSampler(
ur_sampler_handle_t hArgValue ///< [in] handle of Sampler object.
);

///////////////////////////////////////////////////////////////////////////////
/// @brief Properties for for ::urKernelSetArgMemObj.
typedef struct ur_kernel_arg_mem_obj_properties_t {
ur_structure_type_t stype; ///< [in] type of this structure, must be
///< ::UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES
void *pNext; ///< [in,out][optional] pointer to extension-specific structure
ur_mem_flags_t memoryAccess; ///< [in] Memory access flag. Allowed values are: ::UR_MEM_FLAG_READ_WRITE,
///< ::UR_MEM_FLAG_WRITE_ONLY, ::UR_MEM_FLAG_READ_ONLY.

} ur_kernel_arg_mem_obj_properties_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Set a Memory object as the argument value of a Kernel.
///
Expand All @@ -3910,9 +3922,10 @@ urKernelSetArgSampler(
/// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX
UR_APIEXPORT ur_result_t UR_APICALL
urKernelSetArgMemObj(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
const ur_kernel_arg_mem_obj_properties_t *pProperties, ///< [in][optional] pointer to Memory object properties.
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
);

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -7432,6 +7445,7 @@ typedef struct ur_kernel_set_arg_sampler_params_t {
typedef struct ur_kernel_set_arg_mem_obj_params_t {
ur_kernel_handle_t *phKernel;
uint32_t *pargIndex;
const ur_kernel_arg_mem_obj_properties_t **ppProperties;
ur_mem_handle_t *phArgValue;
} ur_kernel_set_arg_mem_obj_params_t;

Expand Down
1 change: 1 addition & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnKernelSetArgSampler_t)(
typedef ur_result_t(UR_APICALL *ur_pfnKernelSetArgMemObj_t)(
ur_kernel_handle_t,
uint32_t,
const ur_kernel_arg_mem_obj_properties_t *,
ur_mem_handle_t);

///////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 3 additions & 3 deletions scripts/core/PROG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ explicit and implicit kernel arguments along with data needed for launch.
// Create kernel object from program
${x}_kernel_handle_t hKernel;
${x}KernelCreate(hProgram, "addVectors", &hKernel);
${x}KernelSetArgMemObj(hKernel, 0, A);
${x}KernelSetArgMemObj(hKernel, 1, B);
${x}KernelSetArgMemObj(hKernel, 2, C);
${x}KernelSetArgMemObj(hKernel, 0, nullptr, A);
${x}KernelSetArgMemObj(hKernel, 1, nullptr, B);
${x}KernelSetArgMemObj(hKernel, 2, nullptr, C);
Queue and Enqueue
=================
Expand Down
2 changes: 2 additions & 0 deletions scripts/core/common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ etors:
desc: $x_exp_command_buffer_desc_t
- name: EXP_SAMPLER_MIP_PROPERTIES
desc: $x_exp_sampler_mip_properties_t
- name: KERNEL_ARG_MEM_OBJ_PROPERTIES
desc: $x_kernel_arg_mem_obj_properties_t
--- #--------------------------------------------------------------------------
type: struct
desc: "Base for all properties types"
Expand Down
13 changes: 13 additions & 0 deletions scripts/core/kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,16 @@ params:
returns:
- $X_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX
--- #--------------------------------------------------------------------------
type: struct
desc: "Properties for for $xKernelSetArgMemObj."
class: $xKernel
name: $x_kernel_arg_mem_obj_properties_t
base: $x_base_properties_t
members:
- type: $x_mem_flags_t
name: memoryAccess
desc: "[in] Memory access flag. Allowed values are: $X_MEM_FLAG_READ_WRITE, $X_MEM_FLAG_WRITE_ONLY, $X_MEM_FLAG_READ_ONLY."
--- #--------------------------------------------------------------------------
type: function
desc: "Set a Memory object as the argument value of a Kernel."
class: $xKernel
Expand All @@ -372,6 +382,9 @@ params:
- type: "uint32_t"
name: argIndex
desc: "[in] argument index in range [0, num args - 1]"
- type: "const $x_kernel_arg_mem_obj_properties_t*"
name: pProperties
desc: "[in][optional] pointer to Memory object properties."
- type: "$x_mem_handle_t"
name: hArgValue
desc: "[in][optional] handle of Memory object."
Expand Down
4 changes: 3 additions & 1 deletion source/adapters/null/ur_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1955,14 +1955,16 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler(
__urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
const ur_kernel_arg_mem_obj_properties_t
*pProperties, ///< [in][optional] pointer to Memory object properties.
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
) try {
ur_result_t result = UR_RESULT_SUCCESS;

// if the driver has created a custom function, then call it instead of using the generic path
auto pfnSetArgMemObj = d_context.urDdiTable.Kernel.pfnSetArgMemObj;
if (nullptr != pfnSetArgMemObj) {
result = pfnSetArgMemObj(hKernel, argIndex, hArgValue);
result = pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue);
} else {
// generic implementation
}
Expand Down
40 changes: 40 additions & 0 deletions source/common/ur_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ inline std::ostream &operator<<(std::ostream &os,
inline std::ostream &operator<<(std::ostream &os,
enum ur_kernel_exec_info_t value);
inline std::ostream &
operator<<(std::ostream &os,
const struct ur_kernel_arg_mem_obj_properties_t params);
inline std::ostream &
operator<<(std::ostream &os, const struct ur_kernel_native_properties_t params);
inline std::ostream &operator<<(std::ostream &os, enum ur_queue_info_t value);
inline std::ostream &operator<<(std::ostream &os, enum ur_queue_flag_t value);
Expand Down Expand Up @@ -757,6 +760,10 @@ inline std::ostream &operator<<(std::ostream &os,
case UR_STRUCTURE_TYPE_EXP_SAMPLER_MIP_PROPERTIES:
os << "UR_STRUCTURE_TYPE_EXP_SAMPLER_MIP_PROPERTIES";
break;

case UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES:
os << "UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES";
break;
default:
os << "unknown enumerator";
break;
Expand Down Expand Up @@ -938,6 +945,12 @@ inline void serializeStruct(std::ostream &os, const void *ptr) {
(const ur_exp_sampler_mip_properties_t *)ptr;
ur_params::serializePtr(os, pstruct);
} break;

case UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES: {
const ur_kernel_arg_mem_obj_properties_t *pstruct =
(const ur_kernel_arg_mem_obj_properties_t *)ptr;
ur_params::serializePtr(os, pstruct);
} break;
default:
os << "unknown enumerator";
break;
Expand Down Expand Up @@ -7241,6 +7254,28 @@ inline void serializeTagged(std::ostream &os, const void *ptr,
}
} // namespace ur_params
inline std::ostream &
operator<<(std::ostream &os,
const struct ur_kernel_arg_mem_obj_properties_t params) {
os << "(struct ur_kernel_arg_mem_obj_properties_t){";

os << ".stype = ";

os << (params.stype);

os << ", ";
os << ".pNext = ";

ur_params::serializeStruct(os, (params.pNext));

os << ", ";
os << ".memoryAccess = ";

ur_params::serializeFlag<ur_mem_flag_t>(os, (params.memoryAccess));

os << "}";
return os;
}
inline std::ostream &
operator<<(std::ostream &os,
const struct ur_kernel_native_properties_t params) {
os << "(struct ur_kernel_native_properties_t){";
Expand Down Expand Up @@ -11835,6 +11870,11 @@ operator<<(std::ostream &os,

os << *(params->pargIndex);

os << ", ";
os << ".pProperties = ";

ur_params::serializePtr(os, *(params->ppProperties));

os << ", ";
os << ".hArgValue = ";

Expand Down
7 changes: 5 additions & 2 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2228,6 +2228,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler(
__urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
const ur_kernel_arg_mem_obj_properties_t
*pProperties, ///< [in][optional] pointer to Memory object properties.
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
) {
auto pfnSetArgMemObj = context.urDdiTable.Kernel.pfnSetArgMemObj;
Expand All @@ -2237,11 +2239,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
}

ur_kernel_set_arg_mem_obj_params_t params = {&hKernel, &argIndex,
&hArgValue};
&pProperties, &hArgValue};
uint64_t instance = context.notify_begin(UR_FUNCTION_KERNEL_SET_ARG_MEM_OBJ,
"urKernelSetArgMemObj", &params);

ur_result_t result = pfnSetArgMemObj(hKernel, argIndex, hArgValue);
ur_result_t result =
pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue);

context.notify_end(UR_FUNCTION_KERNEL_SET_ARG_MEM_OBJ,
"urKernelSetArgMemObj", &params, &result, instance);
Expand Down
5 changes: 4 additions & 1 deletion source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2729,6 +2729,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler(
__urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
const ur_kernel_arg_mem_obj_properties_t
*pProperties, ///< [in][optional] pointer to Memory object properties.
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
) {
auto pfnSetArgMemObj = context.urDdiTable.Kernel.pfnSetArgMemObj;
Expand All @@ -2743,7 +2745,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
}
}

ur_result_t result = pfnSetArgMemObj(hKernel, argIndex, hArgValue);
ur_result_t result =
pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue);

return result;
}
Expand Down
4 changes: 3 additions & 1 deletion source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2569,6 +2569,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgSampler(
__urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
const ur_kernel_arg_mem_obj_properties_t
*pProperties, ///< [in][optional] pointer to Memory object properties.
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
) {
ur_result_t result = UR_RESULT_SUCCESS;
Expand All @@ -2589,7 +2591,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
: nullptr;

// forward to device-platform
result = pfnSetArgMemObj(hKernel, argIndex, hArgValue);
result = pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue);

return result;
}
Expand Down
4 changes: 3 additions & 1 deletion source/loader/ur_libapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3039,14 +3039,16 @@ ur_result_t UR_APICALL urKernelSetArgSampler(
ur_result_t UR_APICALL urKernelSetArgMemObj(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
const ur_kernel_arg_mem_obj_properties_t
*pProperties, ///< [in][optional] pointer to Memory object properties.
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
) try {
auto pfnSetArgMemObj = ur_lib::context->urDdiTable.Kernel.pfnSetArgMemObj;
if (nullptr == pfnSetArgMemObj) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

return pfnSetArgMemObj(hKernel, argIndex, hArgValue);
return pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue);
} catch (...) {
return exceptionToResult(std::current_exception());
}
Expand Down
2 changes: 2 additions & 0 deletions source/ur_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2545,6 +2545,8 @@ ur_result_t UR_APICALL urKernelSetArgSampler(
ur_result_t UR_APICALL urKernelSetArgMemObj(
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
const ur_kernel_arg_mem_obj_properties_t
*pProperties, ///< [in][optional] pointer to Memory object properties.
ur_mem_handle_t hArgValue ///< [in][optional] handle of Memory object.
) {
ur_result_t result = UR_RESULT_SUCCESS;
Expand Down
9 changes: 5 additions & 4 deletions test/conformance/kernel/urKernelSetArgMemObj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,20 @@ struct urKernelSetArgMemObjTest : uur::urKernelTest {
UUR_INSTANTIATE_KERNEL_TEST_SUITE_P(urKernelSetArgMemObjTest);

TEST_P(urKernelSetArgMemObjTest, Success) {
ASSERT_SUCCESS(urKernelSetArgMemObj(kernel, 0, buffer));
ASSERT_SUCCESS(urKernelSetArgMemObj(kernel, 0, nullptr, buffer));
}

TEST_P(urKernelSetArgMemObjTest, InvalidNullHandleKernel) {
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
urKernelSetArgMemObj(nullptr, 0, buffer));
urKernelSetArgMemObj(nullptr, 0, nullptr, buffer));
}

TEST_P(urKernelSetArgMemObjTest, InvalidKernelArgumentIndex) {
size_t num_kernel_args = 0;
ASSERT_SUCCESS(urKernelGetInfo(kernel, UR_KERNEL_INFO_NUM_ARGS,
sizeof(num_kernel_args), &num_kernel_args,
nullptr));
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX,
urKernelSetArgMemObj(kernel, num_kernel_args + 1, buffer));
ASSERT_EQ_RESULT(
UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX,
urKernelSetArgMemObj(kernel, num_kernel_args + 1, nullptr, buffer));
}
4 changes: 2 additions & 2 deletions test/conformance/testing/include/uur/fixtures.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,8 @@ struct urKernelExecutionTest : urKernelTest {
sizeof(zero), 0, size, 0, nullptr,
nullptr));
ASSERT_SUCCESS(urQueueFinish(queue));
ASSERT_SUCCESS(
urKernelSetArgMemObj(kernel, current_arg_index, mem_handle));
ASSERT_SUCCESS(urKernelSetArgMemObj(kernel, current_arg_index, nullptr,
mem_handle));

// This emulates the offset struct sycl adds for a 1D buffer accessor.
struct {
Expand Down

0 comments on commit 7320835

Please sign in to comment.