diff --git a/source/opt/fix_storage_class.cpp b/source/opt/fix_storage_class.cpp index 6345667f88a..2085dda414f 100644 --- a/source/opt/fix_storage_class.cpp +++ b/source/opt/fix_storage_class.cpp @@ -141,22 +141,26 @@ bool FixStorageClass::IsPointerResultType(Instruction* inst) { if (inst->type_id() == 0) { return false; } - const analysis::Type* ret_type = - context()->get_type_mgr()->GetType(inst->type_id()); - return ret_type->AsPointer() != nullptr; + + Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id()); + return type_def->opcode() == spv::Op::OpTypePointer; } bool FixStorageClass::IsPointerToStorageClass(Instruction* inst, spv::StorageClass storage_class) { - analysis::TypeManager* type_mgr = context()->get_type_mgr(); - analysis::Type* pType = type_mgr->GetType(inst->type_id()); - const analysis::Pointer* result_type = pType->AsPointer(); + if (inst->type_id() == 0) { + return false; + } - if (result_type == nullptr) { + Instruction* type_def = get_def_use_mgr()->GetDef(inst->type_id()); + if (type_def->opcode() != spv::Op::OpTypePointer) { return false; } - return (result_type->storage_class() == storage_class); + const uint32_t kPointerTypeStorageClassIndex = 0; + spv::StorageClass pointer_storage_class = static_cast( + type_def->GetSingleWordInOperand(kPointerTypeStorageClassIndex)); + return pointer_storage_class == storage_class; } bool FixStorageClass::ChangeResultType(Instruction* inst, @@ -301,9 +305,9 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) { break; } - Instruction* orig_type_inst = get_def_use_mgr()->GetDef(id); - assert(orig_type_inst->opcode() == spv::Op::OpTypePointer); - id = orig_type_inst->GetSingleWordInOperand(1); + Instruction* id_type_inst = get_def_use_mgr()->GetDef(id); + assert(id_type_inst->opcode() == spv::Op::OpTypePointer); + id = id_type_inst->GetSingleWordInOperand(1); for (uint32_t i = start_idx; i < inst->NumInOperands(); ++i) { Instruction* type_inst = get_def_use_mgr()->GetDef(id); @@ -336,6 +340,15 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) { "Tried to extract from an object where it cannot be done."); } + Instruction* orig_type_inst = get_def_use_mgr()->GetDef(inst->type_id()); + assert(orig_type_inst->opcode() == spv::Op::OpTypePointer); + if (orig_type_inst->GetSingleWordInOperand(1) == id) { + // The existing type is correct. Avoid the search for the type. Note that if + // there is a duplicate type, the search below could return a different type + // forcing more changes to the code than necessary. + return inst->type_id(); + } + return context()->get_type_mgr()->FindPointerToType( id, static_cast( orig_type_inst->GetSingleWordInOperand(0))); diff --git a/source/opt/pass.cpp b/source/opt/pass.cpp index 75c37407fdf..28f26c58682 100644 --- a/source/opt/pass.cpp +++ b/source/opt/pass.cpp @@ -83,7 +83,6 @@ uint32_t Pass::GetNullId(uint32_t type_id) { uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, Instruction* insertion_position) { - analysis::TypeManager* type_mgr = context()->get_type_mgr(); analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); uint32_t original_type_id = object_to_copy->type_id(); @@ -95,55 +94,52 @@ uint32_t Pass::GenerateCopy(Instruction* object_to_copy, uint32_t new_type_id, context(), insertion_position, IRContext::kAnalysisInstrToBlockMapping | IRContext::kAnalysisDefUse); - analysis::Type* original_type = type_mgr->GetType(original_type_id); - analysis::Type* new_type = type_mgr->GetType(new_type_id); - - if (const analysis::Array* original_array_type = original_type->AsArray()) { - uint32_t original_element_type_id = - type_mgr->GetId(original_array_type->element_type()); - - analysis::Array* new_array_type = new_type->AsArray(); - assert(new_array_type != nullptr && "Can't copy an array to a non-array."); - uint32_t new_element_type_id = - type_mgr->GetId(new_array_type->element_type()); - - std::vector element_ids; - const analysis::Constant* length_const = - const_mgr->FindDeclaredConstant(original_array_type->LengthId()); - assert(length_const->AsIntConstant()); - uint32_t array_length = length_const->AsIntConstant()->GetU32(); - for (uint32_t i = 0; i < array_length; i++) { - Instruction* extract = ir_builder.AddCompositeExtract( - original_element_type_id, object_to_copy->result_id(), {i}); - element_ids.push_back( - GenerateCopy(extract, new_element_type_id, insertion_position)); + Instruction* original_type = get_def_use_mgr()->GetDef(original_type_id); + Instruction* new_type = get_def_use_mgr()->GetDef(new_type_id); + assert(new_type->opcode() == original_type->opcode() && + "Can't copy an aggragate type unless the type correspond."); + + switch (original_type->opcode()) { + case spv::Op::OpTypeArray: { + uint32_t original_element_type_id = + original_type->GetSingleWordInOperand(0); + uint32_t new_element_type_id = new_type->GetSingleWordInOperand(0); + + std::vector element_ids; + uint32_t length_id = original_type->GetSingleWordInOperand(1); + const analysis::Constant* length_const = + const_mgr->FindDeclaredConstant(length_id); + assert(length_const->AsIntConstant()); + uint32_t array_length = length_const->AsIntConstant()->GetU32(); + for (uint32_t i = 0; i < array_length; i++) { + Instruction* extract = ir_builder.AddCompositeExtract( + original_element_type_id, object_to_copy->result_id(), {i}); + element_ids.push_back( + GenerateCopy(extract, new_element_type_id, insertion_position)); + } + + return ir_builder.AddCompositeConstruct(new_type_id, element_ids) + ->result_id(); } - - return ir_builder.AddCompositeConstruct(new_type_id, element_ids) - ->result_id(); - } else if (const analysis::Struct* original_struct_type = - original_type->AsStruct()) { - analysis::Struct* new_struct_type = new_type->AsStruct(); - - const std::vector& original_types = - original_struct_type->element_types(); - const std::vector& new_types = - new_struct_type->element_types(); - std::vector element_ids; - for (uint32_t i = 0; i < original_types.size(); i++) { - Instruction* extract = ir_builder.AddCompositeExtract( - type_mgr->GetId(original_types[i]), object_to_copy->result_id(), {i}); - element_ids.push_back(GenerateCopy(extract, type_mgr->GetId(new_types[i]), - insertion_position)); + case spv::Op::OpTypeStruct: { + std::vector element_ids; + for (uint32_t i = 0; i < original_type->NumInOperands(); i++) { + uint32_t orig_member_type_id = original_type->GetSingleWordInOperand(i); + uint32_t new_member_type_id = new_type->GetSingleWordInOperand(i); + Instruction* extract = ir_builder.AddCompositeExtract( + orig_member_type_id, object_to_copy->result_id(), {i}); + element_ids.push_back( + GenerateCopy(extract, new_member_type_id, insertion_position)); + } + return ir_builder.AddCompositeConstruct(new_type_id, element_ids) + ->result_id(); } - return ir_builder.AddCompositeConstruct(new_type_id, element_ids) - ->result_id(); - } else { - // If we do not have an aggregate type, then we have a problem. Either we - // found multiple instances of the same type, or we are copying to an - // incompatible type. Either way the code is illegal. - assert(false && - "Don't know how to copy this type. Code is likely illegal."); + default: + // If we do not have an aggregate type, then we have a problem. Either we + // found multiple instances of the same type, or we are copying to an + // incompatible type. Either way the code is illegal. + assert(false && + "Don't know how to copy this type. Code is likely illegal."); } return 0; } diff --git a/test/opt/fix_storage_class_test.cpp b/test/opt/fix_storage_class_test.cpp index 18afccbe724..410f140ea9d 100644 --- a/test/opt/fix_storage_class_test.cpp +++ b/test/opt/fix_storage_class_test.cpp @@ -953,6 +953,40 @@ OpFunctionEnd SinglePassRunAndCheck(text, text, false, false); } +// Tests that the pass is not confused when there are multiple definitions +// of a pointer type to the same type with the same storage class. +TEST_F(FixStorageClassTest, DuplicatePointerType) { + const std::string text = R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %1 "main" +OpExecutionMode %1 LocalSize 64 1 1 +OpSource HLSL 600 +%uint = OpTypeInt 32 0 +%uint_0 = OpConstant %uint 0 +%uint_3 = OpConstant %uint 3 +%_arr_uint_uint_3 = OpTypeArray %uint %uint_3 +%void = OpTypeVoid +%7 = OpTypeFunction %void +%_struct_8 = OpTypeStruct %_arr_uint_uint_3 +%_ptr_Function__struct_8 = OpTypePointer Function %_struct_8 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Function__arr_uint_uint_3 = OpTypePointer Function %_arr_uint_uint_3 +%_ptr_Function_uint_0 = OpTypePointer Function %uint +%_ptr_Function__ptr_Function_uint_0 = OpTypePointer Function %_ptr_Function_uint_0 +%1 = OpFunction %void None %7 +%14 = OpLabel +%15 = OpVariable %_ptr_Function__ptr_Function_uint_0 Function +%16 = OpVariable %_ptr_Function__struct_8 Function +%17 = OpAccessChain %_ptr_Function__arr_uint_uint_3 %16 %uint_0 +%18 = OpAccessChain %_ptr_Function_uint_0 %17 %uint_0 +OpStore %15 %18 +OpReturn +OpFunctionEnd +)"; + + SinglePassRunAndCheck(text, text, false); +} + } // namespace } // namespace opt } // namespace spvtools