Skip to content

Commit

Permalink
Merge pull request #1226 from hdelan/get-native-mem-on-device2
Browse files Browse the repository at this point in the history
[UR] Add extra param to urMemGetNativeHandle
  • Loading branch information
kbenzie committed Jan 31, 2024
2 parents 40517d2 + fc1f306 commit d216eb4
Show file tree
Hide file tree
Showing 19 changed files with 87 additions and 30 deletions.
3 changes: 3 additions & 0 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2723,13 +2723,15 @@ urMemBufferPartition(
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hMem`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phNativeMem`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
UR_APIEXPORT ur_result_t UR_APICALL
urMemGetNativeHandle(
ur_mem_handle_t hMem, ///< [in] handle of the mem.
ur_device_handle_t hDevice, ///< [in] handle of the device that the native handle will be resident on.
ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem.
);

Expand Down Expand Up @@ -9488,6 +9490,7 @@ typedef struct ur_mem_buffer_partition_params_t {
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_mem_get_native_handle_params_t {
ur_mem_handle_t *phMem;
ur_device_handle_t *phDevice;
ur_native_handle_t **pphNativeMem;
} ur_mem_get_native_handle_params_t;

Expand Down
1 change: 1 addition & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnMemBufferPartition_t)(
/// @brief Function-pointer for urMemGetNativeHandle
typedef ur_result_t(UR_APICALL *ur_pfnMemGetNativeHandle_t)(
ur_mem_handle_t,
ur_device_handle_t,
ur_native_handle_t *);

///////////////////////////////////////////////////////////////////////////////
Expand Down
6 changes: 6 additions & 0 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11174,6 +11174,12 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
ur::details::printPtr(os,
*(params->phMem));

os << ", ";
os << ".hDevice = ";

ur::details::printPtr(os,
*(params->phDevice));

os << ", ";
os << ".phNativeMem = ";

Expand Down
4 changes: 4 additions & 0 deletions scripts/core/memory.yml
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,10 @@ params:
name: hMem
desc: |
[in] handle of the mem.
- type: $x_device_handle_t
name: hDevice
desc: |
[in] handle of the device that the native handle will be resident on.
- type: $x_native_handle_t*
name: phNativeMem
desc: |
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/cuda/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
/// \param[out] phNativeMem Set to the native handle of the UR mem object.
///
/// \return UR_RESULT_SUCCESS
UR_APIEXPORT ur_result_t UR_APICALL
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) {
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
std::get<BufferMem>(hMem->Mem).get());
return UR_RESULT_SUCCESS;
Expand Down
36 changes: 26 additions & 10 deletions source/adapters/hip/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,16 +279,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
/// \param[out] phNativeMem Set to the native handle of the UR mem object.
///
/// \return UR_RESULT_SUCCESS
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(ur_mem_handle_t,
ur_native_handle_t *) {
// FIXME: there is no good way of doing this with a multi device context.
// If we return a single pointer, how would we know which device's allocation
// it should be?
// If we return a vector of pointers, this is OK for read only access but if
// we write to a buffer, how would we know which one had been written to?
// Should unused allocations be updated afterwards? We have no way of knowing
// any of these things in the current API design.
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
UR_APIEXPORT ur_result_t UR_APICALL
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_device_handle_t Device,
ur_native_handle_t *phNativeMem) {
#if defined(__HIP_PLATFORM_NVIDIA__)
if (sizeof(BufferMem::native_type) > sizeof(ur_native_handle_t)) {
// Check that all the upper bits that cannot be represented by
// ur_native_handle_t are empty.
// NOTE: The following shift might trigger a warning, but the check in the
// if above makes sure that this does not underflow.
BufferMem::native_type UpperBits =
std::get<BufferMem>(hMem->Mem).getPtr(Device) >>
(sizeof(ur_native_handle_t) * CHAR_BIT);
if (UpperBits) {
// Return an error if any of the remaining bits is non-zero.
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
}
}
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
std::get<BufferMem>(hMem->Mem).getPtr(Device));
#elif defined(__HIP_PLATFORM_AMD__)
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
std::get<BufferMem>(hMem->Mem).getPtr(Device));
#else
#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
#endif
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
Expand Down
1 change: 1 addition & 0 deletions source/adapters/level_zero/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1856,6 +1856,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(

UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t Mem, ///< [in] handle of the mem.
ur_device_handle_t, ///< [in] handle of the device.
ur_native_handle_t
*NativeMem ///< [out] a pointer to the native handle of the mem.
) {
Expand Down
4 changes: 3 additions & 1 deletion source/adapters/native_cpu/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
}

