Skip to content

Commit

Permalink
[mlir] Expose type and attribute names in the MLIRContext and abstrac…
Browse files Browse the repository at this point in the history
…t type/attr classes (llvm#72189)

This patch expose the type and attribute names in C++ as methods in the
`AbstractType` and `AbstractAttribute` classes, and keep a map of names
to `AbstractType` and `AbstractAttribute` in the `MLIRContext`. Type and
attribute names should be unique.

It adds support in ODS to generate the `getName` methods in
`AbstractType` and `AbstractAttribute`, through the use of two new
variables, `typeName` and `attrName`. It also adds names to C++-defined
type and attributes.
  • Loading branch information
math-fehr committed Nov 30, 2023
1 parent 6688657 commit 3dbac2c
Show file tree
Hide file tree
Showing 30 changed files with 409 additions and 80 deletions.
8 changes: 8 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRAttr.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ExactTypeAttr
using Base::Base;
using ValueType = mlir::Type;

static constexpr llvm::StringLiteral name = "fir.type_is";
static constexpr llvm::StringRef getAttrName() { return "type_is"; }
static ExactTypeAttr get(mlir::Type value);

Expand All @@ -51,6 +52,7 @@ class SubclassAttr
using Base::Base;
using ValueType = mlir::Type;

static constexpr llvm::StringLiteral name = "fir.class_is";
static constexpr llvm::StringRef getAttrName() { return "class_is"; }
static SubclassAttr get(mlir::Type value);

Expand All @@ -63,6 +65,7 @@ class MustBeHeapAttr : public mlir::BoolAttr {
public:
using BoolAttr::BoolAttr;

static constexpr llvm::StringLiteral name = "fir.must_be_heap";
static constexpr llvm::StringRef getAttrName() { return "fir.must_be_heap"; }
};

Expand All @@ -78,6 +81,7 @@ class ClosedIntervalAttr
public:
using Base::Base;

static constexpr llvm::StringLiteral name = "fir.interval";
static constexpr llvm::StringRef getAttrName() { return "interval"; }
static ClosedIntervalAttr get(mlir::MLIRContext *ctxt);
};
Expand All @@ -92,6 +96,7 @@ class UpperBoundAttr
public:
using Base::Base;

static constexpr llvm::StringLiteral name = "fir.upper";
static constexpr llvm::StringRef getAttrName() { return "upper"; }
static UpperBoundAttr get(mlir::MLIRContext *ctxt);
};
Expand All @@ -106,6 +111,7 @@ class LowerBoundAttr
public:
using Base::Base;

static constexpr llvm::StringLiteral name = "fir.lower";
static constexpr llvm::StringRef getAttrName() { return "lower"; }
static LowerBoundAttr get(mlir::MLIRContext *ctxt);
};
Expand All @@ -120,6 +126,7 @@ class PointIntervalAttr
public:
using Base::Base;

static constexpr llvm::StringLiteral name = "fir.point";
static constexpr llvm::StringRef getAttrName() { return "point"; }
static PointIntervalAttr get(mlir::MLIRContext *ctxt);
};
Expand All @@ -135,6 +142,7 @@ class RealAttr
using Base::Base;
using ValueType = std::pair<int, llvm::APFloat>;

static constexpr llvm::StringLiteral name = "fir.real";
static constexpr llvm::StringRef getAttrName() { return "real"; }
static RealAttr get(mlir::MLIRContext *ctxt, const ValueType &key);

Expand Down
3 changes: 3 additions & 0 deletions mlir/examples/toy/Ch7/include/toy/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class StructType : public mlir::Type::TypeBase<StructType, mlir::Type,

/// Returns the number of element type held by this struct.
size_t getNumElementTypes() { return getElementTypes().size(); }

/// The name of this struct type.
static constexpr StringLiteral name = "toy.struct";
};
} // namespace toy
} // namespace mlir
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/DLTI/DLTI.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class DataLayoutEntryAttr

/// Prints this attribute.
void print(AsmPrinter &os) const;

static constexpr StringLiteral name = "builtin.data_layout_entry";
};

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -109,6 +111,8 @@ class DataLayoutSpecAttr

/// Prints this attribute.
void print(AsmPrinter &os) const;

static constexpr StringLiteral name = "builtin.data_layout_spec";
};

} // namespace mlir
Expand Down
35 changes: 28 additions & 7 deletions mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class AsyncTokenType
public:
// Used for generic hooks in TypeBase.
using Base::Base;

static constexpr StringLiteral name = "gpu.async_token";
};

/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
Expand Down Expand Up @@ -128,6 +130,8 @@ class MMAMatrixType
public:
using Base::Base;

