diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 5c68e291cdc..cc0700e8b2c 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1733,27 +1733,27 @@ bool CompositeConstructFeedingExtract( } // Walks the indexes chain from |start| to |end| of an OpCompositeInsert or -// OpCompositeExtract instruction, and returns the type of the final element +// OpCompositeExtract instruction, and returns the type id of the final element // being accessed. -const analysis::Type* GetElementType(uint32_t type_id, - Instruction::iterator start, - Instruction::iterator end, - const analysis::TypeManager* type_mgr) { - const analysis::Type* type = type_mgr->GetType(type_id); +// TODO: Check if this is already written somewhere else and use that instead. +uint32_t GetElementType(uint32_t type_id, Instruction::iterator start, + Instruction::iterator end, + const analysis::DefUseManager* def_use_manager) { for (auto index : make_range(std::move(start), std::move(end))) { + const Instruction* type_inst = def_use_manager->GetDef(type_id); assert(index.type == SPV_OPERAND_TYPE_LITERAL_INTEGER && index.words.size() == 1); - if (auto* array_type = type->AsArray()) { - type = array_type->element_type(); - } else if (auto* matrix_type = type->AsMatrix()) { - type = matrix_type->element_type(); - } else if (auto* struct_type = type->AsStruct()) { - type = struct_type->element_types()[index.words[0]]; + if (type_inst->opcode() == spv::Op::OpTypeArray) { + type_id = type_inst->GetSingleWordInOperand(0); + } else if (type_inst->opcode() == spv::Op::OpTypeMatrix) { + type_id = type_inst->GetSingleWordInOperand(0); + } else if (type_inst->opcode() == spv::Op::OpTypeStruct) { + type_id = type_inst->GetSingleWordInOperand(index.words[0]); } else { - type = nullptr; + return 0; } } - return type; + return type_id; } // Returns true of |inst_1| and |inst_2| have the same indexes that will be used @@ -1838,16 +1838,11 @@ bool CompositeExtractFeedingConstruct( // The last check it to see that the object being extracted from is the // correct type. Instruction* original_inst = def_use_mgr->GetDef(original_id); - analysis::TypeManager* type_mgr = context->get_type_mgr(); - const analysis::Type* original_type = + uint32_t original_type_id = GetElementType(original_inst->type_id(), first_element_inst->begin() + 3, - first_element_inst->end() - 1, type_mgr); - - if (original_type == nullptr) { - return false; - } + first_element_inst->end() - 1, def_use_mgr); - if (inst->type_id() != type_mgr->GetId(original_type)) { + if (inst->type_id() != original_type_id) { return false; } @@ -2141,9 +2136,11 @@ bool DoInsertedValuesCoverEntireObject( // inserted by the OpCompositeInsert instruction |inst|. const analysis::Type* GetContainerType(Instruction* inst) { assert(inst->opcode() == spv::Op::OpCompositeInsert); + analysis::DefUseManager* def_use_manager = inst->context()->get_def_use_mgr(); + uint32_t container_type_id = GetElementType( + inst->type_id(), inst->begin() + 4, inst->end() - 1, def_use_manager); analysis::TypeManager* type_mgr = inst->context()->get_type_mgr(); - return GetElementType(inst->type_id(), inst->begin() + 4, inst->end() - 1, - type_mgr); + return type_mgr->GetType(container_type_id); } // Returns an OpCompositeConstruct instruction that build an object with