Skip to content

Commit

Permalink
Merge pull request #1953 from aarongreig/aaron/changeDeviceCreateWith…
Browse files Browse the repository at this point in the history
…NativeParam

Change urDeviceCreateWithNativeHandle to take an adapter handle.
  • Loading branch information
omarahmed1111 authored Aug 15, 2024
2 parents d7e0fad + a4c6e91 commit 0342c95
Show file tree
Hide file tree
Showing 17 changed files with 53 additions and 69 deletions.
6 changes: 3 additions & 3 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2062,15 +2062,15 @@ typedef struct ur_device_native_properties_t {
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// + `NULL == hAdapter`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevice`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
UR_APIEXPORT ur_result_t UR_APICALL
urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t *phDevice ///< [out] pointer to the handle of the device object created.
);
Expand Down Expand Up @@ -11972,7 +11972,7 @@ typedef struct ur_device_get_native_handle_params_t {
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_device_create_with_native_handle_params_t {
ur_native_handle_t *phNativeDevice;
ur_platform_handle_t *phPlatform;
ur_adapter_handle_t *phAdapter;
const ur_device_native_properties_t **ppProperties;
ur_device_handle_t **pphDevice;
} ur_device_create_with_native_handle_params_t;
Expand Down
2 changes: 1 addition & 1 deletion include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -2373,7 +2373,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnDeviceGetNativeHandle_t)(
/// @brief Function-pointer for urDeviceCreateWithNativeHandle
typedef ur_result_t(UR_APICALL *ur_pfnDeviceCreateWithNativeHandle_t)(
ur_native_handle_t,
ur_platform_handle_t,
ur_adapter_handle_t,
const ur_device_native_properties_t *,
ur_device_handle_t *);

Expand Down
4 changes: 2 additions & 2 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17357,10 +17357,10 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
*(params->phNativeDevice)));

os << ", ";
os << ".hPlatform = ";
os << ".hAdapter = ";

ur::details::printPtr(os,
*(params->phPlatform));
*(params->phAdapter));

os << ", ";
os << ".pProperties = ";
Expand Down
6 changes: 3 additions & 3 deletions scripts/core/device.yml
Original file line number Diff line number Diff line change
Expand Up @@ -820,9 +820,9 @@ params:
- type: $x_native_handle_t
name: hNativeDevice
desc: "[in][nocheck] the native handle of the device."
- type: $x_platform_handle_t
name: hPlatform
desc: "[in] handle of the platform instance"
- type: $x_adapter_handle_t
name: hAdapter
desc: "[in] handle of the adapter to which `hNativeDevice` belongs"
- type: const $x_device_native_properties_t*
name: pProperties
desc: "[in][optional] pointer to native device properties struct."
Expand Down
17 changes: 3 additions & 14 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1185,27 +1185,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
/// \return TBD

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
const ur_device_native_properties_t *pProperties,
ur_native_handle_t hNativeDevice,
[[maybe_unused]] ur_adapter_handle_t hAdapter,
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
ur_device_handle_t *phDevice) {
std::ignore = pProperties;

CUdevice CuDevice = static_cast<CUdevice>(hNativeDevice);

auto IsDevice = [=](std::unique_ptr<ur_device_handle_t_> &Dev) {
return Dev->get() == CuDevice;
};

// If a platform is provided just check if the device is in it
if (hPlatform) {
auto SearchRes = std::find_if(begin(hPlatform->Devices),
end(hPlatform->Devices), IsDevice);
if (SearchRes != end(hPlatform->Devices)) {
*phDevice = SearchRes->get();
return UR_RESULT_SUCCESS;
}
}

// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
Expand Down
13 changes: 2 additions & 11 deletions source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
ur_native_handle_t hNativeDevice,
[[maybe_unused]] ur_adapter_handle_t hAdapter,
[[maybe_unused]] const ur_device_native_properties_t *pProperties,
ur_device_handle_t *phDevice) {
// We can't cast between ur_native_handle_t and hipDevice_t, so memcpy the
Expand All @@ -1000,16 +1001,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
return Dev->get() == HIPDevice;
};

