From 3dbac2c007c114a720300d2a4d79abe9ca1351e7 Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Fri, 1 Dec 2023 00:39:34 +0100 Subject: [PATCH] [mlir] Expose type and attribute names in the MLIRContext and abstract type/attr classes (#72189) This patch expose the type and attribute names in C++ as methods in the `AbstractType` and `AbstractAttribute` classes, and keep a map of names to `AbstractType` and `AbstractAttribute` in the `MLIRContext`. Type and attribute names should be unique. It adds support in ODS to generate the `getName` methods in `AbstractType` and `AbstractAttribute`, through the use of two new variables, `typeName` and `attrName`. It also adds names to C++-defined type and attributes. --- .../include/flang/Optimizer/Dialect/FIRAttr.h | 8 ++ mlir/examples/toy/Ch7/include/toy/Dialect.h | 3 + mlir/include/mlir/Dialect/DLTI/DLTI.h | 4 + mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 35 ++++++-- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 17 ++-- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 4 + mlir/include/mlir/Dialect/Quant/QuantTypes.h | 8 ++ .../mlir/Dialect/SPIRV/IR/SPIRVAttributes.h | 6 ++ .../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 20 +++++ mlir/include/mlir/IR/AttrTypeBase.td | 12 ++- mlir/include/mlir/IR/AttributeSupport.h | 22 +++-- mlir/include/mlir/IR/BuiltinAttributes.h | 2 + mlir/include/mlir/IR/BuiltinAttributes.td | 43 ++++----- .../mlir/IR/BuiltinLocationAttributes.td | 6 ++ mlir/include/mlir/IR/BuiltinTypes.td | 55 ++++++------ mlir/include/mlir/IR/TypeSupport.h | 23 +++-- mlir/include/mlir/TableGen/AttrOrTypeDef.h | 8 ++ mlir/lib/IR/ExtensibleDialect.cpp | 16 +++- mlir/lib/IR/MLIRContext.cpp | 45 +++++++++- mlir/lib/TableGen/AttrOrTypeDef.cpp | 16 ++++ mlir/test/lib/Dialect/Test/TestTypes.h | 4 +- mlir/test/mlir-tblgen/attrdefs.td | 5 ++ mlir/test/mlir-tblgen/op-attribute.td | 1 + mlir/test/mlir-tblgen/op-decl-and-defs.td | 4 +- mlir/test/mlir-tblgen/typedefs.td | 2 + mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 17 ++++ mlir/unittests/IR/CMakeLists.txt | 1 + mlir/unittests/IR/TypeAttrNamesTest.cpp | 90 +++++++++++++++++++ mlir/unittests/IR/TypeTest.cpp | 5 ++ .../Interfaces/DataLayoutInterfacesTest.cpp | 7 ++ 30 files changed, 409 insertions(+), 80 deletions(-) create mode 100644 mlir/unittests/IR/TypeAttrNamesTest.cpp diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.h b/flang/include/flang/Optimizer/Dialect/FIRAttr.h index 2b14e15c906c3c..c427a6576b5dab 100644 --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.h +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.h @@ -38,6 +38,7 @@ class ExactTypeAttr using Base::Base; using ValueType = mlir::Type; + static constexpr llvm::StringLiteral name = "fir.type_is"; static constexpr llvm::StringRef getAttrName() { return "type_is"; } static ExactTypeAttr get(mlir::Type value); @@ -51,6 +52,7 @@ class SubclassAttr using Base::Base; using ValueType = mlir::Type; + static constexpr llvm::StringLiteral name = "fir.class_is"; static constexpr llvm::StringRef getAttrName() { return "class_is"; } static SubclassAttr get(mlir::Type value); @@ -63,6 +65,7 @@ class MustBeHeapAttr : public mlir::BoolAttr { public: using BoolAttr::BoolAttr; + static constexpr llvm::StringLiteral name = "fir.must_be_heap"; static constexpr llvm::StringRef getAttrName() { return "fir.must_be_heap"; } }; @@ -78,6 +81,7 @@ class ClosedIntervalAttr public: using Base::Base; + static constexpr llvm::StringLiteral name = "fir.interval"; static constexpr llvm::StringRef getAttrName() { return "interval"; } static ClosedIntervalAttr get(mlir::MLIRContext *ctxt); }; @@ -92,6 +96,7 @@ class UpperBoundAttr public: using Base::Base; + static constexpr llvm::StringLiteral name = "fir.upper"; static constexpr llvm::StringRef getAttrName() { return "upper"; } static UpperBoundAttr get(mlir::MLIRContext *ctxt); }; @@ -106,6 +111,7 @@ class LowerBoundAttr public: using Base::Base; + static constexpr llvm::StringLiteral name = "fir.lower"; static constexpr llvm::StringRef getAttrName() { return "lower"; } static LowerBoundAttr get(mlir::MLIRContext *ctxt); }; @@ -120,6 +126,7 @@ class PointIntervalAttr public: using Base::Base; + static constexpr llvm::StringLiteral name = "fir.point"; static constexpr llvm::StringRef getAttrName() { return "point"; } static PointIntervalAttr get(mlir::MLIRContext *ctxt); }; @@ -135,6 +142,7 @@ class RealAttr using Base::Base; using ValueType = std::pair; + static constexpr llvm::StringLiteral name = "fir.real"; static constexpr llvm::StringRef getAttrName() { return "real"; } static RealAttr get(mlir::MLIRContext *ctxt, const ValueType &key); diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h index bbcc6cd7f0b184..64094c3515915d 100644 --- a/mlir/examples/toy/Ch7/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -72,6 +72,9 @@ class StructType : public mlir::Type::TypeBase shape, Type elementType, StringRef operand); @@ -168,18 +172,35 @@ void addAsyncDependency(Operation *op, Value token); // Handle types for sparse. enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp }; -template -class SparseHandleType - : public Type::TypeBase, Type, TypeStorage> { +class SparseDnTensorHandleType + : public Type::TypeBase { +public: + using Base = typename Type::TypeBase::Base; + using Base::Base; + + static constexpr StringLiteral name = "gpu.sparse.dntensor_handle"; +}; + +class SparseSpMatHandleType + : public Type::TypeBase { public: using Base = - typename Type::TypeBase, Type, TypeStorage>::Base; + typename Type::TypeBase::Base; using Base::Base; + + static constexpr StringLiteral name = "gpu.sparse.spmat_handle"; }; -using SparseDnTensorHandleType = SparseHandleType; -using SparseSpMatHandleType = SparseHandleType; -using SparseSpGEMMOpHandleType = SparseHandleType; +class SparseSpGEMMOpHandleType + : public Type::TypeBase { +public: + using Base = typename Type::TypeBase::Base; + using Base::Base; + + static constexpr StringLiteral name = "gpu.sparse.spgemmop_handle"; +}; } // namespace gpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index 06a41c7c1a7245..93733ccd4929ae 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -58,18 +58,19 @@ namespace LLVM { //===----------------------------------------------------------------------===// // Batch-define trivial types. -#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \ +#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, TypeName) \ class ClassName : public Type::TypeBase { \ public: \ using Base::Base; \ + static constexpr StringLiteral name = TypeName; \ } -DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void"); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128"); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, "llvm.x86_mmx"); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token"); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label"); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata"); #undef DEFINE_TRIVIAL_LLVM_TYPE @@ -110,6 +111,8 @@ class LLVMStructType /// Inherit base constructors. using Base::Base; + static constexpr StringLiteral name = "llvm.struct"; + /// Checks if the given type can be contained in a structure type. static bool isValidElementType(Type type); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td index 0bd068c1be7c90..96cdbf01b4bd91 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -162,6 +162,8 @@ def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> { elements can be processed as one in SIMD context. }]; + let typeName = "llvm.fixed_vec"; + let parameters = (ins "Type":$elementType, "unsigned":$numElements); let assemblyFormat = [{ `<` $numElements `x` custom($elementType) `>` @@ -192,6 +194,8 @@ def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> { elements can be processed as one in SIMD context. }]; + let typeName = "llvm.scalable_vec"; + let parameters = (ins "Type":$elementType, "unsigned":$minNumElements); let assemblyFormat = [{ `<` `?` `x` $minNumElements `x` ` ` custom($elementType) `>` diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h index 2776b3e6e17ba5..de5aed0a91a209 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -198,6 +198,8 @@ class AnyQuantizedType using Base::Base; using Base::getChecked; + static constexpr StringLiteral name = "quant.any"; + /// Gets an instance of the type with all parameters specified but not /// checked. static AnyQuantizedType get(unsigned flags, Type storageType, @@ -257,6 +259,8 @@ class UniformQuantizedType using Base::Base; using Base::getChecked; + static constexpr StringLiteral name = "quant.uniform"; + /// Gets an instance of the type with all parameters specified but not /// checked. static UniformQuantizedType get(unsigned flags, Type storageType, @@ -315,6 +319,8 @@ class UniformQuantizedPerAxisType using Base::Base; using Base::getChecked; + static constexpr StringLiteral name = "quant.uniform_per_axis"; + /// Gets an instance of the type with all parameters specified but not /// checked. static UniformQuantizedPerAxisType @@ -383,6 +389,8 @@ class CalibratedQuantizedType using Base::Base; using Base::getChecked; + static constexpr StringLiteral name = "quant.calibrated"; + /// Gets an instance of the type with all parameters specified but not /// checked. static CalibratedQuantizedType get(Type expressedType, double min, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h index 3b914dc4cc82f1..5ebfa9ca5ec25c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h @@ -79,6 +79,8 @@ class InterfaceVarABIAttr static LogicalResult verify(function_ref emitError, IntegerAttr descriptorSet, IntegerAttr binding, IntegerAttr storageClass); + + static constexpr StringLiteral name = "spirv.interface_var_abi"; }; /// An attribute that specifies the SPIR-V (version, capabilities, extensions) @@ -129,6 +131,8 @@ class VerCapExtAttr static LogicalResult verify(function_ref emitError, IntegerAttr version, ArrayAttr capabilities, ArrayAttr extensions); + + static constexpr StringLiteral name = "spirv.ver_cap_ext"; }; /// An attribute that specifies the target version, allowed extensions and @@ -183,6 +187,8 @@ class TargetEnvAttr /// Returns the target resource limits. ResourceLimitsAttr getResourceLimits() const; + + static constexpr StringLiteral name = "spirv.target_env"; }; } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 4be2582f8fd68c..d946d936d4e6cf 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -132,6 +132,8 @@ class ArrayType : public Type::TypeBase emitError, diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td index 42a611ee8e4220..91c9283de8bd41 100644 --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -264,6 +264,9 @@ class AttrDef traits = [], // Make it possible to use such attributes as parameters for other attributes. string cppType = dialect.cppNamespace # "::" # cppClassName; + // The unique attribute name. + string attrName = dialect.name # "." # mnemonic; + // The call expression to convert from the storage type to the return // type. For example, an enum can be stored as an int but returned as an // enum class. @@ -289,6 +292,9 @@ class TypeDef traits = [], // Make it possible to use such type as parameters for other types. string cppType = dialect.cppNamespace # "::" # cppClassName; + // The unique type name. + string typeName = dialect.name # "." # mnemonic; + // A constant builder provided when the type has no parameters. let builderCall = !if(!empty(parameters), "$_builder.getType<" # dialect.cppNamespace # @@ -431,15 +437,15 @@ class AttributeSelfTypeParameter traits = []> - : AttrDef { + : AttrDef { let parameters = (ins OptionalArrayRefParameter:$value); let mnemonic = attrMnemonic; let assemblyFormat = "`[` (`]`) : ($value^ `]`)?"; let returnType = "::llvm::ArrayRef<" # eltName # ">"; - let constBuilderCall = "$_builder.getAttr<" # attrName # "Attr>($0)"; + let constBuilderCall = "$_builder.getAttr<" # name # "Attr>($0)"; let convertFromStorage = "$_self.getValue()"; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h index 75ea1ce24753c9..9f36ee4aae2783 100644 --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -38,6 +38,11 @@ class AbstractAttribute { /// reference to it. static const AbstractAttribute &lookup(TypeID typeID, MLIRContext *context); + /// Look up the specified abstract attribute in the MLIRContext and return a + /// reference to it if it exists. + static std::optional> + lookup(StringRef name, MLIRContext *context); + /// This method is used by Dialect objects when they register the list of /// attributes they contain. template @@ -45,7 +50,7 @@ class AbstractAttribute { return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(), T::getWalkImmediateSubElementsFn(), T::getReplaceImmediateSubElementsFn(), - T::getTypeID()); + T::getTypeID(), T::name); } /// This method is used by Dialect objects to register attributes with @@ -57,10 +62,10 @@ class AbstractAttribute { HasTraitFn &&hasTrait, WalkImmediateSubElementsFn walkImmediateSubElementsFn, ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, - TypeID typeID) { + TypeID typeID, StringRef name) { return AbstractAttribute(dialect, std::move(interfaceMap), std::move(hasTrait), walkImmediateSubElementsFn, - replaceImmediateSubElementsFn, typeID); + replaceImmediateSubElementsFn, typeID, name); } /// Return the dialect this attribute was registered to. @@ -102,17 +107,20 @@ class AbstractAttribute { /// Return the unique identifier representing the concrete attribute class. TypeID getTypeID() const { return typeID; } + /// Return the unique name representing the type. + StringRef getName() const { return name; } + private: AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTraitFn, WalkImmediateSubElementsFn walkImmediateSubElementsFn, ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, - TypeID typeID) + TypeID typeID, StringRef name) : dialect(dialect), interfaceMap(std::move(interfaceMap)), hasTraitFn(std::move(hasTraitFn)), walkImmediateSubElementsFn(walkImmediateSubElementsFn), replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), - typeID(typeID) {} + typeID(typeID), name(name) {} /// Give StorageUserBase access to the mutable lookup. template traits = [], +class Builtin_Attr traits = [], string baseCppClass = "::mlir::Attribute"> : AttrDef { let mnemonic = ?; + let attrName = "builtin." # attrMnemonic; } //===----------------------------------------------------------------------===// // AffineMapAttr //===----------------------------------------------------------------------===// -def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [ +def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", "affine_map", [ MemRefLayoutAttrInterface ]> { let summary = "An Attribute containing an AffineMap object"; @@ -70,7 +71,7 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [ // ArrayAttr //===----------------------------------------------------------------------===// -def Builtin_ArrayAttr : Builtin_Attr<"Array"> { +def Builtin_ArrayAttr : Builtin_Attr<"Array", "array"> { let summary = "A collection of other Attribute values"; let description = [{ Syntax: @@ -152,7 +153,7 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter< }]; } -def Builtin_DenseArray : Builtin_Attr<"DenseArray"> { +def Builtin_DenseArray : Builtin_Attr<"DenseArray", "dense_array"> { let summary = "A dense array of integer or floating point elements."; let description = [{ A dense array attribute is an attribute that represents a dense array of @@ -218,7 +219,7 @@ def Builtin_DenseArray : Builtin_Attr<"DenseArray"> { //===----------------------------------------------------------------------===// def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< - "DenseIntOrFPElements", [ElementsAttrInterface], + "DenseIntOrFPElements", "dense_int_or_fp_elements", [ElementsAttrInterface], "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " @@ -359,7 +360,7 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< //===----------------------------------------------------------------------===// def Builtin_DenseStringElementsAttr : Builtin_Attr< - "DenseStringElements", [ElementsAttrInterface], + "DenseStringElements", "dense_string_elements", [ElementsAttrInterface], "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " @@ -429,9 +430,8 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr< // DenseResourceElementsAttr //===----------------------------------------------------------------------===// -def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [ - ElementsAttrInterface - ]> { +def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", + "dense_resource_elements", [ElementsAttrInterface]> { let summary = "An Attribute containing a dense multi-dimensional array " "backed by a resource"; let description = [{ @@ -487,7 +487,7 @@ def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [ // DictionaryAttr //===----------------------------------------------------------------------===// -def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> { +def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", "dictionary"> { let summary = "An dictionary of named Attribute values"; let description = [{ Syntax: @@ -585,7 +585,7 @@ def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> { // FloatAttr //===----------------------------------------------------------------------===// -def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> { +def Builtin_FloatAttr : Builtin_Attr<"Float", "float", [TypedAttrInterface]> { let summary = "An Attribute containing a floating-point value"; let description = [{ Syntax: @@ -648,7 +648,8 @@ def Builtin_FloatAttr : Builtin_Attr<"Float", [TypedAttrInterface]> { // IntegerAttr //===----------------------------------------------------------------------===// -def Builtin_IntegerAttr : Builtin_Attr<"Integer", [TypedAttrInterface]> { +def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer", + [TypedAttrInterface]> { let summary = "An Attribute containing a integer value"; let description = [{ Syntax: @@ -736,7 +737,7 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", [TypedAttrInterface]> { // IntegerSetAttr //===----------------------------------------------------------------------===// -def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet"> { +def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet", "integer_set"> { let summary = "An Attribute containing an IntegerSet object"; let description = [{ Syntax: @@ -765,7 +766,8 @@ def Builtin_IntegerSetAttr : Builtin_Attr<"IntegerSet"> { // OpaqueAttr //===----------------------------------------------------------------------===// -def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", [TypedAttrInterface]> { +def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", "opaque", + [TypedAttrInterface]> { let summary = "An opaque representation of another Attribute"; let description = [{ Syntax: @@ -803,7 +805,7 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque", [TypedAttrInterface]> { //===----------------------------------------------------------------------===// def Builtin_SparseElementsAttr : Builtin_Attr< - "SparseElements", [ElementsAttrInterface] + "SparseElements", "sparse_elements", [ElementsAttrInterface] > { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ @@ -958,7 +960,7 @@ def Builtin_SparseElementsAttr : Builtin_Attr< // StridedLayoutAttr //===----------------------------------------------------------------------===// -def StridedLayoutAttr : Builtin_Attr<"StridedLayout", +def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout", [DeclareAttrInterfaceMethods]> { let summary = "An Attribute representing a strided layout of a shaped type"; @@ -1012,7 +1014,8 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", // StringAttr //===----------------------------------------------------------------------===// -def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> { +def Builtin_StringAttr : Builtin_Attr<"String", "string", + [TypedAttrInterface]> { let summary = "An Attribute containing a string"; let description = [{ Syntax: @@ -1093,7 +1096,7 @@ def Builtin_StringAttr : Builtin_Attr<"String", [TypedAttrInterface]> { // SymbolRefAttr //===----------------------------------------------------------------------===// -def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> { +def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef", "symbol_ref"> { let summary = "An Attribute containing a symbolic reference to an Operation"; let description = [{ Syntax: @@ -1159,7 +1162,7 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> { // TypeAttr //===----------------------------------------------------------------------===// -def Builtin_TypeAttr : Builtin_Attr<"Type"> { +def Builtin_TypeAttr : Builtin_Attr<"Type", "type"> { let summary = "An Attribute containing a Type"; let description = [{ Syntax: @@ -1192,7 +1195,7 @@ def Builtin_TypeAttr : Builtin_Attr<"Type"> { // UnitAttr //===----------------------------------------------------------------------===// -def Builtin_UnitAttr : Builtin_Attr<"Unit"> { +def Builtin_UnitAttr : Builtin_Attr<"Unit", "unit"> { let summary = "An Attribute value of `unit` type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/BuiltinLocationAttributes.td b/mlir/include/mlir/IR/BuiltinLocationAttributes.td index 3c9d2c57f4d1bd..e1656f268795d0 100644 --- a/mlir/include/mlir/IR/BuiltinLocationAttributes.td +++ b/mlir/include/mlir/IR/BuiltinLocationAttributes.td @@ -56,6 +56,7 @@ def CallSiteLoc : Builtin_LocationAttr<"CallSiteLoc"> { "ArrayRef":$frames)> ]; let skipDefaultBuilders = 1; + let attrName = "builtin.call_site_loc"; } //===----------------------------------------------------------------------===// @@ -98,6 +99,7 @@ def FileLineColLoc : Builtin_LocationAttr<"FileLineColLoc"> { }]> ]; let skipDefaultBuilders = 1; + let attrName = "builtin.file_line_loc"; } //===----------------------------------------------------------------------===// @@ -137,6 +139,7 @@ def FusedLoc : Builtin_LocationAttr<"FusedLoc"> { return get(locs, Attribute(), context); } }]; + let attrName = "builtin.fused_loc"; } //===----------------------------------------------------------------------===// @@ -174,6 +177,7 @@ def NameLoc : Builtin_LocationAttr<"NameLoc"> { }]> ]; let skipDefaultBuilders = 1; + let attrName = "builtin.name_loc"; } //===----------------------------------------------------------------------===// @@ -239,6 +243,7 @@ def OpaqueLoc : Builtin_LocationAttr<"OpaqueLoc"> { } }]; let skipDefaultBuilders = 1; + let attrName = "builtin.opaque_loc"; } //===----------------------------------------------------------------------===// @@ -268,6 +273,7 @@ def UnknownLoc : Builtin_LocationAttr<"UnknownLoc"> { let extraClassDeclaration = [{ static UnknownLoc get(MLIRContext *context); }]; + let attrName = "builtin.unknown_loc"; } #endif // BUILTIN_LOCATION_ATTRIBUTES_TD diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 5ec986ac26de06..1d7772810ae6e8 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -23,17 +23,18 @@ include "mlir/IR/BuiltinTypeInterfaces.td" // remove the definitions in OpBase.td, and repoint users to this file instead. // Base class for Builtin dialect types. -class Builtin_Type traits = [], +class Builtin_Type traits = [], string baseCppClass = "::mlir::Type"> : TypeDef { let mnemonic = ?; + let typeName = "builtin." # typeMnemonic; } //===----------------------------------------------------------------------===// // ComplexType //===----------------------------------------------------------------------===// -def Builtin_Complex : Builtin_Type<"Complex"> { +def Builtin_Complex : Builtin_Type<"Complex", "complex"> { let summary = "Complex number with a parameterized element type"; let description = [{ Syntax: @@ -68,8 +69,8 @@ def Builtin_Complex : Builtin_Type<"Complex"> { //===----------------------------------------------------------------------===// // Base class for Builtin dialect float types. -class Builtin_FloatType - : Builtin_Type { +class Builtin_FloatType + : Builtin_Type { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); }]; @@ -78,7 +79,7 @@ class Builtin_FloatType //===----------------------------------------------------------------------===// // Float8E5M2Type -def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> { +def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2"> { let summary = "8-bit floating point with 2 bit mantissa"; let description = [{ An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits @@ -99,7 +100,7 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2"> { //===----------------------------------------------------------------------===// // Float8E4M3FNType -def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> { +def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN"> { let summary = "8-bit floating point with 3 bit mantissa"; let description = [{ An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits @@ -121,7 +122,7 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN"> { //===----------------------------------------------------------------------===// // Float8E5M2FNUZType -def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> { +def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ", "f8E5M2FNUZ"> { let summary = "8-bit floating point with 2 bit mantissa"; let description = [{ An 8-bit floating point type with 1 sign bit, 5 bits exponent and 2 bits @@ -143,7 +144,7 @@ def Builtin_Float8E5M2FNUZ : Builtin_FloatType<"Float8E5M2FNUZ"> { //===----------------------------------------------------------------------===// // Float8E4M3FNUZType -def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> { +def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ", "f8E4M3FNUZ"> { let summary = "8-bit floating point with 3 bit mantissa"; let description = [{ An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits @@ -165,7 +166,7 @@ def Builtin_Float8E4M3FNUZ : Builtin_FloatType<"Float8E4M3FNUZ"> { //===----------------------------------------------------------------------===// // Float8E4M3B11FNUZType -def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ"> { +def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ", "f8E4M3B11FNUZ"> { let summary = "8-bit floating point with 3 bit mantissa"; let description = [{ An 8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits @@ -187,49 +188,49 @@ def Builtin_Float8E4M3B11FNUZ : Builtin_FloatType<"Float8E4M3B11FNUZ"> { //===----------------------------------------------------------------------===// // BFloat16Type -def Builtin_BFloat16 : Builtin_FloatType<"BFloat16"> { +def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16"> { let summary = "bfloat16 floating-point type"; } //===----------------------------------------------------------------------===// // Float16Type -def Builtin_Float16 : Builtin_FloatType<"Float16"> { +def Builtin_Float16 : Builtin_FloatType<"Float16", "f16"> { let summary = "16-bit floating-point type"; } //===----------------------------------------------------------------------===// // FloatTF32Type -def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32"> { +def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> { let summary = "TF32 floating-point type"; } //===----------------------------------------------------------------------===// // Float32Type -def Builtin_Float32 : Builtin_FloatType<"Float32"> { +def Builtin_Float32 : Builtin_FloatType<"Float32", "f32"> { let summary = "32-bit floating-point type"; } //===----------------------------------------------------------------------===// // Float64Type -def Builtin_Float64 : Builtin_FloatType<"Float64"> { +def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> { let summary = "64-bit floating-point type"; } //===----------------------------------------------------------------------===// // Float80Type -def Builtin_Float80 : Builtin_FloatType<"Float80"> { +def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> { let summary = "80-bit floating-point type"; } //===----------------------------------------------------------------------===// // Float128Type -def Builtin_Float128 : Builtin_FloatType<"Float128"> { +def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> { let summary = "128-bit floating-point type"; } @@ -237,7 +238,7 @@ def Builtin_Float128 : Builtin_FloatType<"Float128"> { // FunctionType //===----------------------------------------------------------------------===// -def Builtin_Function : Builtin_Type<"Function"> { +def Builtin_Function : Builtin_Type<"Function", "function"> { let summary = "Map from a list of inputs to a list of results"; let description = [{ Syntax: @@ -289,7 +290,7 @@ def Builtin_Function : Builtin_Type<"Function"> { // IndexType //===----------------------------------------------------------------------===// -def Builtin_Index : Builtin_Type<"Index"> { +def Builtin_Index : Builtin_Type<"Index", "index"> { let summary = "Integer-like type with unknown platform-dependent bit width"; let description = [{ Syntax: @@ -319,7 +320,7 @@ def Builtin_Index : Builtin_Type<"Index"> { // IntegerType //===----------------------------------------------------------------------===// -def Builtin_Integer : Builtin_Type<"Integer"> { +def Builtin_Integer : Builtin_Type<"Integer", "integer"> { let summary = "Integer type with arbitrary precision up to a fixed limit"; let description = [{ Syntax: @@ -383,7 +384,7 @@ def Builtin_Integer : Builtin_Type<"Integer"> { // MemRefType //===----------------------------------------------------------------------===// -def Builtin_MemRef : Builtin_Type<"MemRef", [ +def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [ ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; @@ -663,7 +664,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [ // NoneType //===----------------------------------------------------------------------===// -def Builtin_None : Builtin_Type<"None"> { +def Builtin_None : Builtin_Type<"None", "none"> { let summary = "A unit type"; let description = [{ NoneType is a unit type, i.e. a type with exactly one possible value, where @@ -678,7 +679,7 @@ def Builtin_None : Builtin_Type<"None"> { // OpaqueType //===----------------------------------------------------------------------===// -def Builtin_Opaque : Builtin_Type<"Opaque"> { +def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> { let summary = "Type of a non-registered dialect"; let description = [{ Syntax: @@ -718,7 +719,7 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> { // RankedTensorType //===----------------------------------------------------------------------===// -def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ +def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [ ShapedTypeInterface ], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; @@ -829,7 +830,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ // TupleType //===----------------------------------------------------------------------===// -def Builtin_Tuple : Builtin_Type<"Tuple"> { +def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> { let summary = "Fixed-sized collection of other types"; let description = [{ Syntax: @@ -896,7 +897,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple"> { // UnrankedMemRefType //===----------------------------------------------------------------------===// -def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ +def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [ ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; @@ -974,7 +975,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ // UnrankedTensorType //===----------------------------------------------------------------------===// -def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ +def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [ ShapedTypeInterface ], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; @@ -1023,7 +1024,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ // VectorType //===----------------------------------------------------------------------===// -def Builtin_Vector : Builtin_Type<"Vector", [ShapedTypeInterface], "Type"> { +def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Type"> { let summary = "Multi-dimensional SIMD vector type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h index 2aa6b1a59e8696..36ef696ba4f6fb 100644 --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -39,13 +39,19 @@ class AbstractType { /// reference to it. static const AbstractType &lookup(TypeID typeID, MLIRContext *context); + /// Look up the specified abstract type in the MLIRContext and return a + /// reference to it if it exists. + static std::optional> + lookup(StringRef name, MLIRContext *context); + /// This method is used by Dialect objects when they register the list of /// types they contain. template static AbstractType get(Dialect &dialect) { return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(), T::getWalkImmediateSubElementsFn(), - T::getReplaceImmediateSubElementsFn(), T::getTypeID()); + T::getReplaceImmediateSubElementsFn(), T::getTypeID(), + T::name); } /// This method is used by Dialect objects to register types with @@ -57,10 +63,10 @@ class AbstractType { HasTraitFn &&hasTrait, WalkImmediateSubElementsFn walkImmediateSubElementsFn, ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, - TypeID typeID) { + TypeID typeID, StringRef name) { return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait), walkImmediateSubElementsFn, - replaceImmediateSubElementsFn, typeID); + replaceImmediateSubElementsFn, typeID, name); } /// Return the dialect this type was registered to. @@ -100,17 +106,20 @@ class AbstractType { /// Return the unique identifier representing the concrete type class. TypeID getTypeID() const { return typeID; } + /// Return the unique name representing the type. + StringRef getName() const { return name; } + private: AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, WalkImmediateSubElementsFn walkImmediateSubElementsFn, ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, - TypeID typeID) + TypeID typeID, StringRef name) : dialect(dialect), interfaceMap(std::move(interfaceMap)), hasTraitFn(std::move(hasTrait)), walkImmediateSubElementsFn(walkImmediateSubElementsFn), replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), - typeID(typeID) {} + typeID(typeID), name(name) {} /// Give StorageUserBase access to the mutable lookup. template getTypeBuilder() const; static bool classof(const AttrOrTypeDef *def); + + /// Get the unique attribute name "dialect.attrname". + StringRef getAttrName() const; }; //===----------------------------------------------------------------------===// @@ -267,6 +270,11 @@ class AttrDef : public AttrOrTypeDef { class TypeDef : public AttrOrTypeDef { public: using AttrOrTypeDef::AttrOrTypeDef; + + static bool classof(const AttrOrTypeDef *def); + + /// Get the unique type name "dialect.typename". + StringRef getTypeName() const; }; } // namespace tblgen diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp index 2225a8f2d1b918..ca7b0e1cbb6b39 100644 --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -407,10 +407,16 @@ void ExtensibleDialect::registerDynamicType( assert(registered && "Trying to create a new dynamic type with an existing name"); + // The StringAttr allocates the type name StringRef for the duration of the + // MLIR context. + MLIRContext *ctx = getContext(); + auto nameAttr = + StringAttr::get(ctx, getNamespace() + "." + typePtr->getName()); + auto abstractType = AbstractType::get( *dialect, DynamicAttr::getInterfaceMap(), DynamicType::getHasTraitFn(), DynamicType::getWalkImmediateSubElementsFn(), - DynamicType::getReplaceImmediateSubElementsFn(), typeID); + DynamicType::getReplaceImmediateSubElementsFn(), typeID, nameAttr); /// Add the type to the dialect and the type uniquer. addType(typeID, std::move(abstractType)); @@ -437,10 +443,16 @@ void ExtensibleDialect::registerDynamicAttr( assert(registered && "Trying to create a new dynamic attribute with an existing name"); + // The StringAttr allocates the attribute name StringRef for the duration of + // the MLIR context. + MLIRContext *ctx = getContext(); + auto nameAttr = + StringAttr::get(ctx, getNamespace() + "." + attrPtr->getName()); + auto abstractAttr = AbstractAttribute::get( *dialect, DynamicAttr::getInterfaceMap(), DynamicAttr::getHasTraitFn(), DynamicAttr::getWalkImmediateSubElementsFn(), - DynamicAttr::getReplaceImmediateSubElementsFn(), typeID); + DynamicAttr::getReplaceImmediateSubElementsFn(), typeID, nameAttr); /// Add the type to the dialect and the type uniquer. addAttribute(typeID, std::move(abstractAttr)); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index e5a397c9d65c1a..2fd9cac6df3d09 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -212,6 +212,13 @@ class MLIRContextImpl { DenseMap registeredTypes; StorageUniquer typeUniquer; + /// This is a mapping from type name to the abstract type describing it. + /// It is used by `AbstractType::lookup` to get an `AbstractType` from a name. + /// As this map needs to be populated before `StringAttr` is loaded, we + /// cannot use `StringAttr` as the key. The context does not take ownership + /// of the key, so the `StringRef` must outlive the context. + llvm::DenseMap nameToType; + /// Cached Type Instances. Float8E5M2Type f8E5M2Ty; Float8E4M3FNType f8E4M3FNTy; @@ -236,6 +243,14 @@ class MLIRContextImpl { DenseMap registeredAttributes; StorageUniquer attributeUniquer; + /// This is a mapping from attribute name to the abstract attribute describing + /// it. It is used by `AbstractType::lookup` to get an `AbstractType` from a + /// name. + /// As this map needs to be populated before `StringAttr` is loaded, we + /// cannot use `StringAttr` as the key. The context does not take ownership + /// of the key, so the `StringRef` must outlive the context. + llvm::DenseMap nameToAttribute; + /// Cached Attribute Instances. BoolAttr falseAttr, trueAttr; UnitAttr unitAttr; @@ -697,6 +712,9 @@ void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) { AbstractType(std::move(typeInfo)); if (!impl.registeredTypes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Type already registered."); + if (!impl.nameToType.insert({newInfo->getName(), newInfo}).second) + llvm::report_fatal_error("Dialect Type with name " + newInfo->getName() + + " is already registered."); } void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { @@ -709,6 +727,9 @@ void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { AbstractAttribute(std::move(attrInfo)); if (!impl.registeredAttributes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Attribute already registered."); + if (!impl.nameToAttribute.insert({newInfo->getName(), newInfo}).second) + llvm::report_fatal_error("Dialect Attribute with name " + + newInfo->getName() + " is already registered."); } //===----------------------------------------------------------------------===// @@ -731,6 +752,16 @@ AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID, return impl.registeredAttributes.lookup(typeID); } +std::optional> +AbstractAttribute::lookup(StringRef name, MLIRContext *context) { + MLIRContextImpl &impl = context->getImpl(); + const AbstractAttribute *type = impl.nameToAttribute.lookup(name); + + if (!type) + return std::nullopt; + return {*type}; +} + //===----------------------------------------------------------------------===// // OperationName //===----------------------------------------------------------------------===// @@ -864,8 +895,8 @@ void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs, OpaqueProperties rhs) { *lhs.as() = *rhs.as(); } -bool OperationName::UnregisteredOpModel::compareProperties(OpaqueProperties lhs, - OpaqueProperties rhs) { +bool OperationName::UnregisteredOpModel::compareProperties( + OpaqueProperties lhs, OpaqueProperties rhs) { return *lhs.as() == *rhs.as(); } llvm::hash_code @@ -945,6 +976,16 @@ AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) { return impl.registeredTypes.lookup(typeID); } +std::optional> +AbstractType::lookup(StringRef name, MLIRContext *context) { + MLIRContextImpl &impl = context->getImpl(); + const AbstractType *type = impl.nameToType.lookup(name); + + if (!type) + return std::nullopt; + return {*type}; +} + //===----------------------------------------------------------------------===// // Type uniquing //===----------------------------------------------------------------------===// diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp index 56c96ffd159cfd..c9dbb3bc76b1fa 100644 --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -220,6 +220,22 @@ bool AttrDef::classof(const AttrOrTypeDef *def) { return def->getDef()->isSubClassOf("AttrDef"); } +StringRef AttrDef::getAttrName() const { + return def->getValueAsString("attrName"); +} + +//===----------------------------------------------------------------------===// +// TypeDef +//===----------------------------------------------------------------------===// + +bool TypeDef::classof(const AttrOrTypeDef *def) { + return def->getDef()->isSubClassOf("TypeDef"); +} + +StringRef TypeDef::getTypeName() const { + return def->getValueAsString("typeName"); +} + //===----------------------------------------------------------------------===// // AttrOrTypeParameter //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index 0ce86dd70ab904..b1b5921d8faddd 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -14,8 +14,8 @@ #ifndef MLIR_TESTTYPES_H #define MLIR_TESTTYPES_H -#include #include +#include #include "TestTraits.h" #include "mlir/IR/Diagnostics.h" @@ -132,6 +132,8 @@ class TestRecursiveType public: using Base::Base; + static constexpr ::mlir::StringLiteral name = "test.recursive"; + static TestRecursiveType get(::mlir::MLIRContext *ctx, ::llvm::StringRef name) { return Base::get(ctx, name); diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index 0fb6958799b2ec..5683f343c6bbc8 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -63,6 +63,7 @@ def C_IndexAttr : TestAttr<"Index"> { } def A_SimpleAttrA : TestAttr<"SimpleA"> { + let attrName = "test.simple"; // DECL: class SimpleAAttr : public ::mlir::Attribute } @@ -123,6 +124,7 @@ def D_SingleParameterAttr : TestAttr<"SingleParameter"> { ins "int": $num ); + let attrName = "test.single_parameter"; // DECL-LABEL: struct SingleParameterAttrStorage; // DECL-LABEL: class SingleParameterAttr // DECL-SAME: detail::SingleParameterAttrStorage @@ -130,6 +132,7 @@ def D_SingleParameterAttr : TestAttr<"SingleParameter"> { def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> { let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param); + let attrName = "test.param_with_accessor_type"; } // DECL-LABEL: class ParamWithAccessorTypeAttr @@ -142,6 +145,7 @@ def G_BuilderWithReturnTypeAttr : TestAttr<"BuilderWithReturnType"> { let parameters = (ins "int":$a); let genVerifyDecl = 1; let builders = [AttrBuilder<(ins), [{ return {}; }], "::mlir::Attribute">]; + let attrName = "test.builder_with_return_type"; } // DECL-LABEL: class BuilderWithReturnTypeAttr @@ -158,6 +162,7 @@ def H_TestExtraClassAttr : TestAttr<"TestExtraClass"> { return i+1; } }]; + let attrName = "test.test_extra_class"; } // DECL-LABEL: TestExtraClassAttr : public ::mlir::Attribute diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index f81fc3df857ec2..07636f53c1e15c 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -22,6 +22,7 @@ def SomeAttr : Attr, "some attribute kind"> { } def SomeAttrDef : AttrDef { + let attrName = "test.some_attr"; } diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index 85d68cc42ccbd7..ca133fafdcb576 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -358,7 +358,9 @@ def Test_Dialect2 : Dialect { let name = "test"; let cppNamespace = "::mlir::dialect2"; } -def TestDialect2Type : TypeDef; +def TestDialect2Type : TypeDef { + let typeName = "test.type"; +} def NS_ResultWithDialectTypeOp : NS_Op<"op_with_dialect_type", []> { let results = (outs TestDialect2Type); diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td index 51006635a75a4a..705da4e41b78b7 100644 --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -45,6 +45,7 @@ class TestType : TypeDef { } def A_SimpleTypeA : TestType<"SimpleA"> { // DECL: class SimpleAType : public ::mlir::Type + let typeName = "test.simple_a"; } def RTLValueType : Type, "Type"> { @@ -97,6 +98,7 @@ def C_IndexType : TestType<"Index"> { } def D_SingleParameterType : TestType<"SingleParameter"> { + let typeName = "test.d_single_parameter"; let parameters = (ins "int": $num ); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 8889dff7a58724..b9a72119790e5a 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -87,6 +87,8 @@ class DefGen { /// Emit top-level declarations: using declarations and any extra class /// declarations. void emitTopLevelDeclarations(); + /// Emit the function that returns the type or attribute name. + void emitName(); /// Emit attribute or type builders. void emitBuilders(); /// Emit a verifier for the def. @@ -180,6 +182,8 @@ DefGen::DefGen(const AttrOrTypeDef &def) // Emit builders for defs with parameters if (storageCls) emitBuilders(); + // Emit the type name. + emitName(); // Emit the verifier. if (storageCls && def.genVerifyDecl()) emitVerifier(); @@ -264,6 +268,19 @@ void DefGen::emitTopLevelDeclarations() { std::move(extraDef)); } +void DefGen::emitName() { + StringRef name; + if (auto *attrDef = dyn_cast(&def)) { + name = attrDef->getAttrName(); + } else { + auto *typeDef = cast(&def); + name = typeDef->getTypeName(); + } + std::string nameDecl = + strfmt("static constexpr ::llvm::StringLiteral name = \"{0}\";\n", name); + defCls.declare(std::move(nameDecl)); +} + void DefGen::emitBuilders() { if (!def.skipDefaultBuilders()) { emitDefaultBuilder(); diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 6d05af193dfae0..1ed46869c2c8a9 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_unittest(MLIRIRTests ShapedTypeTest.cpp SymbolTableTest.cpp TypeTest.cpp + TypeAttrNamesTest.cpp OpPropertiesTest.cpp DEPENDS diff --git a/mlir/unittests/IR/TypeAttrNamesTest.cpp b/mlir/unittests/IR/TypeAttrNamesTest.cpp new file mode 100644 index 00000000000000..488c164b23b4bd --- /dev/null +++ b/mlir/unittests/IR/TypeAttrNamesTest.cpp @@ -0,0 +1,90 @@ +//===- TypeAttrNamesTest.cpp - Type API unit tests ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file test the lookup of AbstractType / AbstractAttribute by name. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/TypeID.h" +#include "gtest/gtest.h" + +using namespace mlir; + +namespace { +struct FooType : Type::TypeBase { + using Base::Base; + + static constexpr StringLiteral name = "fake.foo"; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooType) +}; + +struct BarAttr : Attribute::AttrBase { + using Base::Base; + + static constexpr StringLiteral name = "fake.bar"; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarAttr) +}; + +struct FakeDialect : Dialect { + FakeDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addTypes(); + addAttributes(); + } + + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("fake"); + } + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FakeDialect) +}; +} // namespace + +TEST(AbstractType, LookupWithString) { + MLIRContext ctx; + ctx.loadDialect(); + + // Check that we can lookup an abstract type by name. + auto fooAbstractType = AbstractType::lookup("fake.foo", &ctx); + EXPECT_TRUE(fooAbstractType.has_value()); + EXPECT_TRUE(fooAbstractType->get().getName() == "fake.foo"); + + // Check that the abstract type is the same as the one used by the type. + auto fooType = FooType::get(&ctx); + EXPECT_TRUE(&fooType.getAbstractType() == &fooAbstractType->get()); + + // Check that lookups of non-existing types returns nullopt. + // Even if an attribute with the same name exists. + EXPECT_FALSE(AbstractType::lookup("fake.bar", &ctx).has_value()); +} + +TEST(AbstractAttribute, LookupWithString) { + MLIRContext ctx; + ctx.loadDialect(); + + // Check that we can lookup an abstract type by name. + auto barAbstractAttr = AbstractAttribute::lookup("fake.bar", &ctx); + EXPECT_TRUE(barAbstractAttr.has_value()); + EXPECT_TRUE(barAbstractAttr->get().getName() == "fake.bar"); + + // Check that the abstract Attribute is the same as the one used by the + // Attribute. + auto barAttr = BarAttr::get(&ctx); + EXPECT_TRUE(&barAttr.getAbstractAttribute() == &barAbstractAttr->get()); + + // Check that lookups of non-existing Attributes returns nullopt. + // Even if an attribute with the same name exists. + EXPECT_FALSE(AbstractAttribute::lookup("fake.foo", &ctx).has_value()); +} diff --git a/mlir/unittests/IR/TypeTest.cpp b/mlir/unittests/IR/TypeTest.cpp index 1bb9d077d8ed59..30f6642a9ca71d 100644 --- a/mlir/unittests/IR/TypeTest.cpp +++ b/mlir/unittests/IR/TypeTest.cpp @@ -19,6 +19,9 @@ struct LeafType; struct MiddleType : Type::TypeBase { using Base::Base; + + static constexpr StringLiteral name = "test.middle"; + static bool classof(Type ty) { return ty.getTypeID() == TypeID::get() || Base::classof(ty); } @@ -26,6 +29,8 @@ struct MiddleType : Type::TypeBase { struct LeafType : Type::TypeBase { using Base::Base; + + static constexpr StringLiteral name = "test.leaf"; }; struct FakeDialect : Dialect { diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp index 21b9e7b5ac2c26..79599b8c485094 100644 --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -56,6 +56,9 @@ struct CustomDataLayoutSpec MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CustomDataLayoutSpec) using Base::Base; + + static constexpr StringLiteral name = "test.custom_data_layout_spec"; + static CustomDataLayoutSpec get(MLIRContext *ctx, ArrayRef entries) { return Base::get(ctx, entries); @@ -83,6 +86,8 @@ struct SingleQueryType using Base::Base; + static constexpr StringLiteral name = "test.single_query"; + static SingleQueryType get(MLIRContext *ctx) { return Base::get(ctx); } llvm::TypeSize getTypeSizeInBits(const DataLayout &layout, @@ -131,6 +136,8 @@ struct TypeNoLayout : public Type::TypeBase { using Base::Base; + static constexpr StringLiteral name = "test.no_layout"; + static TypeNoLayout get(MLIRContext *ctx) { return Base::get(ctx); } };