UR_APIEXPORT ur_result_t UR_APICALL
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_device_handle_t hDevice,
ur_native_handle_t *phNativeMem) {
std::ignore = hMem;
std::ignore = hDevice;
std::ignore = phNativeMem;

DIE_NO_IMPLEMENTATION
Expand Down
4 changes: 3 additions & 1 deletion source/adapters/null/ur_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition(
/// @brief Intercept function for urMemGetNativeHandle
__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ///< [in] handle of the mem.
ur_device_handle_t
hDevice, ///< [in] handle of the device that the native handle will be resident on.
ur_native_handle_t
*phNativeMem ///< [out] a pointer to the native handle of the mem.
) try {
Expand All @@ -924,7 +926,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
// if the driver has created a custom function, then call it instead of using the generic path
auto pfnGetNativeHandle = d_context.urDdiTable.Mem.pfnGetNativeHandle;
if (nullptr != pfnGetNativeHandle) {
result = pfnGetNativeHandle(hMem, phNativeMem);
result = pfnGetNativeHandle(hMem, hDevice, phNativeMem);
} else {
// generic implementation
*phNativeMem = reinterpret_cast<ur_native_handle_t>(d_context.get());
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/opencl/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
return mapCLErrorToUR(RetErr);
}

UR_APIEXPORT ur_result_t UR_APICALL
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) {
return getNativeHandle(hMem, phNativeMem);
}

Expand Down
6 changes: 4 additions & 2 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition(
/// @brief Intercept function for urMemGetNativeHandle
__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ///< [in] handle of the mem.
ur_device_handle_t
hDevice, ///< [in] handle of the device that the native handle will be resident on.
ur_native_handle_t
*phNativeMem ///< [out] a pointer to the native handle of the mem.
) {
Expand All @@ -995,11 +997,11 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

ur_mem_get_native_handle_params_t params = {&hMem, &phNativeMem};
ur_mem_get_native_handle_params_t params = {&hMem, &hDevice, &phNativeMem};
uint64_t instance = context.notify_begin(UR_FUNCTION_MEM_GET_NATIVE_HANDLE,
"urMemGetNativeHandle", &params);

ur_result_t result = pfnGetNativeHandle(hMem, phNativeMem);
ur_result_t result = pfnGetNativeHandle(hMem, hDevice, phNativeMem);

context.notify_end(UR_FUNCTION_MEM_GET_NATIVE_HANDLE,
"urMemGetNativeHandle", &params, &result, instance);
Expand Down
8 changes: 7 additions & 1 deletion source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition(
/// @brief Intercept function for urMemGetNativeHandle
__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ///< [in] handle of the mem.
ur_device_handle_t
hDevice, ///< [in] handle of the device that the native handle will be resident on.
ur_native_handle_t
*phNativeMem ///< [out] a pointer to the native handle of the mem.
) {
Expand All @@ -1202,12 +1204,16 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

if (NULL == hDevice) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

if (NULL == phNativeMem) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}
}

ur_result_t result = pfnGetNativeHandle(hMem, phNativeMem);
ur_result_t result = pfnGetNativeHandle(hMem, hDevice, phNativeMem);

return result;
}
Expand Down
7 changes: 6 additions & 1 deletion source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition(
/// @brief Intercept function for urMemGetNativeHandle
__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ///< [in] handle of the mem.
ur_device_handle_t
hDevice, ///< [in] handle of the device that the native handle will be resident on.
ur_native_handle_t
*phNativeMem ///< [out] a pointer to the native handle of the mem.
) {
Expand All @@ -1246,8 +1248,11 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
// convert loader handle to platform handle
hMem = reinterpret_cast<ur_mem_object_t *>(hMem)->handle;

// convert loader handle to platform handle
hDevice = reinterpret_cast<ur_device_object_t *>(hDevice)->handle;

// forward to device-platform
result = pfnGetNativeHandle(hMem, phNativeMem);
result = pfnGetNativeHandle(hMem, hDevice, phNativeMem);

if (UR_RESULT_SUCCESS != result) {
return result;
Expand Down
5 changes: 4 additions & 1 deletion source/loader/ur_libapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1654,12 +1654,15 @@ ur_result_t UR_APICALL urMemBufferPartition(
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hMem`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phNativeMem`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ///< [in] handle of the mem.
ur_device_handle_t
hDevice, ///< [in] handle of the device that the native handle will be resident on.
ur_native_handle_t
*phNativeMem ///< [out] a pointer to the native handle of the mem.
) try {
Expand All @@ -1669,7 +1672,7 @@ ur_result_t UR_APICALL urMemGetNativeHandle(
return UR_RESULT_ERROR_UNINITIALIZED;
}

return pfnGetNativeHandle(hMem, phNativeMem);
return pfnGetNativeHandle(hMem, hDevice, phNativeMem);
} catch (...) {
return exceptionToResult(std::current_exception());
}
Expand Down
3 changes: 3 additions & 0 deletions source/ur_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1419,12 +1419,15 @@ ur_result_t UR_APICALL urMemBufferPartition(
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hMem`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phNativeMem`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ///< [in] handle of the mem.
ur_device_handle_t
hDevice, ///< [in] handle of the device that the native handle will be resident on.
ur_native_handle_t
*phNativeMem ///< [out] a pointer to the native handle of the mem.
) {
Expand Down
5 changes: 1 addition & 4 deletions test/conformance/memory/memory_adapter_hip.match
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
{{OPT}}urMemGetInfoTest.InvalidNullPointerParamValue/AMD_HIP_BACKEND___{{.*}}
{{OPT}}urMemGetInfoTest.InvalidNullPointerParamValue/AMD_HIP_BACKEND___{{.*}}
{{OPT}}urMemGetInfoTest.InvalidNullPointerPropSizeRet/AMD_HIP_BACKEND___{{.*}}
{{OPT}}urMemGetInfoTest.InvalidNullPointerPropSizeRet/AMD_HIP_BACKEND___{{.*}}
{{OPT}}urMemBufferCreateWithNativeHandleTest.Success/AMD_HIP_BACKEND___{{.*}}
{{OPT}}urMemImageCreateTest.InvalidSize/AMD_HIP_BACKEND___{{.*}}
{{OPT}}urMemImageGetInfoTest.Success/AMD_HIP_BACKEND___{{.*}}
{{OPT}}urMemImageGetInfoTest.Success/AMD_HIP_BACKEND___{{.*}}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urMemBufferCreateWithNativeHandleTest);

TEST_P(urMemBufferCreateWithNativeHandleTest, Success) {
ur_native_handle_t hNativeMem = nullptr;
if (urMemGetNativeHandle(buffer, &hNativeMem)) {
if (urMemGetNativeHandle(buffer, device, &hNativeMem)) {
GTEST_SKIP();
}

Expand Down
12 changes: 9 additions & 3 deletions test/conformance/memory/urMemGetNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,24 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urMemGetNativeHandleTest);

TEST_P(urMemGetNativeHandleTest, Success) {
ur_native_handle_t hNativeMem = nullptr;
if (auto error = urMemGetNativeHandle(buffer, &hNativeMem)) {
if (auto error = urMemGetNativeHandle(buffer, device, &hNativeMem)) {
ASSERT_EQ_RESULT(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, error);
}
}

TEST_P(urMemGetNativeHandleTest, InvalidNullHandleMem) {
ur_native_handle_t phNativeMem;
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
urMemGetNativeHandle(nullptr, &phNativeMem));
urMemGetNativeHandle(nullptr, device, &phNativeMem));
}

TEST_P(urMemGetNativeHandleTest, InvalidNullHandleDevice) {
ur_native_handle_t phNativeMem;
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
urMemGetNativeHandle(buffer, nullptr, &phNativeMem));
}

TEST_P(urMemGetNativeHandleTest, InvalidNullPointerNativeMem) {
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER,
urMemGetNativeHandle(buffer, nullptr));
urMemGetNativeHandle(buffer, device, nullptr));
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urMemImageCreateWithNativeHandleTest);

TEST_P(urMemImageCreateWithNativeHandleTest, Success) {
ur_native_handle_t native_handle = nullptr;
if (urMemGetNativeHandle(image, &native_handle)) {
if (urMemGetNativeHandle(image, device, &native_handle)) {
GTEST_SKIP();
}

Expand Down

0 comments on commit d216eb4

Please sign in to comment.