Skip to content

Commit

Permalink
Merge pull request #715 from callumfare/callum/adapter_handle
Browse files Browse the repository at this point in the history
Implement adapter instance handles
  • Loading branch information
callumfare committed Jul 19, 2023
2 parents e8e96ce + b279985 commit 974a7d6
Show file tree
Hide file tree
Showing 40 changed files with 2,402 additions and 514 deletions.
25 changes: 23 additions & 2 deletions examples/hello_world/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,36 @@ int main(int argc, char *argv[]) {
}
std::cout << "Platform initialized.\n";

uint32_t adapterCount = 0;
std::vector<ur_adapter_handle_t> adapters;
uint32_t platformCount = 0;
std::vector<ur_platform_handle_t> platforms;

status = urPlatformGet(1, nullptr, &platformCount);
status = urAdapterGet(0, nullptr, &adapterCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urAdapterGet failed with return code: " << status
<< std::endl;
return 1;
}
adapters.resize(adapterCount);
status = urAdapterGet(adapterCount, adapters.data(), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urAdapterGet failed with return code: " << status
<< std::endl;
return 1;
}

status = urPlatformGet(adapters.data(), adapterCount, 1, nullptr,
&platformCount);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
goto out;
}

platforms.resize(platformCount);
status = urPlatformGet(platformCount, platforms.data(), nullptr);
status = urPlatformGet(adapters.data(), adapterCount, platformCount,
platforms.data(), nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urPlatformGet failed with return code: " << status
<< std::endl;
Expand Down Expand Up @@ -98,6 +116,9 @@ int main(int argc, char *argv[]) {
}

out:
for (auto adapter : adapters) {
urAdapterRelease(adapter);
}
urTearDown(nullptr);
return status == UR_RESULT_SUCCESS ? 0 : 1;
}
101 changes: 88 additions & 13 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class ur_function_v(IntEnum):
BINDLESS_IMAGES_DESTROY_EXTERNAL_SEMAPHORE_EXP = 147## Enumerator for ::urBindlessImagesDestroyExternalSemaphoreExp
BINDLESS_IMAGES_WAIT_EXTERNAL_SEMAPHORE_EXP = 148 ## Enumerator for ::urBindlessImagesWaitExternalSemaphoreExp
BINDLESS_IMAGES_SIGNAL_EXTERNAL_SEMAPHORE_EXP = 149 ## Enumerator for ::urBindlessImagesSignalExternalSemaphoreExp
PLATFORM_GET_LAST_ERROR = 150 ## Enumerator for ::urPlatformGetLastError
ENQUEUE_USM_FILL_2D = 151 ## Enumerator for ::urEnqueueUSMFill2D
ENQUEUE_USM_MEMCPY_2D = 152 ## Enumerator for ::urEnqueueUSMMemcpy2D
VIRTUAL_MEM_GRANULARITY_GET_INFO = 153 ## Enumerator for ::urVirtualMemGranularityGetInfo
Expand All @@ -192,6 +191,11 @@ class ur_function_v(IntEnum):
LOADER_CONFIG_RETAIN = 174 ## Enumerator for ::urLoaderConfigRetain
LOADER_CONFIG_GET_INFO = 175 ## Enumerator for ::urLoaderConfigGetInfo
LOADER_CONFIG_ENABLE_LAYER = 176 ## Enumerator for ::urLoaderConfigEnableLayer
ADAPTER_RELEASE = 177 ## Enumerator for ::urAdapterRelease
ADAPTER_GET = 178 ## Enumerator for ::urAdapterGet
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -288,6 +292,11 @@ class ur_bool_t(c_ubyte):
class ur_loader_config_handle_t(c_void_p):
pass

###############################################################################
## @brief Handle of an adapter instance
class ur_adapter_handle_t(c_void_p):
pass

###############################################################################
## @brief Handle of a platform instance
class ur_platform_handle_t(c_void_p):
Expand Down Expand Up @@ -501,6 +510,36 @@ def __str__(self):
return str(ur_loader_config_info_v(self.value))


###############################################################################
## @brief Supported adapter info
class ur_adapter_info_v(IntEnum):
BACKEND = 0 ## [::ur_adapter_backend_t] Identifies the native backend supported by
## the adapter.
REFERENCE_COUNT = 1 ## [uint32_t] Reference count of the adapter.
## The reference count returned should be considered immediately stale.
## It is unsuitable for general use in applications. This feature is
## provided for identifying memory leaks.

class ur_adapter_info_t(c_int):
def __str__(self):
return str(ur_adapter_info_v(self.value))


###############################################################################
## @brief Identifies backend of the adapter
class ur_adapter_backend_v(IntEnum):
UNKNOWN = 0 ## The backend is not a recognized one
LEVEL_ZERO = 1 ## The backend is Level Zero
OPENCL = 2 ## The backend is OpenCL
CUDA = 3 ## The backend is CUDA
HIP = 4 ## The backend is HIP
NATIVE_CPU = 5 ## The backend is Native CPU

class ur_adapter_backend_t(c_int):
def __str__(self):
return str(ur_adapter_backend_v(self.value))


