From 0c0f636ca5ae1735c798621f8dca91a06c15f8d8 Mon Sep 17 00:00:00 2001 From: Jemale Date: Mon, 14 Aug 2023 18:19:48 -0400 Subject: [PATCH] Add validation of module extended descriptor Signed-off-by: Jemale Lockett --- .../parameter_validation/extension_validation.inl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/source/layers/validation/parameter_validation/extension_validation.inl b/source/layers/validation/parameter_validation/extension_validation.inl index 141d5ae..e688f30 100644 --- a/source/layers/validation/parameter_validation/extension_validation.inl +++ b/source/layers/validation/parameter_validation/extension_validation.inl @@ -7,9 +7,10 @@ * @file extension_validation.inl * */ +#include template -inline ze_result_t validateStructureTypes(void *descriptorPtr, +inline ze_result_t validateStructureTypes(const void *descriptorPtr, std::vector &baseTypesVector, std::vector &extensionTypesVector) { @@ -38,6 +39,14 @@ inline ze_result_t validateStructureTypes(void *descriptorPtr, for (auto t : (extensionTypesVector)) { if (pBase->stype == t) { validExtensionTypeFound = true; + + // if extension type is ZE_STRUCTURE_TYPE_MODULE_PROGRAM_EXP_DESC + // then base type->format must be ZE_MODULE_FORMAT_IL_SPIRV + if (std::is_same::value && (ZE_STRUCTURE_TYPE_MODULE_PROGRAM_EXP_DESC == static_cast(t))){ + if (ZE_MODULE_FORMAT_IL_SPIRV != reinterpret_cast(descriptorPtr)->format){ + return ZE_RESULT_ERROR_INVALID_ARGUMENT; + } + } break; } } @@ -353,7 +362,7 @@ inline ze_result_t ParameterValidation::validateExtensions(ze_host_mem_alloc_des } template <> -inline ze_result_t ParameterValidation::validateExtensions(ze_module_desc_t *descriptor) { +inline ze_result_t ParameterValidation::validateExtensions(const ze_module_desc_t *descriptor) { std::vector baseTypes = {ZE_STRUCTURE_TYPE_MODULE_DESC}; std::vector types = {ZE_STRUCTURE_TYPE_MODULE_PROGRAM_EXP_DESC};