Skip to content

Commit

Permalink
Make urInit and urTearDown loader-only
Browse files Browse the repository at this point in the history
  • Loading branch information
callumfare committed Aug 9, 2023
1 parent 703c172 commit b9ecf2d
Show file tree
Hide file tree
Showing 61 changed files with 1,487 additions and 1,550 deletions.
15 changes: 8 additions & 7 deletions examples/collector/collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ constexpr uint16_t TRACE_FN_END =
constexpr std::string_view UR_STREAM_NAME = "ur";

/**
* @brief Formats the function parameters and arguments for urInit
* @brief Formats the function parameters and arguments for urLoaderInit
*/
std::ostream &operator<<(std::ostream &os,
const struct ur_init_params_t *params) {
const struct ur_loader_init_params_t *params) {
os << ".device_flags = ";
if (*params->pdevice_flags & UR_DEVICE_INIT_FLAG_GPU) {
os << "UR_DEVICE_INIT_FLAG_GPU";
Expand All @@ -50,16 +50,17 @@ std::ostream &operator<<(std::ostream &os,

/**
* A map of functions that format the parameters and arguments for each UR function.
* This example only implements a handler for one function, `urInit`, but it's
* This example only implements a handler for one function, `urLoaderInit`, but it's
* trivial to expand it to support more.
*/
static std::unordered_map<
std::string_view,
std::function<void(const xpti::function_with_args_t *, std::ostream &)>>
handlers = {{"urInit", [](const xpti::function_with_args_t *fn_args,
std::ostream &os) {
auto params = static_cast<const struct ur_init_params_t *>(
fn_args->args_data);
handlers = {{"urLoaderInit", [](const xpti::function_with_args_t *fn_args,
std::ostream &os) {
auto params =
static_cast<const struct ur_loader_init_params_t *>(
fn_args->args_data);
os << params;
}}};

Expand Down
7 changes: 4 additions & 3 deletions examples/hello_world/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ int main(int argc, char *argv[]) {
ur_result_t status;

// Initialize the platform
status = urInit(0, nullptr);
status = urLoaderInit(0, nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urInit failed with return code: " << status << std::endl;
std::cout << "urLoaderInit failed with return code: " << status
<< std::endl;
return 1;
}
std::cout << "Platform initialized.\n";
Expand Down Expand Up @@ -119,6 +120,6 @@ int main(int argc, char *argv[]) {
for (auto adapter : adapters) {
urAdapterRelease(adapter);
}
urTearDown(nullptr);
urLoaderTearDown(nullptr);
return status == UR_RESULT_SUCCESS ? 0 : 1;
}
200 changes: 91 additions & 109 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ class ur_function_v(IntEnum):
QUEUE_CREATE_WITH_NATIVE_HANDLE = 96 ## Enumerator for ::urQueueCreateWithNativeHandle
QUEUE_FINISH = 97 ## Enumerator for ::urQueueFinish
QUEUE_FLUSH = 98 ## Enumerator for ::urQueueFlush
INIT = 99 ## Enumerator for ::urInit
TEAR_DOWN = 100 ## Enumerator for ::urTearDown
SAMPLER_CREATE = 101 ## Enumerator for ::urSamplerCreate
SAMPLER_RETAIN = 102 ## Enumerator for ::urSamplerRetain
SAMPLER_RELEASE = 103 ## Enumerator for ::urSamplerRelease
Expand Down Expand Up @@ -196,6 +194,8 @@ class ur_function_v(IntEnum):
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo
LOADER_INIT = 182 ## Enumerator for ::urLoaderInit
LOADER_TEAR_DOWN = 183 ## Enumerator for ::urLoaderTearDown

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -484,32 +484,6 @@ class ur_rect_region_t(Structure):
("depth", c_ulonglong) ## [in] scalar (scalar)
]

###############################################################################
## @brief Supported device initialization flags
class ur_device_init_flags_v(IntEnum):
GPU = UR_BIT(0) ## initialize GPU device adapters.
CPU = UR_BIT(1) ## initialize CPU device adapters.
FPGA = UR_BIT(2) ## initialize FPGA device adapters.
MCA = UR_BIT(3) ## initialize MCA device adapters.
VPU = UR_BIT(4) ## initialize VPU device adapters.

class ur_device_init_flags_t(c_int):
def __str__(self):
return hex(self.value)


###############################################################################
## @brief Supported loader info
class ur_loader_config_info_v(IntEnum):
AVAILABLE_LAYERS = 0 ## [char[]] Null-terminated, semi-colon separated list of available
## layers.
REFERENCE_COUNT = 1 ## [uint32_t] Reference count of the loader config object.

class ur_loader_config_info_t(c_int):
def __str__(self):
return str(ur_loader_config_info_v(self.value))


###############################################################################
## @brief Supported adapter info
class ur_adapter_info_v(IntEnum):
Expand Down Expand Up @@ -540,6 +514,32 @@ def __str__(self):
return str(ur_adapter_backend_v(self.value))


###############################################################################
## @brief Supported device initialization flags
class ur_device_init_flags_v(IntEnum):
GPU = UR_BIT(0) ## initialize GPU device adapters.
CPU = UR_BIT(1) ## initialize CPU device adapters.
FPGA = UR_BIT(2) ## initialize FPGA device adapters.
MCA = UR_BIT(3) ## initialize MCA device adapters.
VPU = UR_BIT(4) ## initialize VPU device adapters.

class ur_device_init_flags_t(c_int):
def __str__(self):
return hex(self.value)


###############################################################################
## @brief Supported loader info
class ur_loader_config_info_v(IntEnum):
AVAILABLE_LAYERS = 0 ## [char[]] Null-terminated, semi-colon separated list of available
## layers.
REFERENCE_COUNT = 1 ## [uint32_t] Reference count of the loader config object.

class ur_loader_config_info_t(c_int):
def __str__(self):
return str(ur_loader_config_info_v(self.value))


###############################################################################
## @brief Supported platform info
class ur_platform_info_v(IntEnum):
Expand Down Expand Up @@ -2864,6 +2864,53 @@ class ur_physical_mem_dditable_t(Structure):
("pfnRelease", c_void_p) ## _urPhysicalMemRelease_t
]

###############################################################################
## @brief Function-pointer for urAdapterGet
if __use_win_types:
_urAdapterGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
else:
_urAdapterGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urAdapterRelease
if __use_win_types:
_urAdapterRelease_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRelease_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterRetain
if __use_win_types:
_urAdapterRetain_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRetain_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterGetLastError
if __use_win_types:
_urAdapterGetLastError_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urAdapterGetLastError_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urAdapterGetInfo
if __use_win_types:
_urAdapterGetInfo_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urAdapterGetInfo_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )


###############################################################################
## @brief Table of Global functions pointers
class ur_global_dditable_t(Structure):
_fields_ = [
("pfnAdapterGet", c_void_p), ## _urAdapterGet_t
("pfnAdapterRelease", c_void_p), ## _urAdapterRelease_t
("pfnAdapterRetain", c_void_p), ## _urAdapterRetain_t
("pfnAdapterGetLastError", c_void_p), ## _urAdapterGetLastError_t
("pfnAdapterGetInfo", c_void_p) ## _urAdapterGetInfo_t
]

###############################################################################
## @brief Function-pointer for urEnqueueKernelLaunch
if __use_win_types:
Expand Down Expand Up @@ -3537,69 +3584,6 @@ class ur_usm_p2p_exp_dditable_t(Structure):
("pfnPeerAccessGetInfoExp", c_void_p) ## _urUsmP2PPeerAccessGetInfoExp_t
]

###############################################################################
## @brief Function-pointer for urInit
if __use_win_types:
_urInit_t = WINFUNCTYPE( ur_result_t, ur_device_init_flags_t, ur_loader_config_handle_t )
else:
_urInit_t = CFUNCTYPE( ur_result_t, ur_device_init_flags_t, ur_loader_config_handle_t )

###############################################################################
## @brief Function-pointer for urTearDown
if __use_win_types:
_urTearDown_t = WINFUNCTYPE( ur_result_t, c_void_p )
else:
_urTearDown_t = CFUNCTYPE( ur_result_t, c_void_p )

###############################################################################
## @brief Function-pointer for urAdapterGet
if __use_win_types:
_urAdapterGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
else:
_urAdapterGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urAdapterRelease
if __use_win_types:
_urAdapterRelease_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRelease_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterRetain
if __use_win_types:
_urAdapterRetain_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRetain_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterGetLastError
if __use_win_types:
_urAdapterGetLastError_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urAdapterGetLastError_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urAdapterGetInfo
if __use_win_types:
_urAdapterGetInfo_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urAdapterGetInfo_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )


###############################################################################
## @brief Table of Global functions pointers
class ur_global_dditable_t(Structure):
_fields_ = [
("pfnInit", c_void_p), ## _urInit_t
("pfnTearDown", c_void_p), ## _urTearDown_t
("pfnAdapterGet", c_void_p), ## _urAdapterGet_t
("pfnAdapterRelease", c_void_p), ## _urAdapterRelease_t
("pfnAdapterRetain", c_void_p), ## _urAdapterRetain_t
("pfnAdapterGetLastError", c_void_p), ## _urAdapterGetLastError_t
("pfnAdapterGetInfo", c_void_p) ## _urAdapterGetInfo_t
]

###############################################################################
## @brief Function-pointer for urVirtualMemGranularityGetInfo
if __use_win_types:
Expand Down Expand Up @@ -3753,14 +3737,14 @@ class ur_dditable_t(Structure):
("Sampler", ur_sampler_dditable_t),
("Mem", ur_mem_dditable_t),
("PhysicalMem", ur_physical_mem_dditable_t),
("Global", ur_global_dditable_t),
("Enqueue", ur_enqueue_dditable_t),
("Queue", ur_queue_dditable_t),
("BindlessImagesExp", ur_bindless_images_exp_dditable_t),
("USM", ur_usm_dditable_t),
("USMExp", ur_usm_exp_dditable_t),
("CommandBufferExp", ur_command_buffer_exp_dditable_t),
("UsmP2PExp", ur_usm_p2p_exp_dditable_t),
("Global", ur_global_dditable_t),
("VirtualMem", ur_virtual_mem_dditable_t),
("Device", ur_device_dditable_t)
]
Expand All @@ -3779,7 +3763,7 @@ def __init__(self, version : ur_api_version_t):
self.__dditable = ur_dditable_t()

# initialize the UR
self.__dll.urInit(0, 0)
self.__dll.urLoaderInit(0, 0)

# call driver to get function pointers
Platform = ur_platform_dditable_t()
Expand Down Expand Up @@ -3921,6 +3905,20 @@ def __init__(self, version : ur_api_version_t):
self.urPhysicalMemRetain = _urPhysicalMemRetain_t(self.__dditable.PhysicalMem.pfnRetain)
self.urPhysicalMemRelease = _urPhysicalMemRelease_t(self.__dditable.PhysicalMem.pfnRelease)

# call driver to get function pointers
Global = ur_global_dditable_t()
r = ur_result_v(self.__dll.urGetGlobalProcAddrTable(version, byref(Global)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.Global = Global

# attach function interface to function address
self.urAdapterGet = _urAdapterGet_t(self.__dditable.Global.pfnAdapterGet)
self.urAdapterRelease = _urAdapterRelease_t(self.__dditable.Global.pfnAdapterRelease)
self.urAdapterRetain = _urAdapterRetain_t(self.__dditable.Global.pfnAdapterRetain)
self.urAdapterGetLastError = _urAdapterGetLastError_t(self.__dditable.Global.pfnAdapterGetLastError)
self.urAdapterGetInfo = _urAdapterGetInfo_t(self.__dditable.Global.pfnAdapterGetInfo)

# call driver to get function pointers
Enqueue = ur_enqueue_dditable_t()
r = ur_result_v(self.__dll.urGetEnqueueProcAddrTable(version, byref(Enqueue)))
Expand Down Expand Up @@ -4062,22 +4060,6 @@ def __init__(self, version : ur_api_version_t):
self.urUsmP2PDisablePeerAccessExp = _urUsmP2PDisablePeerAccessExp_t(self.__dditable.UsmP2PExp.pfnDisablePeerAccessExp)
self.urUsmP2PPeerAccessGetInfoExp = _urUsmP2PPeerAccessGetInfoExp_t(self.__dditable.UsmP2PExp.pfnPeerAccessGetInfoExp)

# call driver to get function pointers
Global = ur_global_dditable_t()
r = ur_result_v(self.__dll.urGetGlobalProcAddrTable(version, byref(Global)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.Global = Global

# attach function interface to function address
self.urInit = _urInit_t(self.__dditable.Global.pfnInit)
self.urTearDown = _urTearDown_t(self.__dditable.Global.pfnTearDown)
self.urAdapterGet = _urAdapterGet_t(self.__dditable.Global.pfnAdapterGet)
self.urAdapterRelease = _urAdapterRelease_t(self.__dditable.Global.pfnAdapterRelease)
self.urAdapterRetain = _urAdapterRetain_t(self.__dditable.Global.pfnAdapterRetain)
self.urAdapterGetLastError = _urAdapterGetLastError_t(self.__dditable.Global.pfnAdapterGetLastError)
self.urAdapterGetInfo = _urAdapterGetInfo_t(self.__dditable.Global.pfnAdapterGetInfo)

# call driver to get function pointers
VirtualMem = ur_virtual_mem_dditable_t()
r = ur_result_v(self.__dll.urGetVirtualMemProcAddrTable(version, byref(VirtualMem)))
Expand Down
Loading

0 comments on commit b9ecf2d

Please sign in to comment.