Skip to content

Commit

Permalink
[SandboxIR] Add tracking for ShuffleVectorInst::commute. (llvm#106644)
Browse files Browse the repository at this point in the history
Track it as an operand swap + a `setShuffleMask` and delegate to the
`llvm::ShuffleVectorInst` implementation.
  • Loading branch information
slackito committed Sep 3, 2024
1 parent 4f403e8 commit e89bcfc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
2 changes: 1 addition & 1 deletion llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ class ShuffleVectorInst final

/// Swap the operands and adjust the mask to preserve the semantics of the
/// instruction.
void commute() { cast<llvm::ShuffleVectorInst>(Val)->commute(); }
void commute();

/// Return true if a shufflevector instruction can be formed with the
/// specified operands.
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,13 @@ VectorType *ShuffleVectorInst::getType() const {
Ctx.getType(cast<llvm::ShuffleVectorInst>(Val)->getType()));
}

void ShuffleVectorInst::commute() {
Ctx.getTracker().emplaceIfTracking<ShuffleVectorSetMask>(this);
Ctx.getTracker().emplaceIfTracking<UseSwap>(getOperandUse(0),
getOperandUse(1));
cast<llvm::ShuffleVectorInst>(Val)->commute();
}

Constant *ShuffleVectorInst::getShuffleMaskForBitcode() const {
return Ctx.getOrCreateConstant(
cast<llvm::ShuffleVectorInst>(Val)->getShuffleMaskForBitcode());
Expand Down
20 changes: 16 additions & 4 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ define void @foo(i32 %cond0, i32 %cond1) {
EXPECT_EQ(Switch->findCaseDest(BB1), One);
}

TEST_F(TrackerTest, ShuffleVectorInstSetters) {
TEST_F(TrackerTest, ShuffleVectorInst) {
parseIR(C, R"IR(
define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
%shuf = shufflevector <2 x i8> %v1, <2 x i8> %v2, <2 x i32> <i32 1, i32 2>
Expand All @@ -983,10 +983,22 @@ define void @foo(<2 x i8> %v1, <2 x i8> %v2) {
SmallVector<int, 2> OrigMask(SVI->getShuffleMask());
Ctx.save();
SVI->setShuffleMask(ArrayRef<int>({0, 0}));
EXPECT_THAT(SVI->getShuffleMask(),
testing::Not(testing::ElementsAreArray(OrigMask)));
EXPECT_NE(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));
Ctx.revert();
EXPECT_THAT(SVI->getShuffleMask(), testing::ElementsAreArray(OrigMask));
EXPECT_EQ(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));

// Check commute.
auto *Op0 = SVI->getOperand(0);
auto *Op1 = SVI->getOperand(1);
Ctx.save();
SVI->commute();
EXPECT_EQ(SVI->getOperand(0), Op1);
EXPECT_EQ(SVI->getOperand(1), Op0);
EXPECT_NE(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));
Ctx.revert();
EXPECT_EQ(SVI->getOperand(0), Op0);
EXPECT_EQ(SVI->getOperand(1), Op1);
EXPECT_EQ(SVI->getShuffleMask(), ArrayRef<int>(OrigMask));
}

TEST_F(TrackerTest, PossiblyDisjointInstSetters) {
Expand Down

0 comments on commit e89bcfc

Please sign in to comment.