From 336b5710a5bec2d2da95cbdd7cf20a4c67e9b51d Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 22 May 2024 13:01:26 -0400 Subject: [PATCH] Do not fold mul and adds to generate fmas (#5682) This removes the folding rules added in #4783 and #4808. They lead to poor code generation on Adreno devices when 16-bit floating point values were used. Since this change is transformation is suppose to be neutral, there is no general reason to continue doing it. I have talked to the owners of SwiftShader, and they do not mind if the transform is removed. They were the ones the requested the change in the first place. Fixes #5658 --- source/opt/folding_rules.cpp | 128 --------------------------------- test/opt/fold_test.cpp | 132 +++++------------------------------ 2 files changed, 18 insertions(+), 242 deletions(-) diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 5c68e291cd..5f83669940 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -1459,132 +1459,6 @@ FoldingRule FactorAddMuls() { }; } -// Replaces |inst| inplace with an FMA instruction |(x*y)+a|. -void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) { - uint32_t ext = - inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); - - if (ext == 0) { - inst->context()->AddExtInstImport("GLSL.std.450"); - ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); - assert(ext != 0 && - "Could not add the GLSL.std.450 extended instruction set"); - } - - std::vector operands; - operands.push_back({SPV_OPERAND_TYPE_ID, {ext}}); - operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}}); - operands.push_back({SPV_OPERAND_TYPE_ID, {x}}); - operands.push_back({SPV_OPERAND_TYPE_ID, {y}}); - operands.push_back({SPV_OPERAND_TYPE_ID, {a}}); - - inst->SetOpcode(spv::Op::OpExtInst); - inst->SetInOperands(std::move(operands)); -} - -// Folds a multiple and add into an Fma. -// -// Cases: -// (x * y) + a = Fma x y a -// a + (x * y) = Fma x y a -bool MergeMulAddArithmetic(IRContext* context, Instruction* inst, - const std::vector&) { - assert(inst->opcode() == spv::Op::OpFAdd); - - if (!inst->IsFloatingPointFoldingAllowed()) { - return false; - } - - analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - for (int i = 0; i < 2; i++) { - uint32_t op_id = inst->GetSingleWordInOperand(i); - Instruction* op_inst = def_use_mgr->GetDef(op_id); - - if (op_inst->opcode() != spv::Op::OpFMul) { - continue; - } - - if (!op_inst->IsFloatingPointFoldingAllowed()) { - continue; - } - - uint32_t x = op_inst->GetSingleWordInOperand(0); - uint32_t y = op_inst->GetSingleWordInOperand(1); - uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2); - ReplaceWithFma(inst, x, y, a); - return true; - } - return false; -} - -// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets -// negated if |negate_addition| is true, otherwise |x| gets negated. -void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y, - uint32_t a, bool negate_addition) { - uint32_t ext = - sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); - - if (ext == 0) { - sub->context()->AddExtInstImport("GLSL.std.450"); - ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450(); - assert(ext != 0 && - "Could not add the GLSL.std.450 extended instruction set"); - } - - InstructionBuilder ir_builder( - sub->context(), sub, - IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); - - Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), spv::Op::OpFNegate, - negate_addition ? a : x); - uint32_t neg_op = neg->result_id(); // -a : -x - - std::vector operands; - operands.push_back({SPV_OPERAND_TYPE_ID, {ext}}); - operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}}); - operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}}); - operands.push_back({SPV_OPERAND_TYPE_ID, {y}}); - operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}}); - - sub->SetOpcode(spv::Op::OpExtInst); - sub->SetInOperands(std::move(operands)); -} - -// Folds a multiply and subtract into an Fma and negation. -// -// Cases: -// (x * y) - a = Fma x y -a -// a - (x * y) = Fma -x y a -bool MergeMulSubArithmetic(IRContext* context, Instruction* sub, - const std::vector&) { - assert(sub->opcode() == spv::Op::OpFSub); - - if (!sub->IsFloatingPointFoldingAllowed()) { - return false; - } - - analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr(); - for (int i = 0; i < 2; i++) { - uint32_t op_id = sub->GetSingleWordInOperand(i); - Instruction* mul = def_use_mgr->GetDef(op_id); - - if (mul->opcode() != spv::Op::OpFMul) { - continue; - } - - if (!mul->IsFloatingPointFoldingAllowed()) { - continue; - } - - uint32_t x = mul->GetSingleWordInOperand(0); - uint32_t y = mul->GetSingleWordInOperand(1); - uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2); - ReplaceWithFmaAndNegate(sub, x, y, a, i == 0); - return true; - } - return false; -} - FoldingRule IntMultipleBy1() { return [](IRContext*, Instruction* inst, const std::vector& constants) { @@ -2941,7 +2815,6 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic()); rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic()); rules_[spv::Op::OpFAdd].push_back(FactorAddMuls()); - rules_[spv::Op::OpFAdd].push_back(MergeMulAddArithmetic); rules_[spv::Op::OpFDiv].push_back(RedundantFDiv()); rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv()); @@ -2962,7 +2835,6 @@ void FoldingRules::AddFoldingRules() { rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic()); rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic()); rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic()); - rules_[spv::Op::OpFSub].push_back(MergeMulSubArithmetic); rules_[spv::Op::OpIAdd].push_back(RedundantIAdd()); rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic()); diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index a4e0447c10..255449dbbf 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -7933,21 +7933,15 @@ ::testing::Values( 3, true) )); +// Issue #5658: The Adreno compiler does not handle 16-bit FMA instructions well. +// We want to avoid this by not generating FMA. We decided to never generate +// FMAs because, from a SPIR-V perspective, it is neutral. The ICD can generate +// the FMA if it wants. The simplest code is no code. INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTest, ::testing::Values( - // Test case 0: (x * y) + a = Fma(x, y, a) + // Test case 0: Don't fold (x * y) + a InstructionFoldingCase( Header() + - "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" + - "; CHECK: OpFunction\n" + - "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" + - "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" + - "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" + - "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" + - "; CHECK: OpStore {{%\\w+}} [[fma]]\n" + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%x = OpVariable %_ptr_float Function\n" + @@ -7961,20 +7955,10 @@ ::testing::Values( "OpStore %a %3\n" + "OpReturn\n" + "OpFunctionEnd", - 3, true), - // Test case 1: a + (x * y) = Fma(x, y, a) + 3, false), + // Test case 1: Don't fold a + (x * y) InstructionFoldingCase( Header() + - "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" + - "; CHECK: OpFunction\n" + - "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" + - "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" + - "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" + - "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" + - "; CHECK: OpStore {{%\\w+}} [[fma]]\n" + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%x = OpVariable %_ptr_float Function\n" + @@ -7988,20 +7972,10 @@ ::testing::Values( "OpStore %a %3\n" + "OpReturn\n" + "OpFunctionEnd", - 3, true), - // Test case 2: (x * y) + a = Fma(x, y, a) with vectors + 3, false), + // Test case 2: Don't fold (x * y) + a with vectors InstructionFoldingCase( Header() + - "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" + - "; CHECK: OpFunction\n" + - "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" + - "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" + - "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" + - "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" + - "; CHECK: OpStore {{%\\w+}} [[fma]]\n" + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%x = OpVariable %_ptr_v4float Function\n" + @@ -8015,20 +7989,10 @@ ::testing::Values( "OpStore %a %3\n" + "OpReturn\n" + "OpFunctionEnd", - 3, true), - // Test case 3: a + (x * y) = Fma(x, y, a) with vectors + 3,false), + // Test case 3: Don't fold a + (x * y) with vectors InstructionFoldingCase( Header() + - "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" + - "; CHECK: OpFunction\n" + - "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" + - "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" + - "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" + - "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" + - "; CHECK: OpStore {{%\\w+}} [[fma]]\n" + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%x = OpVariable %_ptr_float Function\n" + @@ -8042,46 +8006,8 @@ ::testing::Values( "OpStore %a %3\n" + "OpReturn\n" + "OpFunctionEnd", - 3, true), - // Test 4: that the OpExtInstImport instruction is generated if it is missing. - InstructionFoldingCase( - std::string() + - "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" + - "; CHECK: OpFunction\n" + - "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" + - "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" + - "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" + - "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" + - "; CHECK: OpStore {{%\\w+}} [[fma]]\n" + - "OpCapability Shader\n" + - "OpMemoryModel Logical GLSL450\n" + - "OpEntryPoint Fragment %main \"main\"\n" + - "OpExecutionMode %main OriginUpperLeft\n" + - "OpSource GLSL 140\n" + - "OpName %main \"main\"\n" + - "%void = OpTypeVoid\n" + - "%void_func = OpTypeFunction %void\n" + - "%bool = OpTypeBool\n" + - "%float = OpTypeFloat 32\n" + - "%_ptr_float = OpTypePointer Function %float\n" + - "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%x = OpVariable %_ptr_float Function\n" + - "%y = OpVariable %_ptr_float Function\n" + - "%a = OpVariable %_ptr_float Function\n" + - "%lx = OpLoad %float %x\n" + - "%ly = OpLoad %float %y\n" + - "%mul = OpFMul %float %lx %ly\n" + - "%la = OpLoad %float %a\n" + - "%3 = OpFAdd %float %mul %la\n" + - "OpStore %a %3\n" + - "OpReturn\n" + - "OpFunctionEnd", - 3, true), - // Test 5: Don't fold if the multiple is marked no contract. + 3, false), + // Test 4: Don't fold if the multiple is marked no contract. InstructionFoldingCase( std::string() + "OpCapability Shader\n" + @@ -8110,7 +8036,7 @@ ::testing::Values( "OpReturn\n" + "OpFunctionEnd", 3, false), - // Test 6: Don't fold if the add is marked no contract. + // Test 5: Don't fold if the add is marked no contract. InstructionFoldingCase( std::string() + "OpCapability Shader\n" + @@ -8139,20 +8065,9 @@ ::testing::Values( "OpReturn\n" + "OpFunctionEnd", 3, false), - // Test case 7: (x * y) - a = Fma(x, y, -a) + // Test case 6: Don't fold (x * y) - a InstructionFoldingCase( Header() + - "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" + - "; CHECK: OpFunction\n" + - "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" + - "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" + - "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" + - "; CHECK: [[na:%\\w+]] = OpFNegate {{%\\w+}} [[la]]\n" + - "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[na]]\n" + - "; CHECK: OpStore {{%\\w+}} [[fma]]\n" + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%x = OpVariable %_ptr_float Function\n" + @@ -8166,21 +8081,10 @@ ::testing::Values( "OpStore %a %3\n" + "OpReturn\n" + "OpFunctionEnd", - 3, true), - // Test case 8: a - (x * y) = Fma(-x, y, a) + 3, false), + // Test case 7: Don't fold a - (x * y) InstructionFoldingCase( Header() + - "; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" + - "; CHECK: OpFunction\n" + - "; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" + - "; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" + - "; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" + - "; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" + - "; CHECK: [[nx:%\\w+]] = OpFNegate {{%\\w+}} [[lx]]\n" + - "; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[nx]] [[ly]] [[la]]\n" + - "; CHECK: OpStore {{%\\w+}} [[fma]]\n" + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + "%x = OpVariable %_ptr_float Function\n" + @@ -8194,7 +8098,7 @@ ::testing::Values( "OpStore %a %3\n" + "OpReturn\n" + "OpFunctionEnd", - 3, true) + 3, false) )); using MatchingInstructionWithNoResultFoldingTest =