Skip to content

Commit

Permalink
Change ur_native_handle_t to be uintptr_t
Browse files Browse the repository at this point in the history
  • Loading branch information
callumfare committed Jul 3, 2024
1 parent 00ca0da commit 36ca9f1
Show file tree
Hide file tree
Showing 37 changed files with 93 additions and 86 deletions.
2 changes: 1 addition & 1 deletion include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ typedef struct ur_queue_handle_t_ *ur_queue_handle_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Handle of a native object
typedef struct ur_native_handle_t_ *ur_native_handle_t;
typedef uintptr_t ur_native_handle_t;

///////////////////////////////////////////////////////////////////////////////
/// @brief Handle of a Sampler object
Expand Down
40 changes: 20 additions & 20 deletions include/ur_print.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10376,8 +10376,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativePlatform = ";

ur::details::printPtr(os,
*(params->phNativePlatform));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativePlatform)));

os << ", ";
os << ".hAdapter = ";
Expand Down Expand Up @@ -10573,8 +10573,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeContext = ";

ur::details::printPtr(os,
*(params->phNativeContext));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeContext)));

os << ", ";
os << ".numDevices = ";
Expand Down Expand Up @@ -10783,8 +10783,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeEvent = ";

ur::details::printPtr(os,
*(params->phNativeEvent));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeEvent)));

os << ", ";
os << ".hContext = ";
Expand Down Expand Up @@ -11377,8 +11377,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeProgram = ";

ur::details::printPtr(os,
*(params->phNativeProgram));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeProgram)));

os << ", ";
os << ".hContext = ";
Expand Down Expand Up @@ -11597,8 +11597,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeKernel = ";

ur::details::printPtr(os,
*(params->phNativeKernel));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeKernel)));

os << ", ";
os << ".hContext = ";
Expand Down Expand Up @@ -12046,8 +12046,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeQueue = ";

ur::details::printPtr(os,
*(params->phNativeQueue));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeQueue)));

os << ", ";
os << ".hContext = ";
Expand Down Expand Up @@ -12220,8 +12220,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeSampler = ";

ur::details::printPtr(os,
*(params->phNativeSampler));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeSampler)));

os << ", ";
os << ".hContext = ";
Expand Down Expand Up @@ -12424,8 +12424,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeMem = ";

ur::details::printPtr(os,
*(params->phNativeMem));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeMem)));

os << ", ";
os << ".hContext = ";
Expand Down Expand Up @@ -12456,8 +12456,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeMem = ";

ur::details::printPtr(os,
*(params->phNativeMem));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeMem)));

os << ", ";
os << ".hContext = ";
Expand Down Expand Up @@ -17206,8 +17206,8 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct

os << ".hNativeDevice = ";

ur::details::printPtr(os,
*(params->phNativeDevice));
ur::details::printPtr(os, reinterpret_cast<void *>(
*(params->phNativeDevice)));

