From 0c3f4fc43e0927742ce2b54e5ecc20f6cb9c9d9d Mon Sep 17 00:00:00 2001 From: Luke Drummond Date: Thu, 7 Mar 2024 17:22:00 +0000 Subject: [PATCH] Introduce a utility to walk polymorphic linked lists Add `find_stype_node` for walking polymorphic linked lists looking for a particular type. The implementation here works with an auto-generated compile time map, that links a given type to the structure enumeration tag. The implementation simply walks the list looking for the .stype implied by the template parameter type and casts it to the expected type, then returning it. It's one of those unfortunate cases where you can write pretty nasty implementation code that makes the user code much much nicer. In this case the user doesn't need to worry about the `.types` at all, and the const-void casts can be eliminated from user code. --- scripts/generate_code.py | 18 +++- scripts/templates/stype_map_helpers.hpp.mako | 22 +++++ source/adapters/hip/usm.cpp | 15 +-- source/common/stype_map_helpers.def | 98 ++++++++++++++++++++ source/common/ur_util.hpp | 53 +++++++++++ 5 files changed, 192 insertions(+), 14 deletions(-) create mode 100644 scripts/templates/stype_map_helpers.hpp.mako create mode 100644 source/common/stype_map_helpers.def diff --git a/scripts/generate_code.py b/scripts/generate_code.py index bc891f62e0..b8bfa97ba5 100644 --- a/scripts/generate_code.py +++ b/scripts/generate_code.py @@ -411,11 +411,25 @@ def generate_layers(path, section, namespace, tags, version, specs, meta): generates common utilities for unified_runtime """ def generate_common(path, section, namespace, tags, version, specs, meta): + template = "stype_map_helpers.hpp.mako" + fin = os.path.join("templates", template) + + filename = "stype_map_helpers.def" layer_dstpath = os.path.join(path, "common") os.makedirs(layer_dstpath, exist_ok=True) + fout = os.path.join(layer_dstpath, filename) + + print("Generating %s..." % fout) + + loc = util.makoWrite( + fin, fout, + ver=version, + namespace=namespace, + tags=tags, + specs=specs, + meta=meta) + print("COMMON Generated %s lines of code.\n" % loc) - loc = 0 - print("COMMON Generated %s lines of code.\n"%loc) """ Entry-point: diff --git a/scripts/templates/stype_map_helpers.hpp.mako b/scripts/templates/stype_map_helpers.hpp.mako new file mode 100644 index 0000000000..26aff00cd5 --- /dev/null +++ b/scripts/templates/stype_map_helpers.hpp.mako @@ -0,0 +1,22 @@ +<%! +import re +from templates import helper as th +%><% + n=namespace + N=n.upper() + x=tags['$x'] + X=x.upper() +%> +// This file is autogenerated from the template at ${self.template.filename} + +%for obj in th.extract_objs(specs, r"enum"): + %if obj["name"] == '$x_structure_type_t': + %for etor in obj['etors']: + %if 'UINT32' not in etor['name']: +template <> +struct stype_map<${x}_${etor['desc'][3:]}> : stype_map_impl<${X}_${etor['name'][3:]}> {}; + %endif + %endfor + %endif +%endfor + diff --git a/source/adapters/hip/usm.cpp b/source/adapters/hip/usm.cpp index 4e140ce5c1..90a451d9b8 100644 --- a/source/adapters/hip/usm.cpp +++ b/source/adapters/hip/usm.cpp @@ -320,24 +320,15 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size, ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc) : Context(Context) { - const void *pNext = PoolDesc->pNext; - while (pNext != nullptr) { - const ur_base_desc_t *BaseDesc = static_cast(pNext); - switch (BaseDesc->stype) { - case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: { - const ur_usm_pool_limits_desc_t *Limits = - reinterpret_cast(BaseDesc); + if (PoolDesc) { + if (auto *Limits = find_stype_node(PoolDesc)) { for (auto &config : DisjointPoolConfigs.Configs) { config.MaxPoolableSize = Limits->maxPoolableSize; config.SlabMinSize = Limits->minDriverAllocSize; } - break; - } - default: { + } else { throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT); } - } - pNext = BaseDesc->pNext; } auto MemProvider = diff --git a/source/common/stype_map_helpers.def b/source/common/stype_map_helpers.def new file mode 100644 index 0000000000..cc18d0f3f4 --- /dev/null +++ b/source/common/stype_map_helpers.def @@ -0,0 +1,98 @@ + +// This file is autogenerated from the template at templates/stype_map_helpers.hpp.mako + +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; + diff --git a/source/common/ur_util.hpp b/source/common/ur_util.hpp index 3bd7214b8c..018462e9c9 100644 --- a/source/common/ur_util.hpp +++ b/source/common/ur_util.hpp @@ -266,6 +266,59 @@ inline ur_result_t exceptionToResult(std::exception_ptr eptr) { template inline constexpr bool ur_always_false_t = false; +namespace { +// Compile-time map, mapping a UR list node type, to the enum tag type +// These are helpers for the `find_stype_node` helper below +template struct stype_map_impl { + static constexpr ur_structure_type_t value = val; +}; + +template struct stype_map {}; +// contains definitions of the map specializations e.g. +// template <> struct stype_map : +// stype_map_impl {}; +#include "stype_map_helpers.def" + +template constexpr int as_stype() { return stype_map::value; }; + +/// Walk a generic UR linked list looking for a node of the given type. If it's +/// found, its address is returned, othewise `nullptr`. e.g. to find out whether +/// a `ur_usm_host_desc_t` exists in the given polymorphic list, `mylist`: +/// +/// ```cpp +/// auto *node = find_stype_node(mylist); +/// if (!node) +/// printf("node of expected type not found!\n"); +/// ``` +/// +/// To find multiple nodes of a given type: +/// +/// ```cpp +/// std::vector nodes; +/// +/// for (auto *node = find_stype_node(mylist); node; node = +/// node->pNext;) { +/// nodes.push_back(node); +/// } +template +typename std::conditional_t>, + const T *, T *> +find_stype_node(P list_head) noexcept { + auto *list = reinterpret_cast(list_head); + for (const auto *next = reinterpret_cast(list); next; + next = reinterpret_cast(next->pNext)) { + if (next->stype == as_stype()) { + if constexpr (!std::is_const_v

) { + return const_cast(next); + } else { + return next; + } + } + } + return nullptr; +} +} // namespace + namespace ur { [[noreturn]] inline void unreachable() { #ifdef _MSC_VER