Skip to content

Commit

Permalink
Merge pull request #1856 from aarongreig/aaron/fixContextCreateWithNa…
Browse files Browse the repository at this point in the history
…tive

Make urContextCreateWithNativeHandle work for SYCL use with loader
  • Loading branch information
kbenzie committed Jul 23, 2024
2 parents a1295ba + 863c761 commit fa6bf97
Show file tree
Hide file tree
Showing 22 changed files with 172 additions and 115 deletions.
11 changes: 7 additions & 4 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2401,16 +2401,19 @@ typedef struct ur_context_native_properties_t {
/// - ::UR_RESULT_ERROR_UNINITIALIZED
/// - ::UR_RESULT_ERROR_DEVICE_LOST
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hAdapter`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevices`
/// + `NULL == phContext`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
/// + If the adapter has no underlying equivalent handle.
UR_APIEXPORT ur_result_t UR_APICALL
urContextCreateWithNativeHandle(
ur_native_handle_t hNativeContext, ///< [in][nocheck] the native handle of the context.
ur_adapter_handle_t hAdapter, ///< [in] handle of the adapter that owns the native handle
uint32_t numDevices, ///< [in] number of devices associated with the context
const ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context
const ur_device_handle_t *phDevices, ///< [in][optional][range(0, numDevices)] list of devices associated with
///< the context
const ur_context_native_properties_t *pProperties, ///< [in][optional] pointer to native context properties struct
ur_context_handle_t *phContext ///< [out] pointer to the handle of the context object created.
);
Expand Down Expand Up @@ -5627,7 +5630,6 @@ typedef struct ur_queue_native_properties_t {
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phQueue`
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
Expand All @@ -5636,7 +5638,7 @@ UR_APIEXPORT ur_result_t UR_APICALL
urQueueCreateWithNativeHandle(
ur_native_handle_t hNativeQueue, ///< [in][nocheck] the native handle of the queue.
ur_context_handle_t hContext, ///< [in] handle of the context object
ur_device_handle_t hDevice, ///< [in] handle of the device object
ur_device_handle_t hDevice, ///< [in][optional] handle of the device object
const ur_queue_native_properties_t *pProperties, ///< [in][optional] pointer to native queue properties struct
ur_queue_handle_t *phQueue ///< [out] pointer to the handle of the queue object created.
);
Expand Down Expand Up @@ -9824,6 +9826,7 @@ typedef struct ur_context_get_native_handle_params_t {
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_context_create_with_native_handle_params_t {
ur_native_handle_t *phNativeContext;
ur_adapter_handle_t *phAdapter;
uint32_t *pnumDevices;
const ur_device_handle_t **pphDevices;
const ur_context_native_properties_t **ppProperties;
Expand Down
1 change: 1 addition & 0 deletions include/ur_ddi.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnContextGetNativeHandle_t)(
/// @brief Function-pointer for urContextCreateWithNativeHandle
typedef ur_result_t(UR_APICALL *ur_pfnContextCreateWithNativeHandle_t)(
ur_native_handle_t,
ur_adapter_handle_t,
uint32_t,
const ur_device_handle_t *,
const ur_context_native_properties_t *,
Expand Down
6 changes: 6 additions & 0 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10668,6 +10668,12 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeContext)));

os << ", ";
os << ".hAdapter = ";

ur::details::printPtr(os,
*(params->phAdapter));

os << ", ";
os << ".numDevices = ";

Expand Down
5 changes: 4 additions & 1 deletion scripts/core/context.yml
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,15 @@ params:
name: hNativeContext
desc: |
[in][nocheck] the native handle of the context.
- type: $x_adapter_handle_t
name: hAdapter
desc: "[in] handle of the adapter that owns the native handle"
- type: uint32_t
name: numDevices
desc: "[in] number of devices associated with the context"
- type: "const $x_device_handle_t*"
name: phDevices
desc: "[in][range(0, numDevices)] list of devices associated with the context"
desc: "[in][optional][range(0, numDevices)] list of devices associated with the context"
- type: "const $x_context_native_properties_t*"
name: pProperties
desc: "[in][optional] pointer to native context properties struct"
Expand Down
2 changes: 1 addition & 1 deletion scripts/core/queue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ params:
desc: "[in] handle of the context object"
- type: $x_device_handle_t
name: hDevice
desc: "[in] handle of the device object"
desc: "[in][optional] handle of the device object"
- type: "const $x_queue_native_properties_t*"
name: pProperties
desc: "[in][optional] pointer to native queue properties struct"
Expand Down
1 change: 1 addition & 0 deletions source/adapters/cuda/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(

UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
[[maybe_unused]] ur_native_handle_t hNativeContext,
[[maybe_unused]] ur_adapter_handle_t hAdapter,
[[maybe_unused]] uint32_t numDevices,
[[maybe_unused]] const ur_device_handle_t *phDevices,
[[maybe_unused]] const ur_context_native_properties_t *pProperties,
Expand Down
1 change: 1 addition & 0 deletions source/adapters/hip/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(

UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
[[maybe_unused]] ur_native_handle_t hNativeContext,
[[maybe_unused]] ur_adapter_handle_t hAdapter,
[[maybe_unused]] uint32_t numDevices,
[[maybe_unused]] const ur_device_handle_t *phDevices,
[[maybe_unused]] const ur_context_native_properties_t *pProperties,
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t
NativeContext, ///< [in] the native handle of the context.
uint32_t NumDevices, const ur_device_handle_t *Devices,
ur_adapter_handle_t, uint32_t NumDevices, const ur_device_handle_t *Devices,
const ur_context_native_properties_t *Properties,
ur_context_handle_t
*Context ///< [out] pointer to the handle of the context object created.
Expand Down
12 changes: 8 additions & 4 deletions source/adapters/mock/ur_mockddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1264,10 +1264,13 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
/// @brief Intercept function for urContextCreateWithNativeHandle
__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t
hNativeContext, ///< [in][nocheck] the native handle of the context.
hNativeContext, ///< [in][nocheck] the native handle of the context.
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter that owns the native handle
uint32_t numDevices, ///< [in] number of devices associated with the context
const ur_device_handle_t *
phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context
phDevices, ///< [in][optional][range(0, numDevices)] list of devices associated with
///< the context
const ur_context_native_properties_t *
pProperties, ///< [in][optional] pointer to native context properties struct
ur_context_handle_t *
Expand All @@ -1276,7 +1279,8 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_result_t result = UR_RESULT_SUCCESS;

ur_context_create_with_native_handle_params_t params = {
&hNativeContext, &numDevices, &phDevices, &pProperties, &phContext};
&hNativeContext, &hAdapter, &numDevices,
&phDevices, &pProperties, &phContext};

auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(
mock::getCallbacks().get_before_callback(
Expand Down Expand Up @@ -4841,7 +4845,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
ur_native_handle_t
hNativeQueue, ///< [in][nocheck] the native handle of the queue.
ur_context_handle_t hContext, ///< [in] handle of the context object
ur_device_handle_t hDevice, ///< [in] handle of the device object
ur_device_handle_t hDevice, ///< [in][optional] handle of the device object
const ur_queue_native_properties_t *
pProperties, ///< [in][optional] pointer to native queue properties struct
ur_queue_handle_t
Expand Down
5 changes: 3 additions & 2 deletions source/adapters/native_cpu/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t hNativeContext, uint32_t numDevices,
const ur_device_handle_t *phDevices,
ur_native_handle_t hNativeContext, ur_adapter_handle_t hAdapter,
uint32_t numDevices, const ur_device_handle_t *phDevices,
const ur_context_native_properties_t *pProperties,
ur_context_handle_t *phContext) {
std::ignore = hNativeContext;
std::ignore = hAdapter;
std::ignore = numDevices;
std::ignore = phDevices;
std::ignore = pProperties;
Expand Down
3 changes: 2 additions & 1 deletion source/adapters/opencl/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
}

UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t hNativeContext, uint32_t, const ur_device_handle_t *,
ur_native_handle_t hNativeContext, ur_adapter_handle_t, uint32_t,
const ur_device_handle_t *,
const ur_context_native_properties_t *pProperties,
ur_context_handle_t *phContext) {

Expand Down
6 changes: 4 additions & 2 deletions source/loader/layers/sanitizer/ur_sanddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate(
__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t
hNativeContext, ///< [in][nocheck] the native handle of the getContext()->
ur_adapter_handle_t hAdapter,
uint32_t numDevices, ///< [in] number of devices associated with the context
const ur_device_handle_t *
phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context
Expand All @@ -352,8 +353,9 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(

getContext()->logger.debug("==== urContextCreateWithNativeHandle");

ur_result_t result = pfnCreateWithNativeHandle(
hNativeContext, numDevices, phDevices, pProperties, phContext);
ur_result_t result =
pfnCreateWithNativeHandle(hNativeContext, hAdapter, numDevices,
phDevices, pProperties, phContext);

if (result == UR_RESULT_SUCCESS) {
UR_CALL(setupContext(*phContext, numDevices, phDevices));
Expand Down
17 changes: 11 additions & 6 deletions source/loader/layers/tracing/ur_trcddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,10 +966,13 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
/// @brief Intercept function for urContextCreateWithNativeHandle
__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t
hNativeContext, ///< [in][nocheck] the native handle of the context.
hNativeContext, ///< [in][nocheck] the native handle of the context.
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter that owns the native handle
uint32_t numDevices, ///< [in] number of devices associated with the context
const ur_device_handle_t *
phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context
phDevices, ///< [in][optional][range(0, numDevices)] list of devices associated with
///< the context
const ur_context_native_properties_t *
pProperties, ///< [in][optional] pointer to native context properties struct
ur_context_handle_t *
Expand All @@ -983,15 +986,17 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
}

ur_context_create_with_native_handle_params_t params = {
&hNativeContext, &numDevices, &phDevices, &pProperties, &phContext};
&hNativeContext, &hAdapter, &numDevices,
&phDevices, &pProperties, &phContext};
uint64_t instance = getContext()->notify_begin(
UR_FUNCTION_CONTEXT_CREATE_WITH_NATIVE_HANDLE,
"urContextCreateWithNativeHandle", &params);

getContext()->logger.info("---> urContextCreateWithNativeHandle");

ur_result_t result = pfnCreateWithNativeHandle(
hNativeContext, numDevices, phDevices, pProperties, phContext);
ur_result_t result =
pfnCreateWithNativeHandle(hNativeContext, hAdapter, numDevices,
phDevices, pProperties, phContext);

getContext()->notify_end(UR_FUNCTION_CONTEXT_CREATE_WITH_NATIVE_HANDLE,
"urContextCreateWithNativeHandle", &params,
Expand Down Expand Up @@ -3695,7 +3700,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
ur_native_handle_t
hNativeQueue, ///< [in][nocheck] the native handle of the queue.
ur_context_handle_t hContext, ///< [in] handle of the context object
ur_device_handle_t hDevice, ///< [in] handle of the device object
ur_device_handle_t hDevice, ///< [in][optional] handle of the device object
const ur_queue_native_properties_t *
pProperties, ///< [in][optional] pointer to native queue properties struct
ur_queue_handle_t
Expand Down
27 changes: 16 additions & 11 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,10 +976,13 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
/// @brief Intercept function for urContextCreateWithNativeHandle
__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t
hNativeContext, ///< [in][nocheck] the native handle of the context.
hNativeContext, ///< [in][nocheck] the native handle of the context.
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter that owns the native handle
uint32_t numDevices, ///< [in] number of devices associated with the context
const ur_device_handle_t *
phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context
phDevices, ///< [in][optional][range(0, numDevices)] list of devices associated with
///< the context
const ur_context_native_properties_t *
pProperties, ///< [in][optional] pointer to native context properties struct
ur_context_handle_t *
Expand All @@ -993,17 +996,23 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
}

if (getContext()->enableParameterValidation) {
if (NULL == phDevices) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
if (NULL == hAdapter) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

if (NULL == phContext) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}
}

ur_result_t result = pfnCreateWithNativeHandle(
hNativeContext, numDevices, phDevices, pProperties, phContext);
if (getContext()->enableLifetimeValidation &&
!getContext()->refCountContext->isReferenceValid(hAdapter)) {
getContext()->refCountContext->logInvalidReference(hAdapter);
}

ur_result_t result =
pfnCreateWithNativeHandle(hNativeContext, hAdapter, numDevices,
phDevices, pProperties, phContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
getContext()->refCountContext->createRefCount(*phContext);
Expand Down Expand Up @@ -4175,7 +4184,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
ur_native_handle_t
hNativeQueue, ///< [in][nocheck] the native handle of the queue.
ur_context_handle_t hContext, ///< [in] handle of the context object
ur_device_handle_t hDevice, ///< [in] handle of the device object
ur_device_handle_t hDevice, ///< [in][optional] handle of the device object
const ur_queue_native_properties_t *
pProperties, ///< [in][optional] pointer to native queue properties struct
ur_queue_handle_t
Expand All @@ -4193,10 +4202,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

if (NULL == hDevice) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

if (NULL == phQueue) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}
Expand Down
21 changes: 14 additions & 7 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1044,10 +1044,13 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
/// @brief Intercept function for urContextCreateWithNativeHandle
__urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
ur_native_handle_t
hNativeContext, ///< [in][nocheck] the native handle of the context.
hNativeContext, ///< [in][nocheck] the native handle of the context.
ur_adapter_handle_t
hAdapter, ///< [in] handle of the adapter that owns the native handle
uint32_t numDevices, ///< [in] number of devices associated with the context
const ur_device_handle_t *
phDevices, ///< [in][range(0, numDevices)] list of devices associated with the context
phDevices, ///< [in][optional][range(0, numDevices)] list of devices associated with
///< the context
const ur_context_native_properties_t *
pProperties, ///< [in][optional] pointer to native context properties struct
ur_context_handle_t *
Expand All @@ -1058,14 +1061,16 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
[[maybe_unused]] auto context = getContext();

// extract platform's function pointer table
auto dditable =
reinterpret_cast<ur_device_object_t *>(*phDevices)->dditable;
auto dditable = reinterpret_cast<ur_adapter_object_t *>(hAdapter)->dditable;
auto pfnCreateWithNativeHandle =
dditable->ur.Context.pfnCreateWithNativeHandle;
if (nullptr == pfnCreateWithNativeHandle) {
return UR_RESULT_ERROR_UNINITIALIZED;
}

// convert loader handle to platform handle
hAdapter = reinterpret_cast<ur_adapter_object_t *>(hAdapter)->handle;

// convert loader handles to platform handles
auto phDevicesLocal = std::vector<ur_device_handle_t>(numDevices);
for (size_t i = 0; i < numDevices; ++i) {
Expand All @@ -1074,7 +1079,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
}

// forward to device-platform
result = pfnCreateWithNativeHandle(hNativeContext, numDevices,
result = pfnCreateWithNativeHandle(hNativeContext, hAdapter, numDevices,
phDevicesLocal.data(), pProperties,
phContext);

Expand Down Expand Up @@ -3910,7 +3915,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
ur_native_handle_t
hNativeQueue, ///< [in][nocheck] the native handle of the queue.
ur_context_handle_t hContext, ///< [in] handle of the context object
ur_device_handle_t hDevice, ///< [in] handle of the device object
ur_device_handle_t hDevice, ///< [in][optional] handle of the device object
const ur_queue_native_properties_t *
pProperties, ///< [in][optional] pointer to native queue properties struct
ur_queue_handle_t
Expand All @@ -3932,7 +3937,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;

// convert loader handle to platform handle
hDevice = reinterpret_cast<ur_device_object_t *>(hDevice)->handle;
hDevice = (hDevice)
? reinterpret_cast<ur_device_object_t *>(hDevice)->handle
: nullptr;

// forward to device-platform
result = pfnCreateWithNativeHandle(hNativeQueue, hContext, hDevice,
Expand Down
Loading

0 comments on commit fa6bf97

Please sign in to comment.