os << ", ";
os << ".hPlatform = ";
Expand Down
4 changes: 3 additions & 1 deletion scripts/templates/api.h.mako
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ ${th.make_func_name(n, tags, obj)}(
);
## HANDLE #####################################################################
%elif re.match(r"handle", obj['type']):
%if 'alias' in obj:
%if th.type_traits.is_native_handle(obj['name']):
typedef uintptr_t ${th.subt(n, tags, obj['name'])};
%elif 'alias' in obj:
typedef ${th.subt(n, tags, obj['alias'])} ${th.subt(n, tags, obj['name'])};
%else:
typedef struct ${th.subt(n, tags, obj['name'])}_ *${th.subt(n, tags, obj['name'])};
Expand Down
8 changes: 8 additions & 0 deletions scripts/templates/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def get_handle(item, meta):
"""
class type_traits:
RE_HANDLE = r"(.*)handle_t"
RE_NATIVE_HANDLE = r"(.*)native_handle_t"
RE_IPC = r"(.*)ipc(.*)handle_t"
RE_POINTER = r"(.*\w+)\*+"
RE_PPOINTER = r"(.*\w+)\*{2,}"
Expand All @@ -128,6 +129,13 @@ def is_handle(cls, name):
except:
return False

@classmethod
def is_native_handle(cls, name):
try:
return True if re.match(cls.RE_NATIVE_HANDLE, name) else False
except:
return False

@classmethod
def is_pointer_to_pointer(cls, name):
try:
Expand Down
2 changes: 2 additions & 0 deletions scripts/templates/print.hpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ from templates import helper as th
${x}::details::printPtr(os, ${caller.body()});
%elif loop and th.type_traits.is_pointer_to_pointer(itype):
${x}::details::printPtr(os, ${caller.body()});
%elif th.type_traits.is_native_handle(itype):
${x}::details::printPtr(os, reinterpret_cast<void*>(${caller.body()}));
%elif th.type_traits.is_handle(itype):
${x}::details::printPtr(os, ${caller.body()});
%elif iname and iname.startswith("pfn"):
Expand Down
5 changes: 1 addition & 4 deletions source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1185,10 +1185,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
ur_device_handle_t *phDevice) {
std::ignore = pProperties;

// We can't cast between ur_native_handle_t and CUdevice, so memcpy the bits
// instead
CUdevice CuDevice = 0;
memcpy(&CuDevice, &hNativeDevice, sizeof(CUdevice));
CUdevice CuDevice = static_cast<CUdevice>(hNativeDevice);

auto IsDevice = [=](std::unique_ptr<ur_device_handle_t_> &Dev) {
return Dev->get() == CuDevice;
Expand Down
3 changes: 1 addition & 2 deletions source/adapters/cuda/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ urMemGetNativeHandle(ur_mem_handle_t hMem, ur_device_handle_t Device,
ur_native_handle_t *phNativeMem) {
UR_ASSERT(Device != nullptr, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
try {
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
std::get<BufferMem>(hMem->Mem).getPtr(Device));
*phNativeMem = std::get<BufferMem>(hMem->Mem).getPtr(Device);
} catch (ur_result_t Err) {
return Err;
} catch (...) {
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/hip/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGet(ur_platform_handle_t hPlatform,
/// \return UR_RESULT_SUCCESS
UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetNativeHandle(
ur_device_handle_t hDevice, ur_native_handle_t *phNativeHandle) {
*phNativeHandle = reinterpret_cast<ur_native_handle_t>(hDevice->get());
*phNativeHandle = hDevice->get();
return UR_RESULT_SUCCESS;
}

Expand Down
2 changes: 1 addition & 1 deletion source/adapters/level_zero/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventCreateWithNativeHandle(

// we dont have urEventCreate, so use this check for now to know that
// the call comes from urEventCreate()
if (NativeEvent == nullptr) {
if (reinterpret_cast<ze_event_handle_t>(NativeEvent) == nullptr) {
UR_CALL(EventCreate(Context, nullptr, false, true, Event));

(*Event)->RefCountExternal++;
Expand Down
2 changes: 1 addition & 1 deletion source/adapters/opencl/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(

*phQueue = reinterpret_cast<ur_queue_handle_t>(hNativeQueue);
cl_int RetErr =
clRetainCommandQueue(cl_adapter::cast<cl_command_queue>(hNativeQueue));
clRetainCommandQueue(cl_adapter::cast<cl_command_queue>(*phQueue));
CL_RETURN_ON_FAILURE(RetErr);
return UR_RESULT_SUCCESS;
}
Expand Down
3 changes: 1 addition & 2 deletions source/adapters/opencl/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urSamplerCreateWithNativeHandle(
ur_native_handle_t hNativeSampler, ur_context_handle_t,
const ur_sampler_native_properties_t *pProperties,
ur_sampler_handle_t *phSampler) {
*phSampler = reinterpret_cast<ur_sampler_handle_t>(
cl_adapter::cast<cl_sampler>(hNativeSampler));
*phSampler = reinterpret_cast<ur_sampler_handle_t>(hNativeSampler);
if (!pProperties || !pProperties->isNativeHandleOwned) {
return urSamplerRetain(*phSampler);
}
Expand Down
4 changes: 2 additions & 2 deletions source/common/ur_pool_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
const pool_descriptor &lhs = *this;
const pool_descriptor &rhs = other;
ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr;
ur_native_handle_t lhsNative = 0, rhsNative = 0;

// We want to share a memory pool for sub-devices and sub-sub devices.
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
Expand Down Expand Up @@ -264,7 +264,7 @@ namespace std {
/// @brief hash specialization for usm::pool_descriptor
template <> struct hash<usm::pool_descriptor> {
inline size_t operator()(const usm::pool_descriptor &desc) const {
ur_native_handle_t native = nullptr;
ur_native_handle_t native = 0;
if (desc.hDevice) {
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
if (ret != UR_RESULT_SUCCESS) {
Expand Down
4 changes: 2 additions & 2 deletions test/adapters/cuda/context_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ TEST_P(cudaUrContextCreateTest, ActiveContext) {
ASSERT_SUCCESS_CUDA(cuCtxGetCurrent(&cudaCtx));
ASSERT_NE(cudaCtx, nullptr);

ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
ASSERT_SUCCESS(urContextGetNativeHandle(context, &native_context));
ASSERT_NE(native_context, nullptr);
ASSERT_NE(reinterpret_cast<CUcontext>(native_context), nullptr);
ASSERT_EQ(cudaCtx, reinterpret_cast<CUcontext>(native_context));
}

Expand Down
2 changes: 1 addition & 1 deletion test/adapters/cuda/urContextGetNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using urCudaContextGetNativeHandle = uur::urContextTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urCudaContextGetNativeHandle);

TEST_P(urCudaContextGetNativeHandle, Success) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
ASSERT_SUCCESS(urContextGetNativeHandle(context, &native_context));
CUcontext cuda_context = reinterpret_cast<CUcontext>(native_context);

Expand Down
2 changes: 1 addition & 1 deletion test/adapters/cuda/urEventGetNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ TEST_P(urCudaEventGetNativeHandleTest, Success) {
0, buffer_size, 0, nullptr,
event.ptr()));

ur_native_handle_t native_event = nullptr;
ur_native_handle_t native_event = 0;
ASSERT_SUCCESS(urEventGetNativeHandle(event, &native_event));
CUevent cuda_event = reinterpret_cast<CUevent>(native_event);

Expand Down
2 changes: 1 addition & 1 deletion test/adapters/hip/urContextGetNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using urHipContextGetNativeHandleTest = uur::urContextTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urHipContextGetNativeHandleTest);

TEST_P(urHipContextGetNativeHandleTest, Success) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
ASSERT_SUCCESS(urContextGetNativeHandle(context, &native_context));
hipCtx_t hip_context = reinterpret_cast<hipCtx_t>(native_context);
std::ignore = hip_context;
Expand Down
2 changes: 1 addition & 1 deletion test/adapters/hip/urEventGetNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ TEST_P(urHipEventGetNativeHandleTest, Success) {
0, buffer_size, 0, nullptr,
event.ptr()));

ur_native_handle_t native_event = nullptr;
ur_native_handle_t native_event = 0;
ASSERT_SUCCESS(urEventGetNativeHandle(event, &native_event));
hipEvent_t hip_event = reinterpret_cast<hipEvent_t>(native_event);

Expand Down
12 changes: 6 additions & 6 deletions test/adapters/level_zero/multi_device_event_cache_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ TEST_F(urMultiQueueMultiDeviceEventCacheTest,
uur::raii::Event eventWait = nullptr;
uur::raii::Event eventWaitDummy = nullptr;
(*ZeCallCount)["zeCommandListAppendWaitOnEvents"] = 0;
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context2, nullptr,
eventWait.ptr()));
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context1, nullptr,
EXPECT_SUCCESS(
urEventCreateWithNativeHandle(0, context2, nullptr, eventWait.ptr()));
EXPECT_SUCCESS(urEventCreateWithNativeHandle(0, context1, nullptr,
eventWaitDummy.ptr()));
EXPECT_SUCCESS(
urEnqueueEventsWait(queue1, 1, eventWaitDummy.ptr(), eventWait.ptr()));
Expand Down Expand Up @@ -90,9 +90,9 @@ TEST_F(urMultiQueueMultiDeviceEventCacheTest,
uur::raii::Event eventWait = nullptr;
uur::raii::Event eventWaitDummy = nullptr;
(*ZeCallCount)["zeCommandListAppendWaitOnEvents"] = 0;
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context2, nullptr,
eventWait.ptr()));
EXPECT_SUCCESS(urEventCreateWithNativeHandle(nullptr, context1, nullptr,
EXPECT_SUCCESS(
urEventCreateWithNativeHandle(0, context2, nullptr, eventWait.ptr()));
EXPECT_SUCCESS(urEventCreateWithNativeHandle(0, context1, nullptr,
eventWaitDummy.ptr()));
EXPECT_SUCCESS(
urEnqueueEventsWait(queue1, 1, eventWaitDummy.ptr(), eventWait.ptr()));
Expand Down
10 changes: 5 additions & 5 deletions test/conformance/context/urContextCreateWithNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using urContextCreateWithNativeHandleTest = uur::urContextTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urContextCreateWithNativeHandleTest);

TEST_P(urContextCreateWithNativeHandleTest, Success) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
{
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(
urContextGetNativeHandle(context, &native_context));
Expand All @@ -33,7 +33,7 @@ TEST_P(urContextCreateWithNativeHandleTest, Success) {
}

TEST_P(urContextCreateWithNativeHandleTest, SuccessWithOwnedNativeHandle) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
{
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(
urContextGetNativeHandle(context, &native_context));
Expand All @@ -53,7 +53,7 @@ TEST_P(urContextCreateWithNativeHandleTest, SuccessWithOwnedNativeHandle) {
}

TEST_P(urContextCreateWithNativeHandleTest, SuccessWithUnOwnedNativeHandle) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
{
UUR_ASSERT_SUCCESS_OR_UNSUPPORTED(
urContextGetNativeHandle(context, &native_context));
Expand All @@ -75,7 +75,7 @@ TEST_P(urContextCreateWithNativeHandleTest, SuccessWithUnOwnedNativeHandle) {
}

TEST_P(urContextCreateWithNativeHandleTest, InvalidNullPointerDevices) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
ASSERT_SUCCESS(urContextGetNativeHandle(context, &native_context));

ur_context_handle_t ctx = nullptr;
Expand All @@ -85,7 +85,7 @@ TEST_P(urContextCreateWithNativeHandleTest, InvalidNullPointerDevices) {
}

TEST_P(urContextCreateWithNativeHandleTest, InvalidNullPointerContext) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
ASSERT_SUCCESS(urContextGetNativeHandle(context, &native_context));

ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_POINTER,
Expand Down
4 changes: 2 additions & 2 deletions test/conformance/context/urContextGetNativeHandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ using urContextGetNativeHandleTest = uur::urContextTest;
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urContextGetNativeHandleTest);

TEST_P(urContextGetNativeHandleTest, Success) {
ur_native_handle_t native_context = nullptr;
ur_native_handle_t native_context = 0;
if (auto error = urContextGetNativeHandle(context, &native_context)) {
ASSERT_EQ_RESULT(UR_RESULT_ERROR_UNSUPPORTED_FEATURE, error);
}
}

TEST_P(urContextGetNativeHandleTest, InvalidNullHandleContext) {
ur_native_handle_t native_handle = nullptr;
ur_native_handle_t native_handle = 0;
ASSERT_EQ_RESULT(UR_RESULT_ERROR_INVALID_NULL_HANDLE,
urContextGetNativeHandle(nullptr, &native_handle));
}
Expand Down
Loading

0 comments on commit 36ca9f1

Please sign in to comment.