static constexpr StringLiteral name = "gpu.mma_matrix";

/// Get MMAMatrixType and verify construction Invariants.
static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
Expand Down Expand Up @@ -168,18 +172,35 @@ void addAsyncDependency(Operation *op, Value token);
// Handle types for sparse.
enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };

template <SparseHandleKind K>
class SparseHandleType
: public Type::TypeBase<SparseHandleType<K>, Type, TypeStorage> {
class SparseDnTensorHandleType
: public Type::TypeBase<SparseDnTensorHandleType, Type, TypeStorage> {
public:
using Base = typename Type::TypeBase<SparseDnTensorHandleType, Type,
TypeStorage>::Base;
using Base::Base;

static constexpr StringLiteral name = "gpu.sparse.dntensor_handle";
};

class SparseSpMatHandleType
: public Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage> {
public:
using Base =
typename Type::TypeBase<SparseHandleType<K>, Type, TypeStorage>::Base;
typename Type::TypeBase<SparseSpMatHandleType, Type, TypeStorage>::Base;
using Base::Base;

static constexpr StringLiteral name = "gpu.sparse.spmat_handle";
};

using SparseDnTensorHandleType = SparseHandleType<SparseHandleKind::DnTensor>;
using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
using SparseSpGEMMOpHandleType = SparseHandleType<SparseHandleKind::SpGEMMOp>;
class SparseSpGEMMOpHandleType
: public Type::TypeBase<SparseSpGEMMOpHandleType, Type, TypeStorage> {
public:
using Base = typename Type::TypeBase<SparseSpGEMMOpHandleType, Type,
TypeStorage>::Base;
using Base::Base;

static constexpr StringLiteral name = "gpu.sparse.spgemmop_handle";
};

} // namespace gpu
} // namespace mlir
Expand Down
17 changes: 10 additions & 7 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,19 @@ namespace LLVM {
//===----------------------------------------------------------------------===//

// Batch-define trivial types.
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \
#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, TypeName) \
class ClassName : public Type::TypeBase<ClassName, Type, TypeStorage> { \
public: \
using Base::Base; \
static constexpr StringLiteral name = TypeName; \
}

DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, "llvm.x86_mmx");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label");
DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata");

#undef DEFINE_TRIVIAL_LLVM_TYPE

Expand Down Expand Up @@ -110,6 +111,8 @@ class LLVMStructType
/// Inherit base constructors.
using Base::Base;

static constexpr StringLiteral name = "llvm.struct";

/// Checks if the given type can be contained in a structure type.
static bool isValidElementType(Type type);

Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> {
elements can be processed as one in SIMD context.
}];

let typeName = "llvm.fixed_vec";

let parameters = (ins "Type":$elementType, "unsigned":$numElements);
let assemblyFormat = [{
`<` $numElements `x` custom<PrettyLLVMType>($elementType) `>`
Expand Down Expand Up @@ -192,6 +194,8 @@ def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> {
elements can be processed as one in SIMD context.
}];

let typeName = "llvm.scalable_vec";

let parameters = (ins "Type":$elementType, "unsigned":$minNumElements);
let assemblyFormat = [{
`<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>`
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Quant/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class AnyQuantizedType
using Base::Base;
using Base::getChecked;

static constexpr StringLiteral name = "quant.any";

/// Gets an instance of the type with all parameters specified but not
/// checked.
static AnyQuantizedType get(unsigned flags, Type storageType,
Expand Down Expand Up @@ -257,6 +259,8 @@ class UniformQuantizedType
using Base::Base;
using Base::getChecked;

static constexpr StringLiteral name = "quant.uniform";

/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedType get(unsigned flags, Type storageType,
Expand Down Expand Up @@ -315,6 +319,8 @@ class UniformQuantizedPerAxisType
using Base::Base;
using Base::getChecked;

static constexpr StringLiteral name = "quant.uniform_per_axis";

/// Gets an instance of the type with all parameters specified but not
/// checked.
static UniformQuantizedPerAxisType
Expand Down Expand Up @@ -383,6 +389,8 @@ class CalibratedQuantizedType
using Base::Base;
using Base::getChecked;

static constexpr StringLiteral name = "quant.calibrated";

/// Gets an instance of the type with all parameters specified but not
/// checked.
static CalibratedQuantizedType get(Type expressedType, double min,
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class InterfaceVarABIAttr
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr descriptorSet, IntegerAttr binding,
IntegerAttr storageClass);

static constexpr StringLiteral name = "spirv.interface_var_abi";
};

/// An attribute that specifies the SPIR-V (version, capabilities, extensions)
Expand Down Expand Up @@ -129,6 +131,8 @@ class VerCapExtAttr
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
IntegerAttr version, ArrayAttr capabilities,
ArrayAttr extensions);

static constexpr StringLiteral name = "spirv.ver_cap_ext";
};

/// An attribute that specifies the target version, allowed extensions and
Expand Down Expand Up @@ -183,6 +187,8 @@ class TargetEnvAttr

/// Returns the target resource limits.
ResourceLimitsAttr getResourceLimits() const;

static constexpr StringLiteral name = "spirv.target_env";
};
} // namespace spirv
} // namespace mlir
Expand Down
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.array";

