Skip to content

Commit

Permalink
[SandboxIR] Implement ConstantAggregate (llvm#107136)
Browse files Browse the repository at this point in the history
This patch implements sandboxir:: ConstantAggregate, ConstantStruct,
ConstantArray and ConstantVector, mirroring LLVM IR.
  • Loading branch information
vporpo committed Sep 4, 2024
1 parent 83ad644 commit 814aa43
Show file tree
Hide file tree
Showing 7 changed files with 310 additions and 8 deletions.
94 changes: 94 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ class Value {
friend class PHINode; // For getting `Val`.
friend class UnreachableInst; // For getting `Val`.
friend class CatchSwitchAddHandler; // For `Val`.
friend class ConstantArray; // For `Val`.
friend class ConstantStruct; // For `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -840,6 +842,97 @@ class ConstantFP final : public Constant {
#endif
};

/// Base class for aggregate constants (with operands).
class ConstantAggregate : public Constant {
protected:
ConstantAggregate(ClassID ID, llvm::Constant *C, Context &Ctx)
: Constant(ID, C, Ctx) {}

public:
/// For isa/dyn_cast.
static bool classof(const sandboxir::Value *From) {
auto ID = From->getSubclassID();
return ID == ClassID::ConstantVector || ID == ClassID::ConstantStruct ||
ID == ClassID::ConstantArray;
}
};

class ConstantArray final : public ConstantAggregate {
ConstantArray(llvm::ConstantArray *C, Context &Ctx)
: ConstantAggregate(ClassID::ConstantArray, C, Ctx) {}
friend class Context; // For constructor.

public:
static Constant *get(ArrayType *T, ArrayRef<Constant *> V);
ArrayType *getType() const;

// TODO: Missing functions: getType(), getTypeForElements(), getAnon(), get().

/// For isa/dyn_cast.
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::ConstantArray;
}
};

class ConstantStruct final : public ConstantAggregate {
ConstantStruct(llvm::ConstantStruct *C, Context &Ctx)
: ConstantAggregate(ClassID::ConstantStruct, C, Ctx) {}
friend class Context; // For constructor.

public:
static Constant *get(StructType *T, ArrayRef<Constant *> V);

template <typename... Csts>
static std::enable_if_t<are_base_of<Constant, Csts...>::value, Constant *>
get(StructType *T, Csts *...Vs) {
return get(T, ArrayRef<Constant *>({Vs...}));
}
/// Return an anonymous struct that has the specified elements.
/// If the struct is possibly empty, then you must specify a context.
static Constant *getAnon(ArrayRef<Constant *> V, bool Packed = false) {
return get(getTypeForElements(V, Packed), V);
}
static Constant *getAnon(Context &Ctx, ArrayRef<Constant *> V,
bool Packed = false) {
return get(getTypeForElements(Ctx, V, Packed), V);
}
/// This version of the method allows an empty list.
static StructType *getTypeForElements(Context &Ctx, ArrayRef<Constant *> V,
bool Packed = false);
/// Return an anonymous struct type to use for a constant with the specified
/// set of elements. The list must not be empty.
static StructType *getTypeForElements(ArrayRef<Constant *> V,
bool Packed = false) {
assert(!V.empty() &&
"ConstantStruct::getTypeForElements cannot be called on empty list");
return getTypeForElements(V[0]->getContext(), V, Packed);
}

/// Specialization - reduce amount of casting.
inline StructType *getType() const {
return cast<StructType>(Value::getType());
}

/// For isa/dyn_cast.
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::ConstantStruct;
}
};

class ConstantVector final : public ConstantAggregate {
ConstantVector(llvm::ConstantVector *C, Context &Ctx)
: ConstantAggregate(ClassID::ConstantVector, C, Ctx) {}
friend class Context; // For constructor.

public:
// TODO: Missing functions: getSplat(), getType(), getSplatValue(), get().

/// For isa/dyn_cast.
static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::ConstantVector;
}
};

