diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 2ebc385cb4..b1152f473b 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -422,36 +422,37 @@ FoldingRule MergeNegateMulDivArithmetic() { if (width != 32 && width != 64) return false; spv::Op opcode = op_inst->opcode(); - if (opcode == spv::Op::OpFMul || opcode == spv::Op::OpFDiv || - opcode == spv::Op::OpIMul || opcode == spv::Op::OpSDiv || - opcode == spv::Op::OpUDiv) { - std::vector op_constants = - const_mgr->GetOperandConstants(op_inst); - // Merge negate into mul or div if one operand is constant. - if (op_constants[0] || op_constants[1]) { - bool zero_is_variable = op_constants[0] == nullptr; - const analysis::Constant* c = ConstInput(op_constants); - uint32_t neg_id = NegateConstant(const_mgr, c); - uint32_t non_const_id = zero_is_variable - ? op_inst->GetSingleWordInOperand(0u) - : op_inst->GetSingleWordInOperand(1u); - // Change this instruction to a mul/div. - inst->SetOpcode(op_inst->opcode()); - if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv || - opcode == spv::Op::OpSDiv) { - uint32_t op0 = zero_is_variable ? non_const_id : neg_id; - uint32_t op1 = zero_is_variable ? neg_id : non_const_id; - inst->SetInOperands( - {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); - } else { - inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, - {SPV_OPERAND_TYPE_ID, {neg_id}}}); - } - return true; - } + if (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFDiv && + opcode != spv::Op::OpIMul && opcode != spv::Op::OpSDiv) { + return false; } - return false; + std::vector op_constants = + const_mgr->GetOperandConstants(op_inst); + // Merge negate into mul or div if one operand is constant. + if (op_constants[0] == nullptr && op_constants[1] == nullptr) { + return false; + } + + bool zero_is_variable = op_constants[0] == nullptr; + const analysis::Constant* c = ConstInput(op_constants); + uint32_t neg_id = NegateConstant(const_mgr, c); + uint32_t non_const_id = zero_is_variable + ? op_inst->GetSingleWordInOperand(0u) + : op_inst->GetSingleWordInOperand(1u); + // Change this instruction to a mul/div. + inst->SetOpcode(op_inst->opcode()); + if (opcode == spv::Op::OpFDiv || opcode == spv::Op::OpUDiv || + opcode == spv::Op::OpSDiv) { + uint32_t op0 = zero_is_variable ? non_const_id : neg_id; + uint32_t op1 = zero_is_variable ? neg_id : non_const_id; + inst->SetInOperands( + {{SPV_OPERAND_TYPE_ID, {op0}}, {SPV_OPERAND_TYPE_ID, {op1}}}); + } else { + inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {non_const_id}}, + {SPV_OPERAND_TYPE_ID, {neg_id}}}); + } + return true; }; } diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index e2d9d7cc18..003c44915e 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -5940,7 +5940,195 @@ ::testing::Values( "%2 = OpFNegate %v2double %v2double_null\n" + "OpReturn\n" + "OpFunctionEnd", - 2, true) + 2, true), + // Test case 20: fold snegate with OpIMul. + // -(x * 2) = x * -2 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpIMul [[long]] [[ld]] [[long_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpIMul %long %2 %long_2\n" + + "%4 = OpSNegate %long %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 21: fold snegate with OpIMul. + // -(x * 2) = x * -2 + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %uint_2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 22: fold snegate with OpIMul. + // -(-24 * x) = x * 24 + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[int_24]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %int_n24 %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 23: fold snegate with OpIMul with UINT_MAX + // -(UINT_MAX * x) = x + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpCopyObject [[int]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %uint_max %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 24: fold snegate with OpIMul using -INT_MAX + // -(x * 2147483649u) = x * 2147483647u + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpIMul [[int]] [[ld]] [[uint_2147483647]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpIMul %int %2 %uint_2147483649\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 25: fold snegate with OpSDiv (long). + // -(x / 2) = x / -2 + InstructionFoldingCase( + Header() + + "; CHECK: [[long:%\\w+]] = OpTypeInt 64 1\n" + + "; CHECK: [[long_n2:%\\w+]] = OpConstant [[long]] -2\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[long]]\n" + + "; CHECK: %4 = OpSDiv [[long]] [[ld]] [[long_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_long Function\n" + + "%2 = OpLoad %long %var\n" + + "%3 = OpSDiv %long %2 %long_2\n" + + "%4 = OpSNegate %long %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 26: fold snegate with OpSDiv (int). + // -(x / 2) = x / -2 + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK-DAG: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_n2:%\\w+]] = OpConstant [[uint]] 4294967294\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_n2]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %uint_2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 27: fold snegate with OpSDiv. + // -(-24 / x) = 24 / x + InstructionFoldingCase( + Header() + + "; CHECK-DAG: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[int_24:%\\w+]] = OpConstant [[int]] 24\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[int_24]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %int_n24 %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 28: fold snegate with OpSDiv with UINT_MAX + // -(UINT_MAX / x) = (1 / x) + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_1:%\\w+]] = OpConstant [[uint]] 1\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[uint_1]] [[ld]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %uint_max %2\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 29: fold snegate with OpSDiv using -INT_MAX + // -(x / 2147483647u) = x / 2147483647 + InstructionFoldingCase( + Header() + + "; CHECK: [[int:%\\w+]] = OpTypeInt 32 1\n" + + "; CHECK: [[uint:%\\w+]] = OpTypeInt 32 0\n" + + "; CHECK: [[uint_2147483647:%\\w+]] = OpConstant [[uint]] 2147483647\n" + + "; CHECK: [[ld:%\\w+]] = OpLoad [[int]]\n" + + "; CHECK: %4 = OpSDiv [[int]] [[ld]] [[uint_2147483647]]\n" + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpSDiv %int %2 %uint_2147483649\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, true), + // Test case 30: Don't fold snegate int OpUDiv. The operands are interpreted + // as unsigned, so negating an operand is not the same a negating the + // result. + InstructionFoldingCase( + Header() + + "%main = OpFunction %void None %void_func\n" + + "%main_lab = OpLabel\n" + + "%var = OpVariable %_ptr_int Function\n" + + "%2 = OpLoad %int %var\n" + + "%3 = OpUDiv %int %2 %uint_1\n" + + "%4 = OpSNegate %int %3\n" + + "OpReturn\n" + + "OpFunctionEnd", + 4, false) )); INSTANTIATE_TEST_SUITE_P(ReciprocalFDivTest, MatchingInstructionFoldingTest,