diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 04c54ed69e93f1..c73d7c8d83bec1 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -302,6 +302,97 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { return Res; } +/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael +/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for +/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic. +/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every +/// even x in Bitwidth-bit arithmetic. +static unsigned CarmichaelShift(unsigned Bitwidth) { + if (Bitwidth < 3) + return Bitwidth - 1; + return Bitwidth - 2; +} + +/// Add the extra weight 'RHS' to the existing weight 'LHS', +/// reducing the combined weight using any special properties of the operation. +/// The existing weight LHS represents the computation X op X op ... op X where +/// X occurs LHS times. The combined weight represents X op X op ... op X with +/// X occurring LHS + RHS times. If op is "Xor" for example then the combined +/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even; +/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second. +static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) { + // If we were working with infinite precision arithmetic then the combined + // weight would be LHS + RHS. But we are using finite precision arithmetic, + // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct + // for nilpotent operations and addition, but not for idempotent operations + // and multiplication), so it is important to correctly reduce the combined + // weight back into range if wrapping would be wrong. + + // If RHS is zero then the weight didn't change. + if (RHS.isMinValue()) + return; + // If LHS is zero then the combined weight is RHS. + if (LHS.isMinValue()) { + LHS = RHS; + return; + } + // From this point on we know that neither LHS nor RHS is zero. + + if (Instruction::isIdempotent(Opcode)) { + // Idempotent means X op X === X, so any non-zero weight is equivalent to a + // weight of 1. Keeping weights at zero or one also means that wrapping is + // not a problem. + assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); + return; // Return a weight of 1. + } + if (Instruction::isNilpotent(Opcode)) { + // Nilpotent means X op X === 0, so reduce weights modulo 2. + assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); + LHS = 0; // 1 + 1 === 0 modulo 2. + return; + } + if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) { + // TODO: Reduce the weight by exploiting nsw/nuw? + LHS += RHS; + return; + } + + assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) && + "Unknown associative operation!"); + unsigned Bitwidth = LHS.getBitWidth(); + // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth + // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth + // bit number x, since either x is odd in which case x^CM = 1, or x is even in + // which case both x^W and x^(W - CM) are zero. By subtracting off multiples + // of CM like this weights can always be reduced to the range [0, CM+Bitwidth) + // which by a happy accident means that they can always be represented using + // Bitwidth bits. + // TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than + // the Carmichael number). + if (Bitwidth > 3) { + /// CM - The value of Carmichael's lambda function. + APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth)); + // Any weight W >= Threshold can be replaced with W - CM. + APInt Threshold = CM + Bitwidth; + assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!"); + // For Bitwidth 4 or more the following sum does not overflow. + LHS += RHS; + while (LHS.uge(Threshold)) + LHS -= CM; + } else { + // To avoid problems with overflow do everything the same as above but using + // a larger type. + unsigned CM = 1U << CarmichaelShift(Bitwidth); + unsigned Threshold = CM + Bitwidth; + assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold && + "Weights not reduced!"); + unsigned Total = LHS.getZExtValue() + RHS.getZExtValue(); + while (Total >= Threshold) + Total -= CM; + LHS = Total; + } +} + using RepeatedValue = std::pair; /// Given an associative binary expression, return the leaf @@ -471,7 +562,7 @@ static bool LinearizeExprTree(Instruction *I, "In leaf map but not visited!"); // Update the number of paths to the leaf. - It->second += Weight; + IncorporateWeight(It->second, Weight, Opcode); // If we still have uses that are not accounted for by the expression // then it is not safe to modify the value. diff --git a/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll b/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll index bd0060cc5abbd9..d4aa5c507ec8be 100644 --- a/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll +++ b/llvm/test/Transforms/Reassociate/reassoc_bool_vec.ll @@ -56,19 +56,20 @@ define <8 x i1> @vector2(<8 x i1> %a, <8 x i1> %b0, <8 x i1> %b1, <8 x i1> %b2, ; CHECK-NEXT: [[OR5:%.*]] = or <8 x i1> [[B5]], [[A]] ; CHECK-NEXT: [[OR6:%.*]] = or <8 x i1> [[B6]], [[A]] ; CHECK-NEXT: [[OR7:%.*]] = or <8 x i1> [[B7]], [[A]] -; CHECK-NEXT: [[XOR0:%.*]] = xor <8 x i1> [[OR1]], [[OR0]] -; CHECK-NEXT: [[XOR2:%.*]] = xor <8 x i1> [[XOR0]], [[OR2]] -; CHECK-NEXT: [[OR045:%.*]] = xor <8 x i1> [[XOR2]], [[OR3]] -; CHECK-NEXT: [[XOR3:%.*]] = xor <8 x i1> [[OR045]], [[OR4]] -; CHECK-NEXT: [[XOR4:%.*]] = xor <8 x i1> [[XOR3]], [[OR5]] -; CHECK-NEXT: [[XOR5:%.*]] = xor <8 x i1> [[XOR4]], [[OR6]] -; CHECK-NEXT: [[XOR6:%.*]] = xor <8 x i1> [[XOR5]], [[OR7]] +; CHECK-NEXT: [[XOR2:%.*]] = xor <8 x i1> [[OR1]], [[OR0]] +; CHECK-NEXT: [[OR045:%.*]] = xor <8 x i1> [[XOR2]], [[OR2]] +; CHECK-NEXT: [[XOR3:%.*]] = xor <8 x i1> [[OR045]], [[OR3]] +; CHECK-NEXT: [[XOR4:%.*]] = xor <8 x i1> [[XOR3]], [[OR4]] +; CHECK-NEXT: [[XOR5:%.*]] = xor <8 x i1> [[XOR4]], [[OR5]] +; CHECK-NEXT: [[XOR6:%.*]] = xor <8 x i1> [[XOR5]], [[OR6]] +; CHECK-NEXT: [[XOR7:%.*]] = xor <8 x i1> [[XOR6]], [[OR7]] ; CHECK-NEXT: [[OR4560:%.*]] = or <8 x i1> [[OR045]], [[XOR2]] ; CHECK-NEXT: [[OR023:%.*]] = or <8 x i1> [[OR4560]], [[XOR3]] ; CHECK-NEXT: [[OR001:%.*]] = or <8 x i1> [[OR023]], [[XOR4]] ; CHECK-NEXT: [[OR0123:%.*]] = or <8 x i1> [[OR001]], [[XOR5]] ; CHECK-NEXT: [[OR01234567:%.*]] = or <8 x i1> [[OR0123]], [[XOR6]] -; CHECK-NEXT: ret <8 x i1> [[OR01234567]] +; CHECK-NEXT: [[OR1234567:%.*]] = or <8 x i1> [[OR01234567]], [[XOR7]] +; CHECK-NEXT: ret <8 x i1> [[OR1234567]] ; %or0 = or <8 x i1> %b0, %a %or1 = or <8 x i1> %b1, %a diff --git a/llvm/test/Transforms/Reassociate/repeats.ll b/llvm/test/Transforms/Reassociate/repeats.ll index 28177f1c0ba5ee..ba25c4bfc643cd 100644 --- a/llvm/test/Transforms/Reassociate/repeats.ll +++ b/llvm/test/Transforms/Reassociate/repeats.ll @@ -15,7 +15,7 @@ define i8 @nilpotent(i8 %x) { define i2 @idempotent(i2 %x) { ; CHECK-LABEL: define i2 @idempotent( ; CHECK-SAME: i2 [[X:%.*]]) { -; CHECK-NEXT: ret i2 -1 +; CHECK-NEXT: ret i2 [[X]] ; %tmp1 = and i2 %x, %x %tmp2 = and i2 %tmp1, %x @@ -60,8 +60,7 @@ define i3 @foo3x5(i3 %x) { ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]] ; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP3]], [[X]] -; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[TMP4]], [[TMP3]] -; CHECK-NEXT: ret i3 [[TMP5]] +; CHECK-NEXT: ret i3 [[TMP4]] ; %tmp1 = mul i3 %x, %x %tmp2 = mul i3 %tmp1, %x @@ -75,8 +74,7 @@ define i3 @foo3x5_nsw(i3 %x) { ; CHECK-LABEL: define i3 @foo3x5_nsw( ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[X]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[X]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i3 [[TMP2]], [[TMP3]] +; CHECK-NEXT: [[TMP4:%.*]] = mul nsw i3 [[TMP3]], [[X]] ; CHECK-NEXT: ret i3 [[TMP4]] ; %tmp1 = mul i3 %x, %x @@ -91,8 +89,7 @@ define i3 @foo3x6(i3 %x) { ; CHECK-LABEL: define i3 @foo3x6( ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i3 [[X]], [[X]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP3]], [[TMP3]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i3 [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret i3 [[TMP2]] ; %tmp1 = mul i3 %x, %x @@ -108,9 +105,7 @@ define i3 @foo3x7(i3 %x) { ; CHECK-LABEL: define i3 @foo3x7( ; CHECK-SAME: i3 [[X:%.*]]) { ; CHECK-NEXT: [[TMP5:%.*]] = mul i3 [[X]], [[X]] -; CHECK-NEXT: [[TMP7:%.*]] = mul i3 [[TMP5]], [[X]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i3 [[TMP7]], [[X]] -; CHECK-NEXT: [[TMP6:%.*]] = mul i3 [[TMP3]], [[TMP7]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i3 [[TMP5]], [[X]] ; CHECK-NEXT: ret i3 [[TMP6]] ; %tmp1 = mul i3 %x, %x @@ -127,8 +122,7 @@ define i4 @foo4x8(i4 %x) { ; CHECK-LABEL: define i4 @foo4x8( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP1]], [[TMP1]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[TMP3]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret i4 [[TMP4]] ; %tmp1 = mul i4 %x, %x @@ -146,9 +140,8 @@ define i4 @foo4x9(i4 %x) { ; CHECK-LABEL: define i4 @foo4x9( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]] -; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP3]], [[TMP2]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] +; CHECK-NEXT: [[TMP8:%.*]] = mul i4 [[TMP2]], [[TMP1]] ; CHECK-NEXT: ret i4 [[TMP8]] ; %tmp1 = mul i4 %x, %x @@ -167,8 +160,7 @@ define i4 @foo4x10(i4 %x) { ; CHECK-LABEL: define i4 @foo4x10( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] ; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]] ; CHECK-NEXT: ret i4 [[TMP3]] ; @@ -189,8 +181,7 @@ define i4 @foo4x11(i4 %x) { ; CHECK-LABEL: define i4 @foo4x11( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[TMP1]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP4]], [[X]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] ; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[X]] ; CHECK-NEXT: [[TMP10:%.*]] = mul i4 [[TMP3]], [[TMP2]] ; CHECK-NEXT: ret i4 [[TMP10]] @@ -213,9 +204,7 @@ define i4 @foo4x12(i4 %x) { ; CHECK-LABEL: define i4 @foo4x12( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]] -; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP3]], [[TMP3]] +; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[TMP1]] ; CHECK-NEXT: ret i4 [[TMP2]] ; %tmp1 = mul i4 %x, %x @@ -238,9 +227,7 @@ define i4 @foo4x13(i4 %x) { ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] ; CHECK-NEXT: [[TMP2:%.*]] = mul i4 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP2]], [[TMP2]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP3]], [[X]] -; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP4]], [[TMP3]] +; CHECK-NEXT: [[TMP12:%.*]] = mul i4 [[TMP2]], [[TMP1]] ; CHECK-NEXT: ret i4 [[TMP12]] ; %tmp1 = mul i4 %x, %x @@ -263,9 +250,7 @@ define i4 @foo4x14(i4 %x) { ; CHECK-LABEL: define i4 @foo4x14( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP4]], [[TMP4]] -; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP5]], [[X]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]] ; CHECK-NEXT: [[TMP7:%.*]] = mul i4 [[TMP6]], [[TMP6]] ; CHECK-NEXT: ret i4 [[TMP7]] ; @@ -290,9 +275,7 @@ define i4 @foo4x15(i4 %x) { ; CHECK-LABEL: define i4 @foo4x15( ; CHECK-SAME: i4 [[X:%.*]]) { ; CHECK-NEXT: [[TMP1:%.*]] = mul i4 [[X]], [[X]] -; CHECK-NEXT: [[TMP4:%.*]] = mul i4 [[TMP1]], [[X]] -; CHECK-NEXT: [[TMP3:%.*]] = mul i4 [[TMP4]], [[TMP4]] -; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP3]], [[X]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i4 [[TMP1]], [[X]] ; CHECK-NEXT: [[TMP5:%.*]] = mul i4 [[TMP6]], [[X]] ; CHECK-NEXT: [[TMP14:%.*]] = mul i4 [[TMP5]], [[TMP6]] ; CHECK-NEXT: ret i4 [[TMP14]]