/// Iterator for `Instruction`s in a `BasicBlock.
/// \Returns an sandboxir::Instruction & when derereferenced.
class BBIterator {
Expand Down Expand Up @@ -3353,6 +3446,7 @@ class Context {
friend class Type; // For LLVMCtx.
friend class PointerType; // For LLVMCtx.
friend class IntegerType; // For LLVMCtx.
friend class StructType; // For LLVMCtx.
Tracker IRTracker;

/// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ DEF_VALUE(Block, BasicBlock)
DEF_CONST(Constant, Constant)
DEF_CONST(ConstantInt, ConstantInt)
DEF_CONST(ConstantFP, ConstantFP)
DEF_CONST(ConstantArray, ConstantArray)
DEF_CONST(ConstantStruct, ConstantStruct)
DEF_CONST(ConstantVector, ConstantVector)

#ifndef DEF_INSTR
#define DEF_INSTR(ID, OPCODE, CLASS)
Expand Down
45 changes: 38 additions & 7 deletions llvm/include/llvm/SandboxIR/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class PointerType;
class VectorType;
class IntegerType;
class FunctionType;
class ArrayType;
class StructType;
#define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
#define DEF_CONST(ID, CLASS) class CLASS;
#include "llvm/SandboxIR/SandboxIRValues.def"
Expand All @@ -36,13 +38,19 @@ class FunctionType;
class Type {
protected:
llvm::Type *LLVMTy;
friend class VectorType; // For LLVMTy.
friend class PointerType; // For LLVMTy.
friend class FunctionType; // For LLVMTy.
friend class IntegerType; // For LLVMTy.
friend class Function; // For LLVMTy.
friend class CallBase; // For LLVMTy.
friend class ConstantInt; // For LLVMTy.
friend class ArrayType; // For LLVMTy.
friend class StructType; // For LLVMTy.
friend class VectorType; // For LLVMTy.
friend class PointerType; // For LLVMTy.
friend class FunctionType; // For LLVMTy.
friend class IntegerType; // For LLVMTy.
friend class Function; // For LLVMTy.
friend class CallBase; // For LLVMTy.
friend class ConstantInt; // For LLVMTy.
friend class ConstantArray; // For LLVMTy.
friend class ConstantStruct; // For LLVMTy.
friend class ConstantVector; // For LLVMTy.

// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
#define DEF_CONST(ID, CLASS) friend class CLASS;
Expand Down Expand Up @@ -281,8 +289,31 @@ class PointerType : public Type {
}
};

class ArrayType : public Type {
public:
// TODO: add missing functions
static bool classof(const Type *From) {
return isa<llvm::ArrayType>(From->LLVMTy);
}
};

class StructType : public Type {
public:
/// This static method is the primary way to create a literal StructType.
static StructType *get(Context &Ctx, ArrayRef<Type *> Elements,
bool IsPacked = false);

bool isPacked() const { return cast<llvm::StructType>(LLVMTy)->isPacked(); }

// TODO: add missing functions
static bool classof(const Type *From) {
return isa<llvm::StructType>(From->LLVMTy);
}
};

class VectorType : public Type {
public:
static VectorType *get(Type *ElementType, ElementCount EC);
// TODO: add missing functions
static bool classof(const Type *From) {
return isa<llvm::VectorType>(From->LLVMTy);
Expand Down
48 changes: 47 additions & 1 deletion llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,44 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat &V) {
return llvm::ConstantFP::isValueValidForType(Ty->LLVMTy, V);
}

Constant *ConstantArray::get(ArrayType *T, ArrayRef<Constant *> V) {
auto &Ctx = T->getContext();
SmallVector<llvm::Constant *> LLVMValues;
LLVMValues.reserve(V.size());
for (auto *Elm : V)
LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
auto *LLVMC =
llvm::ConstantArray::get(cast<llvm::ArrayType>(T->LLVMTy), LLVMValues);
return cast<ConstantArray>(Ctx.getOrCreateConstant(LLVMC));
}

ArrayType *ConstantArray::getType() const {
return cast<ArrayType>(
Ctx.getType(cast<llvm::ConstantArray>(Val)->getType()));
}

Constant *ConstantStruct::get(StructType *T, ArrayRef<Constant *> V) {
auto &Ctx = T->getContext();
SmallVector<llvm::Constant *> LLVMValues;
LLVMValues.reserve(V.size());
for (auto *Elm : V)
LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
auto *LLVMC =
llvm::ConstantStruct::get(cast<llvm::StructType>(T->LLVMTy), LLVMValues);
return cast<ConstantStruct>(Ctx.getOrCreateConstant(LLVMC));
}

StructType *ConstantStruct::getTypeForElements(Context &Ctx,
ArrayRef<Constant *> V,
bool Packed) {
unsigned VecSize = V.size();
SmallVector<Type *, 16> EltTypes;
EltTypes.reserve(VecSize);
for (Constant *Elm : V)
EltTypes.push_back(Elm->getType());
return StructType::get(Ctx, EltTypes, Packed);
}

FunctionType *Function::getFunctionType() const {
return cast<FunctionType>(
Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
Expand Down Expand Up @@ -2459,7 +2497,15 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
return It->second.get();
}
if (auto *F = dyn_cast<llvm::Function>(LLVMV))
if (auto *CA = dyn_cast<llvm::ConstantArray>(C))
It->second = std::unique_ptr<ConstantArray>(new ConstantArray(CA, *this));
else if (auto *CS = dyn_cast<llvm::ConstantStruct>(C))
It->second =
std::unique_ptr<ConstantStruct>(new ConstantStruct(CS, *this));
else if (auto *CV = dyn_cast<llvm::ConstantVector>(C))
It->second =
std::unique_ptr<ConstantVector>(new ConstantVector(CV, *this));
else if (auto *F = dyn_cast<llvm::Function>(LLVMV))
It->second = std::unique_ptr<Function>(new Function(F, *this));
else
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/SandboxIR/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
}

StructType *StructType::get(Context &Ctx, ArrayRef<Type *> Elements,
bool IsPacked) {
SmallVector<llvm::Type *> LLVMElements;
LLVMElements.reserve(Elements.size());
for (Type *Elm : Elements)
LLVMElements.push_back(Elm->LLVMTy);
return cast<StructType>(
Ctx.getType(llvm::StructType::get(Ctx.LLVMCtx, LLVMElements, IsPacked)));
}

VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
return cast<VectorType>(ElementType->getContext().getType(
llvm::VectorType::get(ElementType->LLVMTy, EC)));
}

IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
return cast<IntegerType>(
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
Expand Down
75 changes: 75 additions & 0 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,81 @@ define void @foo(float %v0, double %v1) {
EXPECT_TRUE(NegZero->isExactlyValue(-0.0));
}

// Tests ConstantArray, ConstantStruct and ConstantVector.
TEST_F(SandboxIRTest, ConstantAggregate) {
// Note: we are using i42 to avoid the creation of ConstantDataVector or
// ConstantDataArray.
parseIR(C, R"IR(
define void @foo() {
%array = extractvalue [2 x i42] [i42 0, i42 1], 0
%struct = extractvalue {i42, i42} {i42 0, i42 1}, 0
%vector = extractelement <2 x i42> <i42 0, i42 1>, i32 0
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto &F = *Ctx.createFunction(&LLVMF);
auto &BB = *F.begin();
auto It = BB.begin();
auto *I0 = &*It++;
auto *I1 = &*It++;
auto *I2 = &*It++;
// Check classof() and creation.
auto *Array = cast<sandboxir::ConstantArray>(I0->getOperand(0));
EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Array));
auto *Struct = cast<sandboxir::ConstantStruct>(I1->getOperand(0));
EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Struct));
auto *Vector = cast<sandboxir::ConstantVector>(I2->getOperand(0));
EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Vector));

