diff --git a/include/ur_api.h b/include/ur_api.h index b25855be01..442c364e0c 100644 --- a/include/ur_api.h +++ b/include/ur_api.h @@ -2723,6 +2723,7 @@ urMemBufferPartition( /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hMem` +/// + `NULL == hDevice` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phNativeMem` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE @@ -2730,6 +2731,7 @@ urMemBufferPartition( UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t hDevice, ///< [in] handle of the device that the native handle will be resident on. ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem. ); @@ -9488,6 +9490,7 @@ typedef struct ur_mem_buffer_partition_params_t { /// allowing the callback the ability to modify the parameter's value typedef struct ur_mem_get_native_handle_params_t { ur_mem_handle_t *phMem; + ur_device_handle_t *phDevice; ur_native_handle_t **pphNativeMem; } ur_mem_get_native_handle_params_t; diff --git a/include/ur_ddi.h b/include/ur_ddi.h index 92fc742f72..77f2f35f70 100644 --- a/include/ur_ddi.h +++ b/include/ur_ddi.h @@ -770,6 +770,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnMemBufferPartition_t)( /// @brief Function-pointer for urMemGetNativeHandle typedef ur_result_t(UR_APICALL *ur_pfnMemGetNativeHandle_t)( ur_mem_handle_t, + ur_device_handle_t, ur_native_handle_t *); /////////////////////////////////////////////////////////////////////////////// diff --git a/include/ur_print.hpp b/include/ur_print.hpp index 63cf0e3aea..6b27b2a443 100644 --- a/include/ur_print.hpp +++ b/include/ur_print.hpp @@ -11174,6 +11174,12 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct ur::details::printPtr(os, *(params->phMem)); + os << ", "; + os << ".hDevice = "; + + ur::details::printPtr(os, + *(params->phDevice)); + os << ", "; os << ".phNativeMem = "; diff --git a/scripts/core/memory.yml b/scripts/core/memory.yml index ede16a1913..644832f1a3 100644 --- a/scripts/core/memory.yml +++ b/scripts/core/memory.yml @@ -432,6 +432,10 @@ params: name: hMem desc: | [in] handle of the mem. + - type: $x_device_handle_t + name: hDevice + desc: | + [in] handle of the device that the native handle will be resident on. - type: $x_native_handle_t* name: phNativeMem desc: | diff --git a/source/adapters/cuda/memory.cpp b/source/adapters/cuda/memory.cpp index 824ab1f580..f479522fb3 100644 --- a/source/adapters/cuda/memory.cpp +++ b/source/adapters/cuda/memory.cpp @@ -161,8 +161,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { /// \param[out] phNativeMem Set to the native handle of the UR mem object. /// /// \return UR_RESULT_SUCCESS -UR_APIEXPORT ur_result_t UR_APICALL -urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) { +UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle( + ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) { *phNativeMem = reinterpret_cast( std::get(hMem->Mem).get()); return UR_RESULT_SUCCESS; diff --git a/source/adapters/hip/memory.cpp b/source/adapters/hip/memory.cpp index 7be8f3f9c1..dcc3e34fad 100644 --- a/source/adapters/hip/memory.cpp +++ b/source/adapters/hip/memory.cpp @@ -279,16 +279,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory, /// \param[out] phNativeMem Set to the native handle of the UR mem object. /// /// \return UR_RESULT_SUCCESS -UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(ur_mem_handle_t, - ur_native_handle_t *) { - // FIXME: there is no good way of doing this with a multi device context. - // If we return a single pointer, how would we know which device's allocation - // it should be? - // If we return a vector of pointers, this is OK for read only access but if - // we write to a buffer, how would we know which one had been written to? - // Should unused allocations be updated afterwards? We have no way of knowing - // any of these things in the current API design. - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; +UR_APIEXPORT ur_result_t UR_APICALL +urMemGetNativeHandle(ur_mem_handle_t hMem, ur_device_handle_t Device, + ur_native_handle_t *phNativeMem) { +#if defined(__HIP_PLATFORM_NVIDIA__) + if (sizeof(BufferMem::native_type) > sizeof(ur_native_handle_t)) { + // Check that all the upper bits that cannot be represented by + // ur_native_handle_t are empty. + // NOTE: The following shift might trigger a warning, but the check in the + // if above makes sure that this does not underflow. + BufferMem::native_type UpperBits = + std::get(hMem->Mem).getPtr(Device) >> + (sizeof(ur_native_handle_t) * CHAR_BIT); + if (UpperBits) { + // Return an error if any of the remaining bits is non-zero. + return UR_RESULT_ERROR_INVALID_MEM_OBJECT; + } + } + *phNativeMem = reinterpret_cast( + std::get(hMem->Mem).getPtr(Device)); +#elif defined(__HIP_PLATFORM_AMD__) + *phNativeMem = reinterpret_cast( + std::get(hMem->Mem).getPtr(Device)); +#else +#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__"); +#endif + return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( diff --git a/source/adapters/level_zero/memory.cpp b/source/adapters/level_zero/memory.cpp index fa3ef18e47..a423e55b71 100644 --- a/source/adapters/level_zero/memory.cpp +++ b/source/adapters/level_zero/memory.cpp @@ -1856,6 +1856,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition( UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t Mem, ///< [in] handle of the mem. + ur_device_handle_t, ///< [in] handle of the device. ur_native_handle_t *NativeMem ///< [out] a pointer to the native handle of the mem. ) { diff --git a/source/adapters/native_cpu/memory.cpp b/source/adapters/native_cpu/memory.cpp index a190208ab7..1f8a927c67 100644 --- a/source/adapters/native_cpu/memory.cpp +++ b/source/adapters/native_cpu/memory.cpp @@ -111,8 +111,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition( } UR_APIEXPORT ur_result_t UR_APICALL -urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) { +urMemGetNativeHandle(ur_mem_handle_t hMem, ur_device_handle_t hDevice, + ur_native_handle_t *phNativeMem) { std::ignore = hMem; + std::ignore = hDevice; std::ignore = phNativeMem; DIE_NO_IMPLEMENTATION diff --git a/source/adapters/null/ur_nullddi.cpp b/source/adapters/null/ur_nullddi.cpp index d6887ee12f..464aa59d54 100644 --- a/source/adapters/null/ur_nullddi.cpp +++ b/source/adapters/null/ur_nullddi.cpp @@ -916,6 +916,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( /// @brief Intercept function for urMemGetNativeHandle __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t + hDevice, ///< [in] handle of the device that the native handle will be resident on. ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem. ) try { @@ -924,7 +926,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( // if the driver has created a custom function, then call it instead of using the generic path auto pfnGetNativeHandle = d_context.urDdiTable.Mem.pfnGetNativeHandle; if (nullptr != pfnGetNativeHandle) { - result = pfnGetNativeHandle(hMem, phNativeMem); + result = pfnGetNativeHandle(hMem, hDevice, phNativeMem); } else { // generic implementation *phNativeMem = reinterpret_cast(d_context.get()); diff --git a/source/adapters/opencl/memory.cpp b/source/adapters/opencl/memory.cpp index e93f0731c6..86800845e8 100644 --- a/source/adapters/opencl/memory.cpp +++ b/source/adapters/opencl/memory.cpp @@ -331,8 +331,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition( return mapCLErrorToUR(RetErr); } -UR_APIEXPORT ur_result_t UR_APICALL -urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) { +UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle( + ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) { return getNativeHandle(hMem, phNativeMem); } diff --git a/source/loader/layers/tracing/ur_trcddi.cpp b/source/loader/layers/tracing/ur_trcddi.cpp index 402b64d638..5867d295ae 100644 --- a/source/loader/layers/tracing/ur_trcddi.cpp +++ b/source/loader/layers/tracing/ur_trcddi.cpp @@ -986,6 +986,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( /// @brief Intercept function for urMemGetNativeHandle __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t + hDevice, ///< [in] handle of the device that the native handle will be resident on. ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem. ) { @@ -995,11 +997,11 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; } - ur_mem_get_native_handle_params_t params = {&hMem, &phNativeMem}; + ur_mem_get_native_handle_params_t params = {&hMem, &hDevice, &phNativeMem}; uint64_t instance = context.notify_begin(UR_FUNCTION_MEM_GET_NATIVE_HANDLE, "urMemGetNativeHandle", ¶ms); - ur_result_t result = pfnGetNativeHandle(hMem, phNativeMem); + ur_result_t result = pfnGetNativeHandle(hMem, hDevice, phNativeMem); context.notify_end(UR_FUNCTION_MEM_GET_NATIVE_HANDLE, "urMemGetNativeHandle", ¶ms, &result, instance); diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index 72e225028c..db59ca3b11 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -1188,6 +1188,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( /// @brief Intercept function for urMemGetNativeHandle __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t + hDevice, ///< [in] handle of the device that the native handle will be resident on. ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem. ) { @@ -1202,12 +1204,16 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( return UR_RESULT_ERROR_INVALID_NULL_HANDLE; } + if (NULL == hDevice) { + return UR_RESULT_ERROR_INVALID_NULL_HANDLE; + } + if (NULL == phNativeMem) { return UR_RESULT_ERROR_INVALID_NULL_POINTER; } } - ur_result_t result = pfnGetNativeHandle(hMem, phNativeMem); + ur_result_t result = pfnGetNativeHandle(hMem, hDevice, phNativeMem); return result; } diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 201315272f..a3a4ccaaa0 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -1231,6 +1231,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition( /// @brief Intercept function for urMemGetNativeHandle __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t + hDevice, ///< [in] handle of the device that the native handle will be resident on. ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem. ) { @@ -1246,8 +1248,11 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle( // convert loader handle to platform handle hMem = reinterpret_cast(hMem)->handle; + // convert loader handle to platform handle + hDevice = reinterpret_cast(hDevice)->handle; + // forward to device-platform - result = pfnGetNativeHandle(hMem, phNativeMem); + result = pfnGetNativeHandle(hMem, hDevice, phNativeMem); if (UR_RESULT_SUCCESS != result) { return result; diff --git a/source/loader/ur_libapi.cpp b/source/loader/ur_libapi.cpp index 0a69fcd1e2..cd4a70c91e 100644 --- a/source/loader/ur_libapi.cpp +++ b/source/loader/ur_libapi.cpp @@ -1654,12 +1654,15 @@ ur_result_t UR_APICALL urMemBufferPartition( /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hMem` +/// + `NULL == hDevice` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phNativeMem` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If the adapter has no underlying equivalent handle. ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t + hDevice, ///< [in] handle of the device that the native handle will be resident on. ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem. ) try { @@ -1669,7 +1672,7 @@ ur_result_t UR_APICALL urMemGetNativeHandle( return UR_RESULT_ERROR_UNINITIALIZED; } - return pfnGetNativeHandle(hMem, phNativeMem); + return pfnGetNativeHandle(hMem, hDevice, phNativeMem); } catch (...) { return exceptionToResult(std::current_exception()); } diff --git a/source/ur_api.cpp b/source/ur_api.cpp index 2bcc229f29..26f24aba08 100644 --- a/source/ur_api.cpp +++ b/source/ur_api.cpp @@ -1419,12 +1419,15 @@ ur_result_t UR_APICALL urMemBufferPartition( /// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC /// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE /// + `NULL == hMem` +/// + `NULL == hDevice` /// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER /// + `NULL == phNativeMem` /// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE /// + If the adapter has no underlying equivalent handle. ur_result_t UR_APICALL urMemGetNativeHandle( ur_mem_handle_t hMem, ///< [in] handle of the mem. + ur_device_handle_t + hDevice, ///< [in] handle of the device that the native handle will be resident on. ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem. ) { diff --git a/test/conformance/memory/memory_adapter_hip.match b/test/conformance/memory/memory_adapter_hip.match index 02760dcb8a..4bb9904b04 100644 --- a/test/conformance/memory/memory_adapter_hip.match +++ b/test/conformance/memory/memory_adapter_hip.match @@ -1,7 +1,4 @@ -{{OPT}}urMemGetInfoTest.InvalidNullPointerParamValue/AMD_HIP_BACKEND___{{.*}} -{{OPT}}urMemGetInfoTest.InvalidNullPointerParamValue/AMD_HIP_BACKEND___{{.*}} -{{OPT}}urMemGetInfoTest.InvalidNullPointerPropSizeRet/AMD_HIP_BACKEND___{{.*}} -{{OPT}}urMemGetInfoTest.InvalidNullPointerPropSizeRet/AMD_HIP_BACKEND___{{.*}} +{{OPT}}urMemBufferCreateWithNativeHandleTest.Success/AMD_HIP_BACKEND___{{.*}} {{OPT}}urMemImageCreateTest.InvalidSize/AMD_HIP_BACKEND___{{.*}} {{OPT}}urMemImageGetInfoTest.Success/AMD_HIP_BACKEND___{{.*}} {{OPT}}urMemImageGetInfoTest.Success/AMD_HIP_BACKEND___{{.*}} diff --git a/test/conformance/memory/urMemBufferCreateWithNativeHandle.cpp b/test/conformance/memory/urMemBufferCreateWithNativeHandle.cpp index 2d6babd56e..573c9c0036 100644 --- a/test/conformance/memory/urMemBufferCreateWithNativeHandle.cpp +++ b/test/conformance/memory/urMemBufferCreateWithNativeHandle.cpp @@ -10,7 +10,7 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urMemBufferCreateWithNativeHandleTest); TEST_P(urMemBufferCreateWithNativeHandleTest, Success) { ur_native_handle_t hNativeMem = nullptr; - if (urMemGetNativeHandle(buffer, &hNativeMem)) { + if (urMemGetNativeHandle(buffer, device, &hNativeMem)) { GTEST_SKIP(); } diff --git a/test/conformance/memory/urMemGetNativeHandle.cpp b/test/conformance/memory/urMemGetNativeHandle.cpp index 1dbb959c3f..55f8910f72 100644 --- a/test/conformance/memory/urMemGetNativeHandle.cpp +++ b/test/conformance/memory/urMemGetNativeHandle.cpp @@ -9,7 +9,7 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urMemGetNativeHandleTest); TEST_P(urMemGetNativeHandleTest, Success) { ur_native_handle_t hNativeMem = nullptr; - if (auto error = urMemGetNativeHandle(buffer, &hNativeMem)) { + if (auto error = urMemGetNativeHandle(buffer, device, &hNativeMem)) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, error); } } @@ -17,10 +17,16 @@ TEST_P(urMemGetNativeHandleTest, Success) { TEST_P(urMemGetNativeHandleTest, InvalidNullHandleMem) { ur_native_handle_t phNativeMem; ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE, - urMemGetNativeHandle(nullptr, &phNativeMem)); + urMemGetNativeHandle(nullptr, device, &phNativeMem)); +} + +TEST_P(urMemGetNativeHandleTest, InvalidNullHandleDevice) { + ur_native_handle_t phNativeMem; + ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE, + urMemGetNativeHandle(buffer, nullptr, &phNativeMem)); } TEST_P(urMemGetNativeHandleTest, InvalidNullPointerNativeMem) { ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER, - urMemGetNativeHandle(buffer, nullptr)); + urMemGetNativeHandle(buffer, device, nullptr)); } diff --git a/test/conformance/memory/urMemImageCreateWithNativeHandle.cpp b/test/conformance/memory/urMemImageCreateWithNativeHandle.cpp index f22e6c38e5..9d800d90bc 100644 --- a/test/conformance/memory/urMemImageCreateWithNativeHandle.cpp +++ b/test/conformance/memory/urMemImageCreateWithNativeHandle.cpp @@ -10,7 +10,7 @@ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urMemImageCreateWithNativeHandleTest); TEST_P(urMemImageCreateWithNativeHandleTest, Success) { ur_native_handle_t native_handle = nullptr; - if (urMemGetNativeHandle(image, &native_handle)) { + if (urMemGetNativeHandle(image, device, &native_handle)) { GTEST_SKIP(); }