Skip to content

Commit

Permalink
Introduce a utility to walk polymorphic linked lists
Browse files Browse the repository at this point in the history
Add `find_stype_node` for walking polymorphic linked lists looking for a
particular type.

The implementation here works with an auto-generated compile time map,
that links a given type to the structure enumeration tag. The
implementation simply walks the list looking for the .stype implied by
the template parameter type and casts it to the expected type, then
returning it.

It's one of those unfortunate cases where you can write pretty nasty
implementation code that makes the user code much much nicer. In this
case the user doesn't need to worry about the `.types` at all, and the
const-void casts can be eliminated from user code.
  • Loading branch information
ldrumm committed Mar 7, 2024
1 parent cc268e5 commit e8b3445
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 14 deletions.
18 changes: 16 additions & 2 deletions scripts/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,25 @@ def generate_layers(path, section, namespace, tags, version, specs, meta):
generates common utilities for unified_runtime
"""
def generate_common(path, section, namespace, tags, version, specs, meta):
template = "stype_map_helpers.hpp.mako"
fin = os.path.join("templates", template)

filename = "stype_map_helpers.def"
layer_dstpath = os.path.join(path, "common")
os.makedirs(layer_dstpath, exist_ok=True)
fout = os.path.join(layer_dstpath, filename)

print("Generating %s..." % fout)

loc = util.makoWrite(
fin, fout,
ver=version,
namespace=namespace,
tags=tags,
specs=specs,
meta=meta)
print("COMMON Generated %s lines of code.\n" % loc)

loc = 0
print("COMMON Generated %s lines of code.\n"%loc)

"""
Entry-point:
Expand Down
22 changes: 22 additions & 0 deletions scripts/templates/stype_map_helpers.hpp.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
<%!
import re
from templates import helper as th
%><%
n=namespace
N=n.upper()
x=tags['$x']
X=x.upper()
%>
// This file is autogenerated from the template at ${self.template.filename}

%for obj in th.extract_objs(specs, r"enum"):
%if obj["name"] == '$x_structure_type_t':
%for etor in obj['etors']:
%if 'UINT32' not in etor['name']:
template <>
struct stype_map<${x}_${etor['desc'][3:]}> : stype_map_impl<${X}_${etor['name'][3:]}> {};
%endif
%endfor
%endif
%endfor

