diff --git a/scripts/generate_code.py b/scripts/generate_code.py index bc891f62e0..b8bfa97ba5 100644 --- a/scripts/generate_code.py +++ b/scripts/generate_code.py @@ -411,11 +411,25 @@ def generate_layers(path, section, namespace, tags, version, specs, meta): generates common utilities for unified_runtime """ def generate_common(path, section, namespace, tags, version, specs, meta): + template = "stype_map_helpers.hpp.mako" + fin = os.path.join("templates", template) + + filename = "stype_map_helpers.def" layer_dstpath = os.path.join(path, "common") os.makedirs(layer_dstpath, exist_ok=True) + fout = os.path.join(layer_dstpath, filename) + + print("Generating %s..." % fout) + + loc = util.makoWrite( + fin, fout, + ver=version, + namespace=namespace, + tags=tags, + specs=specs, + meta=meta) + print("COMMON Generated %s lines of code.\n" % loc) - loc = 0 - print("COMMON Generated %s lines of code.\n"%loc) """ Entry-point: diff --git a/scripts/templates/stype_map_helpers.hpp.mako b/scripts/templates/stype_map_helpers.hpp.mako new file mode 100644 index 0000000000..26aff00cd5 --- /dev/null +++ b/scripts/templates/stype_map_helpers.hpp.mako @@ -0,0 +1,22 @@ +<%! +import re +from templates import helper as th +%><% + n=namespace + N=n.upper() + x=tags['$x'] + X=x.upper() +%> +// This file is autogenerated from the template at ${self.template.filename} + +%for obj in th.extract_objs(specs, r"enum"): + %if obj["name"] == '$x_structure_type_t': + %for etor in obj['etors']: + %if 'UINT32' not in etor['name']: +template <> +struct stype_map<${x}_${etor['desc'][3:]}> : stype_map_impl<${X}_${etor['name'][3:]}> {}; + %endif + %endfor + %endif +%endfor + diff --git a/source/adapters/hip/usm.cpp b/source/adapters/hip/usm.cpp index 4e140ce5c1..90a451d9b8 100644 --- a/source/adapters/hip/usm.cpp +++ b/source/adapters/hip/usm.cpp @@ -320,24 +320,15 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size, ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_usm_pool_desc_t *PoolDesc) : Context(Context) { - const void *pNext = PoolDesc->pNext; - while (pNext != nullptr) { - const ur_base_desc_t *BaseDesc = static_cast(pNext); - switch (BaseDesc->stype) { - case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: { - const ur_usm_pool_limits_desc_t *Limits = - reinterpret_cast(BaseDesc); + if (PoolDesc) { + if (auto *Limits = find_stype_node(PoolDesc)) { for (auto &config : DisjointPoolConfigs.Configs) { config.MaxPoolableSize = Limits->maxPoolableSize; config.SlabMinSize = Limits->minDriverAllocSize; } - break; - } - default: { + } else { throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT); } - } - pNext = BaseDesc->pNext; } auto MemProvider = diff --git a/source/common/stype_map_helpers.def b/source/common/stype_map_helpers.def new file mode 100644 index 0000000000..cc18d0f3f4 --- /dev/null +++ b/source/common/stype_map_helpers.def @@ -0,0 +1,98 @@ + +// This file is autogenerated from the template at templates/stype_map_helpers.hpp.mako + +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; +template <> +struct stype_map : stype_map_impl {}; + diff --git a/source/common/ur_util.hpp b/source/common/ur_util.hpp index 3bd7214b8c..018462e9c9 100644 --- a/source/common/ur_util.hpp +++ b/source/common/ur_util.hpp @@ -266,6 +266,59 @@ inline ur_result_t exceptionToResult(std::exception_ptr eptr) { template inline constexpr bool ur_always_false_t = false; +namespace { +// Compile-time map, mapping a UR list node type, to the enum tag type +// These are helpers for the `find_stype_node` helper below +template struct stype_map_impl { + static constexpr ur_structure_type_t value = val; +}; + +template struct stype_map {}; +// contains definitions of the map specializations e.g. +// template <> struct stype_map : +// stype_map_impl {}; +#include "stype_map_helpers.def" + +template constexpr int as_stype() { return stype_map::value; }; + +/// Walk a generic UR linked list looking for a node of the given type. If it's +/// found, its address is returned, othewise `nullptr`. e.g. to find out whether +/// a `ur_usm_host_desc_t` exists in the given polymorphic list, `mylist`: +/// +/// ```cpp +/// auto *node = find_stype_node(mylist); +/// if (!node) +/// printf("node of expected type not found!\n"); +/// ``` +/// +/// To find multiple nodes of a given type: +/// +/// ```cpp +/// std::vector nodes; +/// +/// for (auto *node = find_stype_node(mylist); node; node = +/// node->pNext;) { +/// nodes.push_back(node); +/// } +template +typename std::conditional_t>, + const T *, T *> +find_stype_node(P list_head) noexcept { + auto *list = reinterpret_cast(list_head); + for (const auto *next = reinterpret_cast(list); next; + next = reinterpret_cast(next->pNext)) { + if (next->stype == as_stype()) { + if constexpr (!std::is_const_v

) { + return const_cast(next); + } else { + return next; + } + } + } + return nullptr; +} +} // namespace + namespace ur { [[noreturn]] inline void unreachable() { #ifdef _MSC_VER