// If a platform is provided just check if the device is in it
if (hPlatform) {
auto SearchRes = std::find_if(begin(hPlatform->Devices),
end(hPlatform->Devices), IsDevice);
if (SearchRes != end(hPlatform->Devices)) {
*phDevice = SearchRes->get();
return UR_RESULT_SUCCESS;
}
}

// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
Expand Down
12 changes: 3 additions & 9 deletions source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1600,14 +1600,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t NativeDevice, ///< [in] the native handle of the device.
ur_platform_handle_t Platform, ///< [in] handle of the platform instance
const ur_device_native_properties_t
[[maybe_unused]] ur_adapter_handle_t
Adapter, ///< [in] handle of the platform instance
[[maybe_unused]] const ur_device_native_properties_t
*Properties, ///< [in][optional] pointer to native device properties
///< struct.
ur_device_handle_t
*Device ///< [out] pointer to the handle of the device object created.
) {
std::ignore = Properties;
auto ZeDevice = ur_cast<ze_device_handle_t>(NativeDevice);

// The SYCL spec requires that the set of devices must remain fixed for the
Expand All @@ -1620,12 +1620,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
if (const auto *platforms = GlobalAdapter->PlatformCache->get_value()) {
for (const auto &p : *platforms) {
Dev = p->getDeviceFromNativeHandle(ZeDevice);
if (Dev) {
// Check that the input Platform, if was given, matches the found one.
UR_ASSERT(!Platform || Platform == p.get(),
UR_RESULT_ERROR_INVALID_PLATFORM);
break;
}
}
} else {
return GlobalAdapter->PlatformCache->get_error();
Expand Down
5 changes: 3 additions & 2 deletions source/adapters/mock/ur_mockddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -930,7 +931,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

ur_device_create_with_native_handle_params_t params = {
&hNativeDevice, &hPlatform, &pProperties, &phDevice};
&hNativeDevice, &hAdapter, &pProperties, &phDevice};

auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(
mock::getCallbacks().get_before_callback(
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/native_cpu/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t hPlatform,
ur_native_handle_t hNativeDevice, ur_adapter_handle_t hAdapter,
const ur_device_native_properties_t *pProperties,
ur_device_handle_t *phDevice) {
std::ignore = hNativeDevice;
std::ignore = hPlatform;
std::ignore = hAdapter;
std::ignore = pProperties;
std::ignore = phDevice;

Expand Down
2 changes: 1 addition & 1 deletion source/adapters/opencl/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t hNativeDevice, ur_platform_handle_t,
ur_native_handle_t hNativeDevice, ur_adapter_handle_t,
const ur_device_native_properties_t *, ur_device_handle_t *phDevice) {

*phDevice = reinterpret_cast<ur_device_handle_t>(hNativeDevice);
Expand Down
7 changes: 4 additions & 3 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -719,14 +720,14 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
}

ur_device_create_with_native_handle_params_t params = {
&hNativeDevice, &hPlatform, &pProperties, &phDevice};
&hNativeDevice, &hAdapter, &pProperties, &phDevice};
uint64_t instance =
getContext()->notify_begin(UR_FUNCTION_DEVICE_CREATE_WITH_NATIVE_HANDLE,
"urDeviceCreateWithNativeHandle", &params);

getContext()->logger.info("---> urDeviceCreateWithNativeHandle");

ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hPlatform,
ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hAdapter,
pProperties, phDevice);

getContext()->notify_end(UR_FUNCTION_DEVICE_CREATE_WITH_NATIVE_HANDLE,
Expand Down
12 changes: 9 additions & 3 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -733,7 +734,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
}

if (getContext()->enableParameterValidation) {
if (NULL == hPlatform) {
if (NULL == hAdapter) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

Expand All @@ -742,7 +743,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
}
}

ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hPlatform,
if (getContext()->enableLifetimeValidation &&
!getContext()->refCountContext->isReferenceValid(hAdapter)) {
getContext()->refCountContext->logInvalidReference(hAdapter);
}

