From 16b75a0e43b503120c01eac9625390aa43ade5bf Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Wed, 25 Sep 2024 11:36:14 -0400 Subject: [PATCH] Do not distrubute OpSNegate into OpUDiv We cannot apply the negate to an operand of an OpUDiv instead of it result. This is because the operands of the OpUDiv are interpreted as unsigned. We stop the optimizer from doing that. There were no tests for distributing a negate into OpIMul, OpSDiv, and OpUDiv. Tests are added for all of these. Fixes #5822 --- source/opt/folding_rules.cpp | 57 +++++------ test/opt/fold_test.cpp | 190 ++++++++++++++++++++++++++++++++++- 2 files changed, 218 insertions(+), 29 deletions(-) diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 2ebc385cb4..b1152f473b 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -422,36 +422,37 @@ FoldingRule MergeNegateMulDivArithmetic() { if (width != 32 && width != 64) return false; spv::Op opcode = op_inst->opcode(); - if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv || - opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv || - opcode == spv::Op::OpUDiv) { - std::vector op_constants = - const_mgr->GetOperandConstants(op_inst); - // Merge negate into mul or div if one operand is constant. - if (op_constants[0] || op_constants[1]) { - bool zero_is_variable = op_constants[0] == nullptr; - const analysis::Constant* c = ConstInput(op_constants); - uint32_t neg_id = NegateConstant(const_mgr, c); - uint32_t non_const_id = zero_is_variable - ? op_inst->GetSingleWordInOperand(0u) - : op_inst->GetSingleWordInOperand(1u); - // Change this instruction to a mul/div. - inst->SetOpcode(op_inst->opcode()); - if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv || - opcode == spv::Op::OpSDiv) { - uint32_t op0 = zero_is_variable ? non_const_id : neg_id; - uint32_t op1 = zero_is_variable ? neg_id : non_const_id; - inst->SetInOperands( - {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); - } else { - inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, - {SPV_OPERAND_TYPE_ID, {neg_id}}}); - } - return true; - } + if (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFDiv && + opcode != spv::Op::OpIMul && opcode != spv::Op::OpSDiv) { + return false; } - return false; + std::vector op_constants = + const_mgr->GetOperandConstants(op_inst); + // Merge negate into mul or div if one operand is constant. + if (op_constants[0] == nullptr && op_constants[1] == nullptr) { + return false; + } + + bool zero_is_variable = op_constants[0] == nullptr; + const analysis::Constant* c = ConstInput(op_constants); + uint32_t neg_id = NegateConstant(const_mgr, c); + uint32_t non_const_id = zero_is_variable + ? op_inst->GetSingleWordInOperand(0u) + : op_inst->GetSingleWordInOperand(1u); + // Change this instruction to a mul/div. + inst->SetOpcode(op_inst->opcode()); + if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv || + opcode == spv::Op::OpSDiv) { + uint32_t op0 = zero_is_variable ? non_const_id : neg_id; + uint32_t op1 = zero_is_variable ? neg_id : non_const_id; + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); + } else { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, + {SPV_OPERAND_TYPE_ID, {neg_id}}}); + } + return true; }; } diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index e2d9d7cc18..003c44915e 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -5940,7 +5940,195 @@ ::testing::Values( "%2 = OpFNegate %v2double %v2double_null\n" + "OpReturn\n" + "OpFunctionEnd", - 2, true) + 2, true), + // Test case 20: fold snegate with OpIMul. + // -(x * 2) = x * -2 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIMul %long %2 %long_2\n" + + "%4 = OpSNegate %long %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 21: fold snegate with OpIMul. + // -(x * 2) = x * -2 + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %uint_2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 22: fold snegate with OpIMul. + // -(-24 * x) = x * 24 + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_24]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %int_n24 %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 23: fold snegate with OpIMul with UINT_MAX + // -(UINT_MAX * x) = x + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %uint_max %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 24: fold snegate with OpIMul using -INT_MAX + // -(x * 2147483649u) = x * 2147483647u + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_2147483647]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %uint_2147483649\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 25: fold snegate with OpSDiv (long). + // -(x / 2) = x / -2 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpSDiv [[long]] [[ld]] [[long_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpSDiv %long %2 %long_2\n" + + "%4 = OpSNegate %long %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 26: fold snegate with OpSDiv (int). + // -(x / 2) = x / -2 + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %uint_2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 27: fold snegate with OpSDiv. + // -(-24 / x) = 24 / x + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[int_24]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %int_n24 %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 28: fold snegate with OpSDiv with UINT_MAX + // -(UINT_MAX / x) = (1 / x) + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_1:%\\w+]] = OpConstant [[uint]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[uint_1]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %uint_max %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 29: fold snegate with OpSDiv using -INT_MAX + // -(x / 2147483647u) = x / 2147483647 + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_2147483647]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %uint_2147483649\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 30: Don't fold snegate int OpUDiv. The operands are interpreted + // as unsigned, so negating an operand is not the same a negating the + // result. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpUDiv %int %2 %uint_1\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, false) )); INSTANTIATE_TEST_SUITE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,