diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h index 91d6b58cfee00c..2fdbbbd094650f 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIR.h +++ b/llvm/include/llvm/SandboxIR/SandboxIR.h @@ -123,6 +123,7 @@ class ConstantFP; class ConstantAggregateZero; class ConstantPointerNull; class PoisonValue; +class BlockAddress; class Context; class Function; class Instruction; @@ -323,6 +324,7 @@ class Value { friend class ConstantPointerNull; // For `Val`. friend class UndefValue; // For `Val`. friend class PoisonValue; // For `Val`. + friend class BlockAddress; // For `Val`. /// All values point to the context. Context &Ctx; @@ -1112,6 +1114,33 @@ class PoisonValue final : public UndefValue { #endif }; +class BlockAddress final : public Constant { + BlockAddress(llvm::BlockAddress *C, Context &Ctx) + : Constant(ClassID::BlockAddress, C, Ctx) {} + friend class Context; // For constructor. + +public: + /// Return a BlockAddress for the specified function and basic block. + static BlockAddress *get(Function *F, BasicBlock *BB); + + /// Return a BlockAddress for the specified basic block. The basic + /// block must be embedded into a function. + static BlockAddress *get(BasicBlock *BB); + + /// Lookup an existing \c BlockAddress constant for the given BasicBlock. + /// + /// \returns 0 if \c !BB->hasAddressTaken(), otherwise the \c BlockAddress. + static BlockAddress *lookup(const BasicBlock *BB); + + Function *getFunction() const; + BasicBlock *getBasicBlock() const; + + /// For isa/dyn_cast. + static bool classof(const sandboxir::Value *From) { + return From->getSubclassID() == ClassID::BlockAddress; + } +}; + /// Iterator for `Instruction`s in a `BasicBlock. /// \Returns an sandboxir::Instruction & when derereferenced. class BBIterator { diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def index 459226216703d9..c29e8be24ea754 100644 --- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def +++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def @@ -34,6 +34,7 @@ DEF_CONST(ConstantAggregateZero, ConstantAggregateZero) DEF_CONST(ConstantPointerNull, ConstantPointerNull) DEF_CONST(UndefValue, UndefValue) DEF_CONST(PoisonValue, PoisonValue) +DEF_CONST(BlockAddress, BlockAddress) #ifndef DEF_INSTR #define DEF_INSTR(ID, OPCODE, CLASS) diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp index a4b68bd8ffd7c9..18fdcda15a1a91 100644 --- a/llvm/lib/SandboxIR/SandboxIR.cpp +++ b/llvm/lib/SandboxIR/SandboxIR.cpp @@ -2489,6 +2489,32 @@ PoisonValue *PoisonValue::getElementValue(unsigned Idx) const { cast(Val)->getElementValue(Idx))); } +BlockAddress *BlockAddress::get(Function *F, BasicBlock *BB) { + auto *LLVMC = llvm::BlockAddress::get(cast(F->Val), + cast(BB->Val)); + return cast(F->getContext().getOrCreateConstant(LLVMC)); +} + +BlockAddress *BlockAddress::get(BasicBlock *BB) { + auto *LLVMC = llvm::BlockAddress::get(cast(BB->Val)); + return cast(BB->getContext().getOrCreateConstant(LLVMC)); +} + +BlockAddress *BlockAddress::lookup(const BasicBlock *BB) { + auto *LLVMC = llvm::BlockAddress::lookup(cast(BB->Val)); + return cast_or_null(BB->getContext().getValue(LLVMC)); +} + +Function *BlockAddress::getFunction() const { + return cast( + Ctx.getValue(cast(Val)->getFunction())); +} + +BasicBlock *BlockAddress::getBasicBlock() const { + return cast( + Ctx.getValue(cast(Val)->getBasicBlock())); +} + FunctionType *Function::getFunctionType() const { return cast( Ctx.getType(cast(Val)->getFunctionType())); @@ -2585,6 +2611,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { It->second = std::unique_ptr( new ConstantFP(cast(C), *this)); return It->second.get(); + case llvm::Value::BlockAddressVal: + It->second = std::unique_ptr( + new BlockAddress(cast(C), *this)); + return It->second.get(); case llvm::Value::ConstantAggregateZeroVal: { auto *CAZ = cast(C); It->second = std::unique_ptr( @@ -2640,7 +2670,7 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) { return It->second.get(); } if (auto *BB = dyn_cast(LLVMV)) { - assert(isa(U) && + assert(isa(U) && "This won't create a SBBB, don't call this function directly!"); if (auto *SBBB = getValue(BB)) return SBBB; diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp index 1b939b4d047aaf..b76d24dc297b96 100644 --- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp +++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp @@ -729,6 +729,54 @@ define void @foo() { EXPECT_EQ(UndefStruct->getNumElements(), 2u); } +TEST_F(SandboxIRTest, BlockAddress) { + parseIR(C, R"IR( +define void @foo(ptr %ptr) { +bb0: + store ptr blockaddress(@foo, %bb0), ptr %ptr + ret void +bb1: + ret void +bb2: + ret void +} +)IR"); + Function &LLVMF = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + + auto &F = *Ctx.createFunction(&LLVMF); + auto *BB0 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb0"))); + auto *BB1 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb1"))); + auto *BB2 = cast( + Ctx.getValue(getBasicBlockByName(LLVMF, "bb2"))); + auto It = BB0->begin(); + auto *SI = cast(&*It++); + [[maybe_unused]] auto *Ret = cast(&*It++); + + // Check classof(), creation, getFunction(), getBasicBlock(). + auto *BB0Addr = cast(SI->getValueOperand()); + EXPECT_EQ(BB0Addr->getBasicBlock(), BB0); + EXPECT_EQ(BB0Addr->getFunction(), &F); + // Check get(F, BB). + auto *NewBB0Addr = sandboxir::BlockAddress::get(&F, BB0); + EXPECT_EQ(NewBB0Addr, BB0Addr); + // Check get(BB). + auto *NewBB0Addr2 = sandboxir::BlockAddress::get(BB0); + EXPECT_EQ(NewBB0Addr2, BB0Addr); + auto *BB1Addr = sandboxir::BlockAddress::get(BB1); + EXPECT_EQ(BB1Addr->getBasicBlock(), BB1); + EXPECT_NE(BB1Addr, BB0Addr); + // Check lookup(). + auto *LookupBB0Addr = sandboxir::BlockAddress::lookup(BB0); + EXPECT_EQ(LookupBB0Addr, BB0Addr); + auto *LookupBB1Addr = sandboxir::BlockAddress::lookup(BB1); + EXPECT_EQ(LookupBB1Addr, BB1Addr); + auto *LookupBB2Addr = sandboxir::BlockAddress::lookup(BB2); + EXPECT_EQ(LookupBB2Addr, nullptr); +} + TEST_F(SandboxIRTest, Use) { parseIR(C, R"IR( define i32 @foo(i32 %v0, i32 %v1) {