auto *ZeroI42 = cast<sandboxir::ConstantInt>(Array->getOperand(0));
auto *OneI42 = cast<sandboxir::ConstantInt>(Array->getOperand(1));
// Check ConstantArray::get(), getType().
auto *NewCA =
sandboxir::ConstantArray::get(Array->getType(), {ZeroI42, OneI42});
EXPECT_EQ(NewCA, Array);

// Check ConstantStruct::get(), getType().
auto *NewCS =
sandboxir::ConstantStruct::get(Struct->getType(), {ZeroI42, OneI42});
EXPECT_EQ(NewCS, Struct);
// Check ConstantStruct::get(...).
auto *NewCS2 =
sandboxir::ConstantStruct::get(Struct->getType(), ZeroI42, OneI42);
EXPECT_EQ(NewCS2, Struct);
// Check ConstantStruct::getAnon(ArayRef).
auto *AnonCS = sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42});
EXPECT_FALSE(cast<sandboxir::StructType>(AnonCS->getType())->isPacked());
auto *AnonCSPacked =
sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42}, /*Packed=*/true);
EXPECT_TRUE(cast<sandboxir::StructType>(AnonCSPacked->getType())->isPacked());
// Check ConstantStruct::getAnon(Ctx, ArrayRef).
auto *AnonCS2 = sandboxir::ConstantStruct::getAnon(Ctx, {ZeroI42, OneI42});
EXPECT_EQ(AnonCS2, AnonCS);
auto *AnonCS2Packed = sandboxir::ConstantStruct::getAnon(
Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
EXPECT_EQ(AnonCS2Packed, AnonCSPacked);
// Check ConstantStruct::getTypeForElements(Ctx, ArrayRef).
auto *StructTy =
sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
EXPECT_EQ(StructTy, Struct->getType());
EXPECT_FALSE(StructTy->isPacked());
// Check ConstantStruct::getTypeForElements(Ctx, ArrayRef, Packed).
auto *StructTyPacked = sandboxir::ConstantStruct::getTypeForElements(
Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
EXPECT_TRUE(StructTyPacked->isPacked());
// Check ConstantStruct::getTypeForElements(ArrayRef).
auto *StructTy2 =
sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
EXPECT_EQ(StructTy2, Struct->getType());
// Check ConstantStruct::getTypeForElements(ArrayRef, Packed).
auto *StructTy2Packed = sandboxir::ConstantStruct::getTypeForElements(
Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
EXPECT_EQ(StructTy2Packed, StructTyPacked);
}

TEST_F(SandboxIRTest, Use) {
parseIR(C, R"IR(
define i32 @foo(i32 %v0, i32 %v1) {
Expand Down
Loading

0 comments on commit 814aa43

Please sign in to comment.