Skip to content

Commit

Permalink
Do not fold mul and adds to generate fmas (KhronosGroup#5682)
Browse files Browse the repository at this point in the history
This removes the folding rules added in KhronosGroup#4783 and KhronosGroup#4808. They lead to
poor code generation on Adreno devices when 16-bit floating point values
were used. Since this change is transformation is suppose to be neutral,
there is no general reason to continue doing it.

I have talked to the owners of SwiftShader, and they do not mind if the
transform is removed. They were the ones the requested the change in the
first place.

Fixes KhronosGroup#5658
  • Loading branch information
s-perron authored May 22, 2024
1 parent ee749f5 commit 336b571
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 242 deletions.
128 changes: 0 additions & 128 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1459,132 +1459,6 @@ FoldingRule FactorAddMuls() {
};
}

// Replaces |inst| inplace with an FMA instruction |(x*y)+a|.
void ReplaceWithFma(Instruction* inst, uint32_t x, uint32_t y, uint32_t a) {
uint32_t ext =
inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();

if (ext == 0) {
inst->context()->AddExtInstImport("GLSL.std.450");
ext = inst->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
assert(ext != 0 &&
"Could not add the GLSL.std.450 extended instruction set");
}

std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
operands.push_back({SPV_OPERAND_TYPE_ID, {x}});
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
operands.push_back({SPV_OPERAND_TYPE_ID, {a}});

inst->SetOpcode(spv::Op::OpExtInst);
inst->SetInOperands(std::move(operands));
}

// Folds a multiple and add into an Fma.
//
// Cases:
// (x * y) + a = Fma x y a
// a + (x * y) = Fma x y a
bool MergeMulAddArithmetic(IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>&) {
assert(inst->opcode() == spv::Op::OpFAdd);

if (!inst->IsFloatingPointFoldingAllowed()) {
return false;
}

analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
for (int i = 0; i < 2; i++) {
uint32_t op_id = inst->GetSingleWordInOperand(i);
Instruction* op_inst = def_use_mgr->GetDef(op_id);

if (op_inst->opcode() != spv::Op::OpFMul) {
continue;
}

if (!op_inst->IsFloatingPointFoldingAllowed()) {
continue;
}

uint32_t x = op_inst->GetSingleWordInOperand(0);
uint32_t y = op_inst->GetSingleWordInOperand(1);
uint32_t a = inst->GetSingleWordInOperand((i + 1) % 2);
ReplaceWithFma(inst, x, y, a);
return true;
}
return false;
}

// Replaces |sub| inplace with an FMA instruction |(x*y)+a| where |a| first gets
// negated if |negate_addition| is true, otherwise |x| gets negated.
void ReplaceWithFmaAndNegate(Instruction* sub, uint32_t x, uint32_t y,
uint32_t a, bool negate_addition) {
uint32_t ext =
sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();

if (ext == 0) {
sub->context()->AddExtInstImport("GLSL.std.450");
ext = sub->context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
assert(ext != 0 &&
"Could not add the GLSL.std.450 extended instruction set");
}

InstructionBuilder ir_builder(
sub->context(), sub,
IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);

Instruction* neg = ir_builder.AddUnaryOp(sub->type_id(), spv::Op::OpFNegate,
negate_addition ? a : x);
uint32_t neg_op = neg->result_id(); // -a : -x

std::vector<Operand> operands;
operands.push_back({SPV_OPERAND_TYPE_ID, {ext}});
operands.push_back({SPV_OPERAND_TYPE_LITERAL_INTEGER, {GLSLstd450Fma}});
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? x : neg_op}});
operands.push_back({SPV_OPERAND_TYPE_ID, {y}});
operands.push_back({SPV_OPERAND_TYPE_ID, {negate_addition ? neg_op : a}});

sub->SetOpcode(spv::Op::OpExtInst);
sub->SetInOperands(std::move(operands));
}

// Folds a multiply and subtract into an Fma and negation.
//
// Cases:
// (x * y) - a = Fma x y -a
// a - (x * y) = Fma -x y a
bool MergeMulSubArithmetic(IRContext* context, Instruction* sub,
const std::vector<const analysis::Constant*>&) {
assert(sub->opcode() == spv::Op::OpFSub);

if (!sub->IsFloatingPointFoldingAllowed()) {
return false;
}

analysis::DefUseManager* def_use_mgr = context->get_def_use_mgr();
for (int i = 0; i < 2; i++) {
uint32_t op_id = sub->GetSingleWordInOperand(i);
Instruction* mul = def_use_mgr->GetDef(op_id);

if (mul->opcode() != spv::Op::OpFMul) {
continue;
}

if (!mul->IsFloatingPointFoldingAllowed()) {
continue;
}

uint32_t x = mul->GetSingleWordInOperand(0);
uint32_t y = mul->GetSingleWordInOperand(1);
uint32_t a = sub->GetSingleWordInOperand((i + 1) % 2);
ReplaceWithFmaAndNegate(sub, x, y, a, i == 0);
return true;
}
return false;
}

