Skip to content

Commit

Permalink
Do not distrubute OpSNegate into OpUDiv
Browse files Browse the repository at this point in the history
We cannot apply the negate to an operand of an OpUDiv instead of it
result. This is because the operands of the OpUDiv are interpreted as
unsigned. We stop the optimizer from doing that.

There were no tests for distributing a negate into OpIMul, OpSDiv, and
OpUDiv. Tests are added for all of these.

Fixes KhronosGroup#5822
  • Loading branch information
s-perron committed Sep 25, 2024
1 parent 44936c4 commit 16b75a0
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 29 deletions.
57 changes: 29 additions & 28 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,36 +422,37 @@ FoldingRule MergeNegateMulDivArithmetic() {
if (width != 32 && width != 64) return false;

spv::Op opcode = op_inst->opcode();
if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv ||
opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv ||
opcode == spv::Op::OpUDiv) {
std::vector<const analysis::Constant*> op_constants =
const_mgr->GetOperandConstants(op_inst);
// Merge negate into mul or div if one operand is constant.
if (op_constants[0] || op_constants[1]) {
bool zero_is_variable = op_constants[0] == nullptr;
const analysis::Constant* c = ConstInput(op_constants);
uint32_t neg_id = NegateConstant(const_mgr, c);
uint32_t non_const_id = zero_is_variable
? op_inst->GetSingleWordInOperand(0u)
: op_inst->GetSingleWordInOperand(1u);
// Change this instruction to a mul/div.
inst->SetOpcode(op_inst->opcode());
if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
opcode == spv::Op::OpSDiv) {
uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
} else {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
{SPV_OPERAND_TYPE_ID, {neg_id}}});
}
return true;
}
if (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFDiv &&
opcode != spv::Op::OpIMul && opcode != spv::Op::OpSDiv) {
return false;
}

return false;
std::vector<const analysis::Constant*> op_constants =
const_mgr->GetOperandConstants(op_inst);
// Merge negate into mul or div if one operand is constant.
if (op_constants[0] == nullptr && op_constants[1] == nullptr) {
return false;
}

bool zero_is_variable = op_constants[0] == nullptr;
const analysis::Constant* c = ConstInput(op_constants);
uint32_t neg_id = NegateConstant(const_mgr, c);
uint32_t non_const_id = zero_is_variable
? op_inst->GetSingleWordInOperand(0u)
: op_inst->GetSingleWordInOperand(1u);
// Change this instruction to a mul/div.
inst->SetOpcode(op_inst->opcode());
if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv ||
opcode == spv::Op::OpSDiv) {
uint32_t op0 = zero_is_variable ? non_const_id : neg_id;
uint32_t op1 = zero_is_variable ? neg_id : non_const_id;
inst->SetInOperands(
{{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}});
} else {
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}},
{SPV_OPERAND_TYPE_ID, {neg_id}}});
}
return true;
};
}

Expand Down
190 changes: 189 additions & 1 deletion test/opt/fold_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5940,7 +5940,195 @@ ::testing::Values(
"%2 = OpFNegate %v2double %v2double_null\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, true)
2, true),
// Test case 20: fold snegate with OpIMul.
// -(x * 2) = x * -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpIMul %long %2 %long_2\n" +
"%4 = OpSNegate %long %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 21: fold snegate with OpIMul.
// -(x * 2) = x * -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %2 %uint_2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 22: fold snegate with OpIMul.
// -(-24 * x) = x * 24
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_24]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %int_n24 %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 23: fold snegate with OpIMul with UINT_MAX
// -(UINT_MAX * x) = x
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %uint_max %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 24: fold snegate with OpIMul using -INT_MAX
// -(x * 2147483649u) = x * 2147483647u
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_2147483647]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpIMul %int %2 %uint_2147483649\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 25: fold snegate with OpSDiv (long).
// -(x / 2) = x / -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" +
"; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" +
"; CHECK: %4 = OpSDiv [[long]] [[ld]] [[long_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_long Function\n" +
"%2 = OpLoad %long %var\n" +
"%3 = OpSDiv %long %2 %long_2\n" +
"%4 = OpSNegate %long %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 26: fold snegate with OpSDiv (int).
// -(x / 2) = x / -2
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_n2]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %2 %uint_2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 27: fold snegate with OpSDiv.
// -(-24 / x) = 24 / x
InstructionFoldingCase<bool>(
Header() +
"; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[int_24]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %int_n24 %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 28: fold snegate with OpSDiv with UINT_MAX
// -(UINT_MAX / x) = (1 / x)
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_1:%\\w+]] = OpConstant [[uint]] 1\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[uint_1]] [[ld]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %uint_max %2\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 29: fold snegate with OpSDiv using -INT_MAX
// -(x / 2147483647u) = x / 2147483647
InstructionFoldingCase<bool>(
Header() +
"; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" +
"; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" +
"; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" +
"; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" +
"; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_2147483647]]\n" +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpSDiv %int %2 %uint_2147483649\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, true),
// Test case 30: Don't fold snegate int OpUDiv. The operands are interpreted
// as unsigned, so negating an operand is not the same a negating the
// result.
InstructionFoldingCase<bool>(
Header() +
"%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%var = OpVariable %_ptr_int Function\n" +
"%2 = OpLoad %int %var\n" +
"%3 = OpUDiv %int %2 %uint_1\n" +
"%4 = OpSNegate %int %3\n" +
"OpReturn\n" +
"OpFunctionEnd",
4, false)
));

INSTANTIATE_TEST_SUITE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,
Expand Down

0 comments on commit 16b75a0

Please sign in to comment.