Skip to content

Commit

Permalink
[InstCombine] Teach takeLog2 about right shifts, truncation and bitwi…
Browse files Browse the repository at this point in the history
…se-and

We left some easy opportunities for further simplifications.

log2(trunc(x)) is simply trunc(log2(x)). This is safe if we know that
trunc is NUW because it means that the truncation didn't drop any bits.
It is also safe if the caller is OK with zero as a possible answer.

log2(x >>u y) is simply `log2(x) - y`.

log2(x & y) is a funny one. It comes up when doing something like:
```
unsigned int f(unsigned int x, unsigned int y) {
  unsigned char a = 1u << x;
  return y / a;
}
```

LLVM would canonicalize this to:
```
  %shl = shl nuw i32 1, %x
  %conv1 = and i32 %shl, 255
  %div = udiv i32 %y, %conv1
```

In cases like these, we can ignore the mask entirely.
This is equivalent to `y >> x`.
  • Loading branch information
majnemer committed Oct 28, 2024
1 parent d3f70db commit 5d4a0d5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 7 deletions.
30 changes: 30 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });

// log2(trunc x) -> trunc log2(X)
// FIXME: Require one use?
if (match(Op, m_Trunc(m_Value(X)))) {
auto *TI = cast<TruncInst>(Op);
if (AssumeNonZero || TI->hasNoUnsignedWrap())
if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() {
return Builder.CreateTrunc(LogX, Op->getType(), "",
/*IsNUW=*/TI->hasNoUnsignedWrap());
});
}

// log2(X << Y) -> log2(X) + Y
// FIXME: Require one use unless X is 1?
if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
Expand All @@ -1437,6 +1449,24 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
}

// log2(X >>u Y) -> log2(X) - Y
// FIXME: Require one use?
if (match(Op, m_LShr(m_Value(X), m_Value(Y)))) {
auto *PEO = cast<PossiblyExactOperator>(Op);
if (AssumeNonZero || PEO->isExact())
if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateSub(LogX, Y); });
}

// log2(X & Y) -> either log2(X) or log2(Y)
// This requires `AssumeNonZero` as `X & Y` may be zero when X != Y.
if (AssumeNonZero && match(Op, m_And(m_Value(X), m_Value(Y)))) {
if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return LogX; });
if (Value *LogY = takeLog2(Builder, Y, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return LogY; });
}

// log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
// FIXME: Require one use?
if (SelectInst *SI = dyn_cast<SelectInst>(Op))
Expand Down
43 changes: 40 additions & 3 deletions llvm/test/Transforms/InstCombine/div.ll
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,8 @@ define <2 x i32> @test31(<2 x i32> %x) {

define i32 @test32(i32 %a, i32 %b) {
; CHECK-LABEL: @test32(
; CHECK-NEXT: [[SHL:%.*]] = shl i32 2, [[B:%.*]]
; CHECK-NEXT: [[DIV:%.*]] = lshr i32 [[SHL]], 2
; CHECK-NEXT: [[DIV2:%.*]] = udiv i32 [[A:%.*]], [[DIV]]
; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[B:%.*]], -1
; CHECK-NEXT: [[DIV2:%.*]] = lshr i32 [[A:%.*]], [[TMP1]]
; CHECK-NEXT: ret i32 [[DIV2]]
;
%shl = shl i32 2, %b
Expand Down Expand Up @@ -1832,3 +1831,41 @@ define i32 @fold_disjoint_or_over_udiv(i32 %x) {
%r = udiv i32 %or, 9
ret i32 %r
}

define i8 @udiv_trunc_shl(i32 %x) {
; CHECK-LABEL: @udiv_trunc_shl(
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
; CHECK-NEXT: [[UDIV1:%.*]] = lshr i8 8, [[TMP1]]
; CHECK-NEXT: ret i8 [[UDIV1]]
;
%lshr = shl i32 1, %x
%trunc = trunc i32 %lshr to i8
%div = udiv i8 8, %trunc
ret i8 %div
}

define i32 @zext_udiv_trunc_lshr(i32 %x) {
; CHECK-LABEL: @zext_udiv_trunc_lshr(
; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
; CHECK-NEXT: [[TMP2:%.*]] = sub i8 5, [[TMP1]]
; CHECK-NEXT: [[UDIV1:%.*]] = lshr i8 8, [[TMP2]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i8 [[UDIV1]] to i32
; CHECK-NEXT: ret i32 [[ZEXT]]
;
%lshr = lshr i32 32, %x
%trunc = trunc i32 %lshr to i8
%div = udiv i8 8, %trunc
%zext = zext i8 %div to i32
ret i32 %zext
}

define i32 @udiv_and_shl(i32 %a, i32 %b, i32 %c) {
; CHECK-LABEL: @udiv_and_shl(
; CHECK-NEXT: [[DIV1:%.*]] = lshr i32 [[C:%.*]], [[A:%.*]]
; CHECK-NEXT: ret i32 [[DIV1]]
;
%shl = shl i32 1, %a
%and = and i32 %b, %shl
%div = udiv i32 %c, %and
ret i32 %div
}
8 changes: 4 additions & 4 deletions llvm/test/Transforms/InstCombine/shift.ll
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,8 @@ entry:

define i32 @test42(i32 %a, i32 %b) {
; CHECK-LABEL: @test42(
; CHECK-NEXT: [[DIV:%.*]] = lshr exact i32 4096, [[B:%.*]]
; CHECK-NEXT: [[DIV2:%.*]] = udiv i32 [[A:%.*]], [[DIV]]
; CHECK-NEXT: [[TMP1:%.*]] = sub i32 12, [[B:%.*]]
; CHECK-NEXT: [[DIV2:%.*]] = lshr i32 [[A:%.*]], [[TMP1]]
; CHECK-NEXT: ret i32 [[DIV2]]
;
%div = lshr i32 4096, %b ; must be exact otherwise we'd divide by zero
Expand All @@ -688,8 +688,8 @@ define i32 @test42(i32 %a, i32 %b) {

define <2 x i32> @test42vec(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @test42vec(
; CHECK-NEXT: [[DIV:%.*]] = lshr exact <2 x i32> <i32 4096, i32 4096>, [[B:%.*]]
; CHECK-NEXT: [[DIV2:%.*]] = udiv <2 x i32> [[A:%.*]], [[DIV]]
; CHECK-NEXT: [[TMP1:%.*]] = sub <2 x i32> <i32 12, i32 12>, [[B:%.*]]
; CHECK-NEXT: [[DIV2:%.*]] = lshr <2 x i32> [[A:%.*]], [[TMP1]]
; CHECK-NEXT: ret <2 x i32> [[DIV2]]
;
%div = lshr <2 x i32> <i32 4096, i32 4096>, %b ; must be exact otherwise we'd divide by zero
Expand Down

0 comments on commit 5d4a0d5

Please sign in to comment.