FoldingRule IntMultipleBy1() {
return [](IRContext*, Instruction* inst,
const std::vector<const analysis::Constant*>& constants) {
Expand Down Expand Up @@ -2941,7 +2815,6 @@ void FoldingRules::AddFoldingRules() {
rules_[spv::Op::OpFAdd].push_back(MergeAddSubArithmetic());
rules_[spv::Op::OpFAdd].push_back(MergeGenericAddSubArithmetic());
rules_[spv::Op::OpFAdd].push_back(FactorAddMuls());
rules_[spv::Op::OpFAdd].push_back(MergeMulAddArithmetic);

rules_[spv::Op::OpFDiv].push_back(RedundantFDiv());
rules_[spv::Op::OpFDiv].push_back(ReciprocalFDiv());
Expand All @@ -2962,7 +2835,6 @@ void FoldingRules::AddFoldingRules() {
rules_[spv::Op::OpFSub].push_back(MergeSubNegateArithmetic());
rules_[spv::Op::OpFSub].push_back(MergeSubAddArithmetic());
rules_[spv::Op::OpFSub].push_back(MergeSubSubArithmetic());
rules_[spv::Op::OpFSub].push_back(MergeMulSubArithmetic);

rules_[spv::Op::OpIAdd].push_back(RedundantIAdd());
rules_[spv::Op::OpIAdd].push_back(MergeAddNegateArithmetic());
Expand Down
132 changes: 18 additions & 114 deletions test/opt/fold_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7933,21 +7933,15 @@ ::testing::Values(
3, true)
));

// Issue #5658: The Adreno compiler does not handle 16-bit FMA instructions well.
// We want to avoid this by not generating FMA. We decided to never generate
// FMAs because, from a SPIR-V perspective, it is neutral. The ICD can generate
// the FMA if it wants. The simplest code is no code.
INSTANTIATE_TEST_SUITE_P(FmaGenerationMatchingTest, MatchingInstructionFoldingTest,
::testing::Values(
// Test case 0: (x * y) + a = Fma(x, y, a)
// Test case 0: Don't fold (x * y) + a
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
Expand All @@ -7961,20 +7955,10 @@ ::testing::Values(
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 1: a + (x * y) = Fma(x, y, a)
3, false),
// Test case 1: Don't fold a + (x * y)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
Expand All @@ -7988,20 +7972,10 @@ ::testing::Values(
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 2: (x * y) + a = Fma(x, y, a) with vectors
3, false),
// Test case 2: Don't fold (x * y) + a with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_v4float Function\n" +
Expand All @@ -8015,20 +7989,10 @@ ::testing::Values(
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 3: a + (x * y) = Fma(x, y, a) with vectors
3,false),
// Test case 3: Don't fold a + (x * y) with vectors
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
Expand All @@ -8042,46 +8006,8 @@ ::testing::Values(
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 4: that the OpExtInstImport instruction is generated if it is missing.
InstructionFoldingCase<bool>(
std::string() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"OpCapability Shader\n" +
"OpMemoryModel Logical GLSL450\n" +
"OpEntryPoint Fragment %main \"main\"\n" +
"OpExecutionMode %main OriginUpperLeft\n" +
"OpSource GLSL 140\n" +
"OpName %main \"main\"\n" +
"%void = OpTypeVoid\n" +
"%void_func = OpTypeFunction %void\n" +
"%bool = OpTypeBool\n" +
"%float = OpTypeFloat 32\n" +
"%_ptr_float = OpTypePointer Function %float\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
"%y = OpVariable %_ptr_float Function\n" +
"%a = OpVariable %_ptr_float Function\n" +
"%lx = OpLoad %float %x\n" +
"%ly = OpLoad %float %y\n" +
"%mul = OpFMul %float %lx %ly\n" +
"%la = OpLoad %float %a\n" +
"%3 = OpFAdd %float %mul %la\n" +
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test 5: Don't fold if the multiple is marked no contract.
3, false),
// Test 4: Don't fold if the multiple is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
Expand Down Expand Up @@ -8110,7 +8036,7 @@ ::testing::Values(
"OpReturn\n" +
"OpFunctionEnd",
3, false),
// Test 6: Don't fold if the add is marked no contract.
// Test 5: Don't fold if the add is marked no contract.
InstructionFoldingCase<bool>(
std::string() +
"OpCapability Shader\n" +
Expand Down Expand Up @@ -8139,20 +8065,9 @@ ::testing::Values(
"OpReturn\n" +
"OpFunctionEnd",
3, false),
// Test case 7: (x * y) - a = Fma(x, y, -a)
// Test case 6: Don't fold (x * y) - a
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[na:%\\w+]] = OpFNegate {{%\\w+}} [[la]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[lx]] [[ly]] [[na]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
Expand All @@ -8166,21 +8081,10 @@ ::testing::Values(
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true),
// Test case 8: a - (x * y) = Fma(-x, y, a)
3, false),
// Test case 7: Don't fold a - (x * y)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[ext:%\\w+]] = OpExtInstImport \"GLSL.std.450\"\n" +
"; CHECK: OpFunction\n" +
"; CHECK: [[x:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[y:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[a:%\\w+]] = OpVariable {{%\\w+}} Function\n" +
"; CHECK: [[lx:%\\w+]] = OpLoad {{%\\w+}} [[x]]\n" +
"; CHECK: [[ly:%\\w+]] = OpLoad {{%\\w+}} [[y]]\n" +
"; CHECK: [[la:%\\w+]] = OpLoad {{%\\w+}} [[a]]\n" +
"; CHECK: [[nx:%\\w+]] = OpFNegate {{%\\w+}} [[lx]]\n" +
"; CHECK: [[fma:%\\w+]] = OpExtInst {{%\\w+}} [[ext]] Fma [[nx]] [[ly]] [[la]]\n" +
"; CHECK: OpStore {{%\\w+}} [[fma]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%x = OpVariable %_ptr_float Function\n" +
Expand All @@ -8194,7 +8098,7 @@ ::testing::Values(
"OpStore %a %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
3, true)
3, false)
));

using MatchingInstructionWithNoResultFoldingTest =
Expand Down

0 comments on commit 336b571

Please sign in to comment.