From 72941708b5897c020e8a71bf2a71d8d0a383414f Mon Sep 17 00:00:00 2001 From: Fraser Cormack Date: Mon, 20 May 2024 16:44:20 +0100 Subject: [PATCH] [HIP] Implement urDeviceGetNativeHandle This is mostly just a copy of the CUDA version of this implementation. --- source/adapters/hip/device.cpp | 54 ++++++++++++++++++- .../device/device_adapter_hip.match | 2 - 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/source/adapters/hip/device.cpp b/source/adapters/hip/device.cpp index dd20a4f50f..8436849b28 100644 --- a/source/adapters/hip/device.cpp +++ b/source/adapters/hip/device.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "device.hpp" +#include "adapter.hpp" #include "context.hpp" #include "event.hpp" @@ -950,8 +951,57 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle( } UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle( - ur_native_handle_t, ur_platform_handle_t, - const ur_device_native_properties_t *, ur_device_handle_t *) { + ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform, + [[maybe_unused]] const ur_device_native_properties_t *pProperties, + ur_device_handle_t *phDevice) { + // We can't cast between ur_native_handle_t and hipDevice_t, so memcpy the + // bits instead + hipDevice_t HIPDevice = 0; + memcpy(&HIPDevice, &hNativeDevice, sizeof(hipDevice_t)); + + auto IsDevice = [=](std::unique_ptr &Dev) { + return Dev->get() == HIPDevice; + }; + + // If a platform is provided just check if the device is in it + if (hPlatform) { + auto SearchRes = std::find_if(begin(hPlatform->Devices), + end(hPlatform->Devices), IsDevice); + if (SearchRes != end(hPlatform->Devices)) { + *phDevice = SearchRes->get(); + return UR_RESULT_SUCCESS; + } + } + + // Get list of platforms + uint32_t NumPlatforms = 0; + ur_adapter_handle_t AdapterHandle = &adapter; + ur_result_t Result = + urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms); + if (Result != UR_RESULT_SUCCESS) + return Result; + + // We can only have a maximum of one platform. + if (NumPlatforms != 1) + return UR_RESULT_ERROR_INVALID_OPERATION; + + ur_platform_handle_t Platform = nullptr; + + Result = urPlatformGet(&AdapterHandle, 1, NumPlatforms, &Platform, nullptr); + if (Result != UR_RESULT_SUCCESS) + return Result; + + // Iterate through the platform's devices to find the device that matches + // nativeHandle + auto SearchRes = std::find_if(std::begin(Platform->Devices), + std::end(Platform->Devices), IsDevice); + if (SearchRes != end(Platform->Devices)) { + *phDevice = static_cast((*SearchRes).get()); + return UR_RESULT_SUCCESS; + } + + // If the provided nativeHandle cannot be matched to an + // existing device return error return UR_RESULT_ERROR_INVALID_OPERATION; } diff --git a/test/conformance/device/device_adapter_hip.match b/test/conformance/device/device_adapter_hip.match index f64efa4bac..9989fbd774 100644 --- a/test/conformance/device/device_adapter_hip.match +++ b/test/conformance/device/device_adapter_hip.match @@ -1,4 +1,2 @@ -urDeviceCreateWithNativeHandleTest.Success -urDeviceCreateWithNativeHandleTest.SuccessWithOwnedNativeHandle urDeviceCreateWithNativeHandleTest.SuccessWithUnOwnedNativeHandle {{OPT}}urDeviceGetGlobalTimestampTest.SuccessSynchronizedTime