15 changes: 3 additions & 12 deletions source/adapters/hip/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,24 +320,15 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc)
: Context(Context) {
const void *pNext = PoolDesc->pNext;
while (pNext != nullptr) {
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
switch (BaseDesc->stype) {
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
const ur_usm_pool_limits_desc_t *Limits =
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
if (PoolDesc) {
if (auto *Limits = find_stype_node<ur_usm_pool_limits_desc_t>(PoolDesc)) {
for (auto &config : DisjointPoolConfigs.Configs) {
config.MaxPoolableSize = Limits->maxPoolableSize;
config.SlabMinSize = Limits->minDriverAllocSize;
}
break;
}
default: {
} else {
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
}
}
pNext = BaseDesc->pNext;
}

auto MemProvider =
Expand Down
98 changes: 98 additions & 0 deletions source/common/stype_map_helpers.def
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

// This file is autogenerated from the template at templates/stype_map_helpers.hpp.mako

template <>
struct stype_map<ur_context_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_CONTEXT_PROPERTIES> {};
template <>
struct stype_map<ur_image_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_IMAGE_DESC> {};
template <>
struct stype_map<ur_buffer_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_PROPERTIES> {};
template <>
struct stype_map<ur_buffer_region_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_REGION> {};
template <>
struct stype_map<ur_buffer_channel_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_CHANNEL_PROPERTIES> {};
template <>
struct stype_map<ur_buffer_alloc_location_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_BUFFER_ALLOC_LOCATION_PROPERTIES> {};
template <>
struct stype_map<ur_program_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PROGRAM_PROPERTIES> {};
template <>
struct stype_map<ur_usm_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_DESC> {};
template <>
struct stype_map<ur_usm_host_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_HOST_DESC> {};
template <>
struct stype_map<ur_usm_device_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_DEVICE_DESC> {};
template <>
struct stype_map<ur_usm_pool_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_POOL_DESC> {};
template <>
struct stype_map<ur_usm_pool_limits_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC> {};
template <>
struct stype_map<ur_device_binary_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_BINARY> {};
template <>
struct stype_map<ur_sampler_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_SAMPLER_DESC> {};
template <>
struct stype_map<ur_queue_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_PROPERTIES> {};
template <>
struct stype_map<ur_queue_index_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_INDEX_PROPERTIES> {};
template <>
struct stype_map<ur_context_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_CONTEXT_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_kernel_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_queue_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_mem_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_MEM_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_event_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EVENT_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_platform_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PLATFORM_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_device_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_program_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PROGRAM_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_sampler_native_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_SAMPLER_NATIVE_PROPERTIES> {};
template <>
struct stype_map<ur_queue_native_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_QUEUE_NATIVE_DESC> {};
template <>
struct stype_map<ur_device_partition_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES> {};
template <>
struct stype_map<ur_kernel_arg_mem_obj_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES> {};
template <>
struct stype_map<ur_physical_mem_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_PHYSICAL_MEM_PROPERTIES> {};
template <>
struct stype_map<ur_kernel_arg_pointer_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_POINTER_PROPERTIES> {};
template <>
struct stype_map<ur_kernel_arg_sampler_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_SAMPLER_PROPERTIES> {};
template <>
struct stype_map<ur_kernel_exec_info_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_EXEC_INFO_PROPERTIES> {};
template <>
struct stype_map<ur_kernel_arg_value_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_VALUE_PROPERTIES> {};
template <>
struct stype_map<ur_kernel_arg_local_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_KERNEL_ARG_LOCAL_PROPERTIES> {};
template <>
struct stype_map<ur_usm_alloc_location_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC> {};
template <>
struct stype_map<ur_exp_command_buffer_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC> {};
template <>
struct stype_map<ur_exp_command_buffer_update_kernel_launch_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC> {};
template <>
struct stype_map<ur_exp_command_buffer_update_memobj_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_MEMOBJ_ARG_DESC> {};
template <>
struct stype_map<ur_exp_command_buffer_update_pointer_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC> {};
template <>
struct stype_map<ur_exp_command_buffer_update_value_arg_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC> {};
template <>
struct stype_map<ur_exp_command_buffer_update_exec_info_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_EXEC_INFO_DESC> {};
template <>
struct stype_map<ur_exp_sampler_mip_properties_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_MIP_PROPERTIES> {};
template <>
struct stype_map<ur_exp_interop_mem_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_INTEROP_MEM_DESC> {};
template <>
struct stype_map<ur_exp_interop_semaphore_desc_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_INTEROP_SEMAPHORE_DESC> {};
template <>
struct stype_map<ur_exp_file_descriptor_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_FILE_DESCRIPTOR> {};
template <>
struct stype_map<ur_exp_win32_handle_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_WIN32_HANDLE> {};
template <>
struct stype_map<ur_exp_sampler_addr_modes_t> : stype_map_impl<UR_STRUCTURE_TYPE_EXP_SAMPLER_ADDR_MODES> {};

53 changes: 53 additions & 0 deletions source/common/ur_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,59 @@ inline ur_result_t exceptionToResult(std::exception_ptr eptr) {

template <class> inline constexpr bool ur_always_false_t = false;

namespace {
// Compile-time map, mapping a UR list node type, to the enum tag type
// These are helpers for the `find_stype_node` helper below
template <int val> struct stype_map_impl {
static constexpr int value = val;
};

template <typename T> struct stype_map {};
// contains definitions of the map specializations e.g.
// template <> struct stype_map<ur_usm_device_desc_t> :
// stype_map_impl<UR_STRUCTURE_TYPE_USM_DEVICE_DESC> {};
#include "stype_map_helpers.def"

template <typename T> constexpr int as_stype() { return stype_map<T>::value; };

/// Walk a generic UR linked list looking for a node of the given type. If it's
/// found, its address is returned, othewise `nullptr`. e.g. to find out whether
/// a `ur_usm_host_desc_t` exists in the given polymorphic list, `mylist`:
///
/// ```cpp
/// auto *node = find_stype_node<ur_usm_host_desc_t>(mylist);
/// if (!node)
/// printf("node of expected type not found!\n");
/// ```
///
/// To find multiple nodes of a given type:
///
/// ```cpp
/// std::vector<ur_usm_host_desc_t> nodes;
///
/// for (auto *node = find_stype_node<ur_usm_host_desc_t>(mylist); node; node =
/// node->pNext;) {
/// nodes.push_back(node);
/// }
template <typename T, typename P>
typename std::conditional_t<std::is_const_v<std::remove_pointer_t<P>>,
const T *, T *>
find_stype_node(P list_head) noexcept {
auto *list = reinterpret_cast<const T *>(list_head);
for (const auto *next = reinterpret_cast<const T *>(list->pNext); next;
next = reinterpret_cast<const T *>(next->pNext)) {
if (next->stype == as_stype<T>()) {
if constexpr (!std::is_const_v<P>) {
return const_cast<T *>(next);
} else {
return next;
}
}
}
return nullptr;
}
} // namespace

namespace ur {
[[noreturn]] inline void unreachable() {
#ifdef _MSC_VER
Expand Down

0 comments on commit e8b3445

Please sign in to comment.