static ArrayType get(Type elementType, unsigned elementCount);

/// Returns an array type with the given stride in bytes.
Expand Down Expand Up @@ -162,6 +164,8 @@ class ImageType
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.image";

static ImageType
get(Type elementType, Dim dim,
ImageDepthInfo depth = ImageDepthInfo::DepthUnknown,
Expand Down Expand Up @@ -201,6 +205,8 @@ class PointerType : public Type::TypeBase<PointerType, SPIRVType,
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.pointer";

static PointerType get(Type pointeeType, StorageClass storageClass);

Type getPointeeType() const;
Expand All @@ -220,6 +226,8 @@ class RuntimeArrayType
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.rtarray";

static RuntimeArrayType get(Type elementType);

/// Returns a runtime array type with the given stride in bytes.
Expand All @@ -244,6 +252,8 @@ class SampledImageType
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.sampled_image";

static SampledImageType get(Type imageType);

static SampledImageType
Expand Down Expand Up @@ -288,6 +298,8 @@ class StructType
// Type for specifying the offset of the struct members
using OffsetInfo = uint32_t;

static constexpr StringLiteral name = "spirv.struct";

// Type for specifying the decoration(s) on struct members
struct MemberDecorationInfo {
uint32_t memberIndex : 31;
Expand Down Expand Up @@ -387,6 +399,8 @@ class CooperativeMatrixType
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.coopmatrix";

static CooperativeMatrixType get(Type elementType, uint32_t rows,
uint32_t columns, Scope scope,
CooperativeMatrixUseKHR use);
Expand Down Expand Up @@ -414,6 +428,8 @@ class CooperativeMatrixNVType
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.NV.coopmatrix";

static CooperativeMatrixNVType get(Type elementType, Scope scope,
unsigned rows, unsigned columns);
Type getElementType() const;
Expand All @@ -438,6 +454,8 @@ class JointMatrixINTELType
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.jointmatrix";

static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows,
unsigned columns, MatrixLayout matrixLayout);
Type getElementType() const;
Expand All @@ -464,6 +482,8 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
public:
using Base::Base;

static constexpr StringLiteral name = "spirv.matrix";

static MatrixType get(Type columnType, uint32_t columnCount);

static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Expand Down
12 changes: 9 additions & 3 deletions mlir/include/mlir/IR/AttrTypeBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
// Make it possible to use such attributes as parameters for other attributes.
string cppType = dialect.cppNamespace # "::" # cppClassName;

// The unique attribute name.
string attrName = dialect.name # "." # mnemonic;

// The call expression to convert from the storage type to the return
// type. For example, an enum can be stored as an int but returned as an
// enum class.
Expand All @@ -289,6 +292,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;

// The unique type name.
string typeName = dialect.name # "." # mnemonic;

// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
"$_builder.getType<" # dialect.cppNamespace #
Expand Down Expand Up @@ -431,15 +437,15 @@ class AttributeSelfTypeParameter<string desc,
/// This class defines an attribute that contains an array of elements. The
/// elements can be any type, but if they are attributes, the nested elements
/// are parsed and printed using the custom attribute syntax.
class ArrayOfAttr<Dialect dialect, string attrName, string attrMnemonic,
class ArrayOfAttr<Dialect dialect, string name, string attrMnemonic,
string eltName, list<Trait> traits = []>
: AttrDef<dialect, attrName, traits> {
: AttrDef<dialect, name, traits> {
let parameters = (ins OptionalArrayRefParameter<eltName>:$value);
let mnemonic = attrMnemonic;
let assemblyFormat = "`[` (`]`) : ($value^ `]`)?";

let returnType = "::llvm::ArrayRef<" # eltName # ">";
let constBuilderCall = "$_builder.getAttr<" # attrName # "Attr>($0)";
let constBuilderCall = "$_builder.getAttr<" # name # "Attr>($0)";
let convertFromStorage = "$_self.getValue()";

let extraClassDeclaration = [{
Expand Down
Loading

0 comments on commit 3dbac2c

Please sign in to comment.