Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make urInit and urTearDown loader-only #793

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/collector/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ $ mkdir build
$ cd build
$ cmake .. -DUR_ENABLE_TRACING=ON
$ make
$ UR_ADAPTERS_FORCE_LOAD=./lib/libur_adapter_null.so XPTI_TRACE_ENABLE=1 XPTI_FRAMEWORK_DISPATCHER=./lib/libxptifw.so XPTI_SUBSCRIBERS=./lib/libcollector.so ./bin/hello_world
$ UR_ADAPTERS_FORCE_LOAD=./lib/libur_adapter_null.so UR_ENABLE_LAYERS=UR_LAYER_TRACING XPTI_TRACE_ENABLE=1 XPTI_FRAMEWORK_DISPATCHER=./lib/libxptifw.so XPTI_SUBSCRIBERS=./lib/libcollector.so ./bin/hello_world
```

See [XPTI framework documentation](https://github.com/intel/llvm/blob/sycl/xptifw/doc/XPTI_Framework.md) for more information.
33 changes: 21 additions & 12 deletions examples/collector/collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,41 @@ constexpr uint16_t TRACE_FN_END =
constexpr std::string_view UR_STREAM_NAME = "ur";

/**
* @brief Formats the function parameters and arguments for urInit
* @brief Formats the function parameters and arguments for urAdapterGet
*/
std::ostream &operator<<(std::ostream &os,
const struct ur_init_params_t *params) {
os << ".device_flags = ";
if (*params->pdevice_flags & UR_DEVICE_INIT_FLAG_GPU) {
os << "UR_DEVICE_INIT_FLAG_GPU";
} else {
os << "0";
const struct ur_adapter_get_params_t *params) {
os << ".NumEntries = ";
os << *params->pNumEntries;
os << ", ";
os << ".phAdapters = ";
os << *params->pphAdapters;
if (*params->pphAdapters) {
os << " (" << **params->pphAdapters << ")";
}
os << ", ";
os << ".pNumAdapters = ";
os << *params->ppNumAdapters;
if (*params->ppNumAdapters) {
os << " (" << **params->ppNumAdapters << ")";
}
os << "";
return os;
}

/**
* A map of functions that format the parameters and arguments for each UR function.
* This example only implements a handler for one function, `urInit`, but it's
pbalcer marked this conversation as resolved.
Show resolved Hide resolved
* This example only implements a handler for one function, `urAdapterGet`, but it's
* trivial to expand it to support more.
*/
static std::unordered_map<
std::string_view,
std::function<void(const xpti::function_with_args_t *, std::ostream &)>>
handlers = {{"urInit", [](const xpti::function_with_args_t *fn_args,
std::ostream &os) {
auto params = static_cast<const struct ur_init_params_t *>(
fn_args->args_data);
handlers = {{"urAdapterGet", [](const xpti::function_with_args_t *fn_args,
std::ostream &os) {
auto params =
static_cast<const struct ur_adapter_get_params_t *>(
fn_args->args_data);
os << params;
}}};

Expand Down
7 changes: 4 additions & 3 deletions examples/hello_world/hello_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ int main(int argc, char *argv[]) {
ur_result_t status;

// Initialize the platform
status = urInit(0, nullptr);
status = urLoaderInit(0, nullptr);
if (status != UR_RESULT_SUCCESS) {
std::cout << "urInit failed with return code: " << status << std::endl;
std::cout << "urLoaderInit failed with return code: " << status
<< std::endl;
return 1;
}
std::cout << "Platform initialized.\n";
Expand Down Expand Up @@ -119,6 +120,6 @@ int main(int argc, char *argv[]) {
for (auto adapter : adapters) {
urAdapterRelease(adapter);
}
urTearDown(nullptr);
urLoaderTearDown();
return status == UR_RESULT_SUCCESS ? 0 : 1;
}
148 changes: 65 additions & 83 deletions include/ur.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ class ur_function_v(IntEnum):
QUEUE_CREATE_WITH_NATIVE_HANDLE = 96 ## Enumerator for ::urQueueCreateWithNativeHandle
QUEUE_FINISH = 97 ## Enumerator for ::urQueueFinish
QUEUE_FLUSH = 98 ## Enumerator for ::urQueueFlush
INIT = 99 ## Enumerator for ::urInit
TEAR_DOWN = 100 ## Enumerator for ::urTearDown
SAMPLER_CREATE = 101 ## Enumerator for ::urSamplerCreate
SAMPLER_RETAIN = 102 ## Enumerator for ::urSamplerRetain
SAMPLER_RELEASE = 103 ## Enumerator for ::urSamplerRelease
Expand Down Expand Up @@ -196,6 +194,8 @@ class ur_function_v(IntEnum):
ADAPTER_RETAIN = 179 ## Enumerator for ::urAdapterRetain
ADAPTER_GET_LAST_ERROR = 180 ## Enumerator for ::urAdapterGetLastError
ADAPTER_GET_INFO = 181 ## Enumerator for ::urAdapterGetInfo
LOADER_INIT = 182 ## Enumerator for ::urLoaderInit
LOADER_TEAR_DOWN = 183 ## Enumerator for ::urLoaderTearDown

class ur_function_t(c_int):
def __str__(self):
Expand Down Expand Up @@ -2870,6 +2870,53 @@ class ur_physical_mem_dditable_t(Structure):
("pfnRelease", c_void_p) ## _urPhysicalMemRelease_t
]

###############################################################################
## @brief Function-pointer for urAdapterGet
if __use_win_types:
_urAdapterGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
else:
_urAdapterGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urAdapterRelease
if __use_win_types:
_urAdapterRelease_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRelease_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterRetain
if __use_win_types:
_urAdapterRetain_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRetain_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterGetLastError
if __use_win_types:
_urAdapterGetLastError_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urAdapterGetLastError_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urAdapterGetInfo
if __use_win_types:
_urAdapterGetInfo_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urAdapterGetInfo_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )


###############################################################################
## @brief Table of Global functions pointers
class ur_global_dditable_t(Structure):
_fields_ = [
("pfnAdapterGet", c_void_p), ## _urAdapterGet_t
("pfnAdapterRelease", c_void_p), ## _urAdapterRelease_t
("pfnAdapterRetain", c_void_p), ## _urAdapterRetain_t
("pfnAdapterGetLastError", c_void_p), ## _urAdapterGetLastError_t
("pfnAdapterGetInfo", c_void_p) ## _urAdapterGetInfo_t
]

###############################################################################
## @brief Function-pointer for urEnqueueKernelLaunch
if __use_win_types:
Expand Down Expand Up @@ -3543,69 +3590,6 @@ class ur_usm_p2p_exp_dditable_t(Structure):
("pfnPeerAccessGetInfoExp", c_void_p) ## _urUsmP2PPeerAccessGetInfoExp_t
]

###############################################################################
## @brief Function-pointer for urInit
if __use_win_types:
_urInit_t = WINFUNCTYPE( ur_result_t, ur_device_init_flags_t, ur_loader_config_handle_t )
else:
_urInit_t = CFUNCTYPE( ur_result_t, ur_device_init_flags_t, ur_loader_config_handle_t )

###############################################################################
## @brief Function-pointer for urTearDown
if __use_win_types:
_urTearDown_t = WINFUNCTYPE( ur_result_t, c_void_p )
else:
_urTearDown_t = CFUNCTYPE( ur_result_t, c_void_p )

###############################################################################
## @brief Function-pointer for urAdapterGet
if __use_win_types:
_urAdapterGet_t = WINFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )
else:
_urAdapterGet_t = CFUNCTYPE( ur_result_t, c_ulong, POINTER(ur_adapter_handle_t), POINTER(c_ulong) )

###############################################################################
## @brief Function-pointer for urAdapterRelease
if __use_win_types:
_urAdapterRelease_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRelease_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterRetain
if __use_win_types:
_urAdapterRetain_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t )
else:
_urAdapterRetain_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t )

###############################################################################
## @brief Function-pointer for urAdapterGetLastError
if __use_win_types:
_urAdapterGetLastError_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )
else:
_urAdapterGetLastError_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, POINTER(c_char_p), POINTER(c_long) )

###############################################################################
## @brief Function-pointer for urAdapterGetInfo
if __use_win_types:
_urAdapterGetInfo_t = WINFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )
else:
_urAdapterGetInfo_t = CFUNCTYPE( ur_result_t, ur_adapter_handle_t, ur_adapter_info_t, c_size_t, c_void_p, POINTER(c_size_t) )


###############################################################################
## @brief Table of Global functions pointers
class ur_global_dditable_t(Structure):
_fields_ = [
("pfnInit", c_void_p), ## _urInit_t
("pfnTearDown", c_void_p), ## _urTearDown_t
("pfnAdapterGet", c_void_p), ## _urAdapterGet_t
("pfnAdapterRelease", c_void_p), ## _urAdapterRelease_t
("pfnAdapterRetain", c_void_p), ## _urAdapterRetain_t
("pfnAdapterGetLastError", c_void_p), ## _urAdapterGetLastError_t
("pfnAdapterGetInfo", c_void_p) ## _urAdapterGetInfo_t
]

###############################################################################
## @brief Function-pointer for urVirtualMemGranularityGetInfo
if __use_win_types:
Expand Down Expand Up @@ -3759,14 +3743,14 @@ class ur_dditable_t(Structure):
("Sampler", ur_sampler_dditable_t),
("Mem", ur_mem_dditable_t),
("PhysicalMem", ur_physical_mem_dditable_t),
("Global", ur_global_dditable_t),
("Enqueue", ur_enqueue_dditable_t),
("Queue", ur_queue_dditable_t),
("BindlessImagesExp", ur_bindless_images_exp_dditable_t),
("USM", ur_usm_dditable_t),
("USMExp", ur_usm_exp_dditable_t),
("CommandBufferExp", ur_command_buffer_exp_dditable_t),
("UsmP2PExp", ur_usm_p2p_exp_dditable_t),
("Global", ur_global_dditable_t),
("VirtualMem", ur_virtual_mem_dditable_t),
("Device", ur_device_dditable_t)
]
Expand All @@ -3785,7 +3769,7 @@ def __init__(self, version : ur_api_version_t):
self.__dditable = ur_dditable_t()

# initialize the UR
self.__dll.urInit(0, 0)
self.__dll.urLoaderInit(0, 0)

# call driver to get function pointers
Platform = ur_platform_dditable_t()
Expand Down Expand Up @@ -3927,6 +3911,20 @@ def __init__(self, version : ur_api_version_t):
self.urPhysicalMemRetain = _urPhysicalMemRetain_t(self.__dditable.PhysicalMem.pfnRetain)
self.urPhysicalMemRelease = _urPhysicalMemRelease_t(self.__dditable.PhysicalMem.pfnRelease)

# call driver to get function pointers
Global = ur_global_dditable_t()
r = ur_result_v(self.__dll.urGetGlobalProcAddrTable(version, byref(Global)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.Global = Global

# attach function interface to function address
self.urAdapterGet = _urAdapterGet_t(self.__dditable.Global.pfnAdapterGet)
self.urAdapterRelease = _urAdapterRelease_t(self.__dditable.Global.pfnAdapterRelease)
self.urAdapterRetain = _urAdapterRetain_t(self.__dditable.Global.pfnAdapterRetain)
self.urAdapterGetLastError = _urAdapterGetLastError_t(self.__dditable.Global.pfnAdapterGetLastError)
self.urAdapterGetInfo = _urAdapterGetInfo_t(self.__dditable.Global.pfnAdapterGetInfo)

# call driver to get function pointers
Enqueue = ur_enqueue_dditable_t()
r = ur_result_v(self.__dll.urGetEnqueueProcAddrTable(version, byref(Enqueue)))
Expand Down Expand Up @@ -4068,22 +4066,6 @@ def __init__(self, version : ur_api_version_t):
self.urUsmP2PDisablePeerAccessExp = _urUsmP2PDisablePeerAccessExp_t(self.__dditable.UsmP2PExp.pfnDisablePeerAccessExp)
self.urUsmP2PPeerAccessGetInfoExp = _urUsmP2PPeerAccessGetInfoExp_t(self.__dditable.UsmP2PExp.pfnPeerAccessGetInfoExp)

# call driver to get function pointers
Global = ur_global_dditable_t()
r = ur_result_v(self.__dll.urGetGlobalProcAddrTable(version, byref(Global)))
if r != ur_result_v.SUCCESS:
raise Exception(r)
self.__dditable.Global = Global

# attach function interface to function address
self.urInit = _urInit_t(self.__dditable.Global.pfnInit)
self.urTearDown = _urTearDown_t(self.__dditable.Global.pfnTearDown)
self.urAdapterGet = _urAdapterGet_t(self.__dditable.Global.pfnAdapterGet)
self.urAdapterRelease = _urAdapterRelease_t(self.__dditable.Global.pfnAdapterRelease)
self.urAdapterRetain = _urAdapterRetain_t(self.__dditable.Global.pfnAdapterRetain)
self.urAdapterGetLastError = _urAdapterGetLastError_t(self.__dditable.Global.pfnAdapterGetLastError)
self.urAdapterGetInfo = _urAdapterGetInfo_t(self.__dditable.Global.pfnAdapterGetInfo)

# call driver to get function pointers
VirtualMem = ur_virtual_mem_dditable_t()
r = ur_result_v(self.__dll.urGetVirtualMemProcAddrTable(version, byref(VirtualMem)))
Expand Down
Loading