###############################################################################
## @brief Supported platform info
class ur_platform_info_v(IntEnum):
Expand Down Expand Up @@ -2273,9 +2312,9 @@ class ur_loader_config_dditable_t(Structure):
###############################################################################
## @brief Function-pointer for urPlatformGet
if __use_win_types:
_urPlatformGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
_urPlatformGet_t = WINFUNCTYPE( ur_result_t, POINTER(ur_adapter_handle_t), c_ulong, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
else:
_urPlatformGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )
_urPlatformGet_t = CFUNCTYPE( ur_result_t, POINTER(ur_adapter_handle_t), c_ulong, c_ulong, POINTER(ur_platform_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urPlatformGetInfo
Expand All @@ -2298,13 +2337,6 @@ class ur_loader_config_dditable_t(Structure):
else:
_urPlatformCreateWithNativeHandle_t = CFUNCTYPE( ur_result_t, ur_native_handle_t, POINTER(ur_platform_native_properties_t), POINTER(ur_platform_handle_t) )

###############################################################################
## @brief Function-pointer for urPlatformGetLastError
if __use_win_types:
_urPlatformGetLastError_t = WINFUNCTYPE( ur_result_t, ur_platform_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urPlatformGetLastError_t = CFUNCTYPE( ur_result_t, ur_platform_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urPlatformGetApiVersion
if __use_win_types:
Expand All @@ -2328,7 +2360,6 @@ class ur_platform_dditable_t(Structure):
("pfnGetInfo", c_void_p), ## _urPlatformGetInfo_t
("pfnGetNativeHandle", c_void_p), ## _urPlatformGetNativeHandle_t
("pfnCreateWithNativeHandle", c_void_p), ## _urPlatformCreateWithNativeHandle_t
("pfnGetLastError", c_void_p), ## _urPlatformGetLastError_t
("pfnGetApiVersion", c_void_p), ## _urPlatformGetApiVersion_t
("pfnGetBackendOption", c_void_p) ## _urPlatformGetBackendOption_t
]
Expand Down Expand Up @@ -3565,13 +3596,53 @@ class ur_usm_p2p_exp_dditable_t(Structure):
else:
_urTearDown_t = CFUNCTYPE( ur_result_t, c_void_p )

###############################################################################
## @brief Function-pointer for urAdapterGet
if __use_win_types:
_urAdapterGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
else:
_urAdapterGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urAdapterRelease
if __use_win_types:
_urAdapterRelease_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRelease_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterRetain
if __use_win_types:
_urAdapterRetain_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRetain_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterGetLastError
if __use_win_types:
_urAdapterGetLastError_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urAdapterGetLastError_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urAdapterGetInfo
if __use_win_types:
_urAdapterGetInfo_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urAdapterGetInfo_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )


###############################################################################
## @brief Table of Global functions pointers
class ur_global_dditable_t(Structure):
_fields_ = [
("pfnInit", c_void_p), ## _urInit_t
("pfnTearDown", c_void_p) ## _urTearDown_t
("pfnTearDown", c_void_p), ## _urTearDown_t
("pfnAdapterGet", c_void_p), ## _urAdapterGet_t
("pfnAdapterRelease", c_void_p), ## _urAdapterRelease_t
("pfnAdapterRetain", c_void_p), ## _urAdapterRetain_t
("pfnAdapterGetLastError", c_void_p), ## _urAdapterGetLastError_t
("pfnAdapterGetInfo", c_void_p) ## _urAdapterGetInfo_t
]

###############################################################################
Expand Down Expand Up @@ -3768,7 +3839,6 @@ def __init__(self, version : ur_api_version_t):
self.urPlatformGetInfo = _urPlatformGetInfo_t(self.__dditable.Platform.pfnGetInfo)
self.urPlatformGetNativeHandle = _urPlatformGetNativeHandle_t(self.__dditable.Platform.pfnGetNativeHandle)
self.urPlatformCreateWithNativeHandle = _urPlatformCreateWithNativeHandle_t(self.__dditable.Platform.pfnCreateWithNativeHandle)
self.urPlatformGetLastError = _urPlatformGetLastError_t(self.__dditable.Platform.pfnGetLastError)
self.urPlatformGetApiVersion = _urPlatformGetApiVersion_t(self.__dditable.Platform.pfnGetApiVersion)
self.urPlatformGetBackendOption = _urPlatformGetBackendOption_t(self.__dditable.Platform.pfnGetBackendOption)

Expand Down Expand Up @@ -4048,6 +4118,11 @@ def __init__(self, version : ur_api_version_t):
# attach function interface to function address
self.urInit = _urInit_t(self.__dditable.Global.pfnInit)
self.urTearDown = _urTearDown_t(self.__dditable.Global.pfnTearDown)
self.urAdapterGet = _urAdapterGet_t(self.__dditable.Global.pfnAdapterGet)
self.urAdapterRelease = _urAdapterRelease_t(self.__dditable.Global.pfnAdapterRelease)
self.urAdapterRetain = _urAdapterRetain_t(self.__dditable.Global.pfnAdapterRetain)
self.urAdapterGetLastError = _urAdapterGetLastError_t(self.__dditable.Global.pfnAdapterGetLastError)
self.urAdapterGetInfo = _urAdapterGetInfo_t(self.__dditable.Global.pfnAdapterGetInfo)

# call driver to get function pointers
VirtualMem = ur_virtual_mem_dditable_t()
Expand Down
Loading

0 comments on commit 974a7d6

Please sign in to comment.