ur_result_t result = pfnCreateWithNativeHandle(hNativeDevice, hAdapter,
pProperties, phDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
Expand Down
10 changes: 5 additions & 5 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,8 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
__urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -775,19 +776,18 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
[[maybe_unused]] auto context = getContext();

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
auto dditable = reinterpret_cast<ur_adapter_object_t *>(hAdapter)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Device.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;
hAdapter = reinterpret_cast<ur_adapter_object_t *>(hAdapter)->handle;

// forward to device-platform
result = pfnCreateWithNativeHandle(hNativeDevice, hPlatform, pProperties,
result = pfnCreateWithNativeHandle(hNativeDevice, hAdapter, pProperties,
phDevice);

if (UR_RESULT_SUCCESS != result) {
Expand Down
7 changes: 4 additions & 3 deletions source/loader/ur_libapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1135,15 +1135,16 @@ ur_result_t UR_APICALL urDeviceGetNativeHandle(
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// + `NULL == hAdapter`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevice`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand All @@ -1155,7 +1156,7 @@ ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
return UR_RESULT_ERROR_UNINITIALIZED;
}

return pfnCreateWithNativeHandle(hNativeDevice, hPlatform, pProperties,
return pfnCreateWithNativeHandle(hNativeDevice, hAdapter, pProperties,
phDevice);
} catch (...) {
return exceptionToResult(std::current_exception());
Expand Down
5 changes: 3 additions & 2 deletions source/ur_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,15 +997,16 @@ ur_result_t UR_APICALL urDeviceGetNativeHandle(
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hPlatform`
/// + `NULL == hAdapter`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevice`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_native_handle_t
hNativeDevice, ///< [in][nocheck] the native handle of the device.
ur_platform_handle_t hPlatform, ///< [in] handle of the platform instance
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter to which `hNativeDevice` belongs
const ur_device_native_properties_t *
pProperties, ///< [in][optional] pointer to native device properties struct.
ur_device_handle_t
Expand Down
2 changes: 1 addition & 1 deletion test/adapters/cuda/urDeviceCreateWithNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ TEST_F(urCudaDeviceCreateWithNativeHandle, Success) {

ur_native_handle_t nativeCuda = static_cast<ur_native_handle_t>(cudaDevice);
ur_device_handle_t urDevice;
ASSERT_SUCCESS(urDeviceCreateWithNativeHandle(nativeCuda, platform, nullptr,
ASSERT_SUCCESS(urDeviceCreateWithNativeHandle(nativeCuda, adapter, nullptr,
&urDevice));
}
8 changes: 4 additions & 4 deletions test/conformance/device/urDeviceCreateWithNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, Success) {
// and perform some query on it to verify that it works.
ur_device_handle_t dev = nullptr;
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(urDeviceCreateWithNativeHandle(
native_handle, platform, nullptr, &dev));
native_handle, adapter, nullptr, &dev));
ASSERT_NE(dev, nullptr);

uint32_t dev_id = 0;
Expand All @@ -41,7 +41,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, SuccessWithOwnedNativeHandle) {
ur_device_native_properties_t props{
UR_STRUCTURE_TYPE_DEVICE_NATIVE_PROPERTIES, nullptr, true};
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(urDeviceCreateWithNativeHandle(
native_handle, platform, &props, &dev));
native_handle, adapter, &props, &dev));
ASSERT_NE(dev, nullptr);

uint32_t ref_count = 0;
Expand All @@ -64,7 +64,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, SuccessWithUnOwnedNativeHandle) {
ur_device_native_properties_t props{
UR_STRUCTURE_TYPE_DEVICE_NATIVE_PROPERTIES, nullptr, false};
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(urDeviceCreateWithNativeHandle(
native_handle, platform, &props, &dev));
native_handle, adapter, &props, &dev));
ASSERT_NE(dev, nullptr);

uint32_t ref_count = 0;
Expand Down Expand Up @@ -93,7 +93,7 @@ TEST_F(urDeviceCreateWithNativeHandleTest, InvalidNullPointerDevice) {
ASSERT_SUCCESS(urDeviceGetNativeHandle(device, &native_handle));

ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER,
urDeviceCreateWithNativeHandle(native_handle, platform,
urDeviceCreateWithNativeHandle(native_handle, adapter,
nullptr, nullptr));
}
}

0 comments on commit 0342c95

Please sign in to comment.