From 519a78d25c58fcfdbcb5ac09f251ae38c36a9bb4 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 22 May 2024 13:52:07 -0400 Subject: [PATCH] Avoid use of type manager in extact->construct folding When dealing with structs the type manager merge two different structs into a single entry if they have all of the same decorations and element types. This is because they hash to the same value in the hash table. This can cause problems if you need to get the id of a type from the type manager because you could get either one. In this case, it returns the wrong one. The fix avoids using the type manager in one place. I have not looked closely at other places the type manager is used to make sure it is used safely everywhere. Fixes #5624 --- source/opt/folding_rules.cpp | 46 ++++++++++++++++-------------------- test/opt/fold_test.cpp | 16 ++++++++++++- 2 files changed, 36 insertions(+), 26 deletions(-) diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 5f83669940..f41ef868a3 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1607,27 +1607,26 @@ bool CompositeConstructFeedingExtract( } // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or -// OpCompositeExtract instruction, and returns the type of the final element -// being accessed. -const analysis::Type* GetElementType(uint32_t type_id, - Instruction::iterator start, - Instruction::iterator end, - const analysis::TypeManager* type_mgr) { - const analysis::Type* type = type_mgr->GetType(type_id); +// OpCompositeExtract instruction, and returns the type id of the final element +// being accessed. Returns 0 if a valid type could not be found. +uint32_t GetElementType(uint32_t type_id, Instruction::iterator start, + Instruction::iterator end, + const analysis::DefUseManager* def_use_manager) { for (auto index : make_range(std::move(start), std::move(end))) { + const Instruction* type_inst = def_use_manager->GetDef(type_id); assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER && index.words.size() == 1); - if (auto* array_type = type->AsArray()) { - type = array_type->element_type(); - } else if (auto* matrix_type = type->AsMatrix()) { - type = matrix_type->element_type(); - } else if (auto* struct_type = type->AsStruct()) { - type = struct_type->element_types()[index.words[0]]; + if (type_inst->opcode() == spv::Op::OpTypeArray) { + type_id = type_inst->GetSingleWordInOperand(0); + } else if (type_inst->opcode() == spv::Op::OpTypeMatrix) { + type_id = type_inst->GetSingleWordInOperand(0); + } else if (type_inst->opcode() == spv::Op::OpTypeStruct) { + type_id = type_inst->GetSingleWordInOperand(index.words[0]); } else { - type = nullptr; + return 0; } } - return type; + return type_id; } // Returns true of |inst_1| and |inst_2| have the same indexes that will be used @@ -1712,16 +1711,11 @@ bool CompositeExtractFeedingConstruct( // The last check it to see that the object being extracted from is the // correct type. Instruction* original_inst = def_use_mgr->GetDef(original_id); - analysis::TypeManager* type_mgr = context->get_type_mgr(); - const analysis::Type* original_type = + uint32_t original_type_id = GetElementType(original_inst->type_id(), first_element_inst->begin() + 3, - first_element_inst->end() - 1, type_mgr); - - if (original_type == nullptr) { - return false; - } + first_element_inst->end() - 1, def_use_mgr); - if (inst->type_id() != type_mgr->GetId(original_type)) { + if (inst->type_id() != original_type_id) { return false; } @@ -2015,9 +2009,11 @@ bool DoInsertedValuesCoverEntireObject( // inserted by the OpCompositeInsert instruction |inst|. const analysis::Type* GetContainerType(Instruction* inst) { assert(inst->opcode() == spv::Op::OpCompositeInsert); + analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr(); + uint32_t container_type_id = GetElementType( + inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager); analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); - return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1, - type_mgr); + return type_mgr->GetType(container_type_id); } // Returns an OpCompositeConstruct instruction that build an object with diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 255449dbbf..35828ab22f 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -7827,7 +7827,21 @@ ::testing::Values( "%5 = OpCompositeInsert %int_arr_2 %int_1 %4 1\n" + "OpReturn\n" + "OpFunctionEnd", - 5, true) + 5, true), + // Test case 19: Don't fold for isomorphic structs + InstructionFoldingCase( + Header() + + "%structA = OpTypeStruct %ulong\n" + + "%structB = OpTypeStruct %ulong\n" + + "%structC = OpTypeStruct %structB\n" + + "%struct_a_undef = OpUndef %structA\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%3 = OpCompositeExtract %ulong %struct_a_undef 0\n" + + "%4 = OpCompositeConstruct %structB %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, false) )); INSTANTIATE_TEST_SUITE_P(DotProductMatchingTest, MatchingInstructionFoldingTest,