From a2f6d8316a4ab0b630168efa606f91283c73124b Mon Sep 17 00:00:00 2001 From: Ionel Gog Date: Thu, 14 Nov 2024 21:01:41 -0800 Subject: [PATCH] [IFRT] Ensure that VIFRT td file structure matches that of IFRT dialect. PiperOrigin-RevId: 696750835 --- xla/python/ifrt/ir/BUILD | 191 +++++++----------- xla/python/ifrt/ir/sharding_param.cc | 31 ++- xla/python/ifrt/ir/sharding_param.h | 10 + xla/python/ifrt/ir/tests/BUILD | 2 +- xla/python/ifrt/ir/tests/ifrt-translate.cc | 2 +- xla/python/ifrt/ir/vifrt_attrs.td | 122 ----------- xla/python/ifrt/ir/vifrt_base.td | 27 --- .../ir/{vifrt_ops.cc => vifrt_dialect.cc} | 82 ++++---- .../ifrt/ir/{vifrt_ops.h => vifrt_dialect.h} | 56 +++-- xla/python/ifrt/ir/vifrt_dialect.td | 135 +++++++++++++ xla/python/ifrt/ir/vifrt_interfaces.td | 86 ++++++++ xla/python/ifrt/ir/vifrt_ops.td | 45 +---- xla/python/ifrt/ir/vifrt_types.cc | 89 -------- xla/python/ifrt/ir/vifrt_types.h | 65 ------ xla/python/ifrt/ir/vifrt_types.td | 86 -------- xla/python/ifrt/support/BUILD | 1 + xla/python/ifrt/support/module_parsing.cc | 2 + 17 files changed, 428 insertions(+), 604 deletions(-) delete mode 100644 xla/python/ifrt/ir/vifrt_attrs.td delete mode 100644 xla/python/ifrt/ir/vifrt_base.td rename xla/python/ifrt/ir/{vifrt_ops.cc => vifrt_dialect.cc} (75%) rename xla/python/ifrt/ir/{vifrt_ops.h => vifrt_dialect.h} (55%) create mode 100644 xla/python/ifrt/ir/vifrt_interfaces.td delete mode 100644 xla/python/ifrt/ir/vifrt_types.cc delete mode 100644 xla/python/ifrt/ir/vifrt_types.h delete mode 100644 xla/python/ifrt/ir/vifrt_types.td diff --git a/xla/python/ifrt/ir/BUILD b/xla/python/ifrt/ir/BUILD index 9ae96c068bc2e..db790a84abd64 100644 --- a/xla/python/ifrt/ir/BUILD +++ b/xla/python/ifrt/ir/BUILD @@ -315,8 +315,23 @@ cc_library( ], ) +td_library( + name = "vifrt_td", + srcs = [ + "vifrt_dialect.td", + "vifrt_interfaces.td", + "vifrt_ops.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:ShapeOpsTdFiles", + ], +) + gentbl_cc_library( - name = "vifrt_attr_interfaces_inc_gen", + name = "vifrt_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), tbl_outs = [ ( @@ -327,17 +342,61 @@ gentbl_cc_library( ["-gen-attr-interface-defs"], "vifrt_attr_interfaces.cc.inc", ), + ( + ["-gen-type-interface-decls"], + "vifrt_type_interfaces.h.inc", + ), + ( + ["-gen-type-interface-defs"], + "vifrt_type_interfaces.cc.inc", + ), + ( + ["-gen-op-interface-decls"], + "vifrt_op_interfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "vifrt_op_interfaces.cc.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "vifrt_attrs.td", + td_file = "vifrt_interfaces.td", test = True, - deps = [":vifrt_ops_td_files"], + deps = [":vifrt_td"], ) gentbl_cc_library( - name = "vifrt_attrs_inc_gen", + name = "vifrt_dialect_inc_gen", compatible_with = get_compatible_with_portable(), tbl_outs = [ + ( + [ + "-gen-dialect-decls", + "-dialect=vifrt", + ], + "vifrt_dialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=vifrt", + ], + "vifrt_dialect.cc.inc", + ), + ( + [ + "-gen-typedef-decls", + "--typedefs-dialect=vifrt", + ], + "vifrt_types.h.inc", + ), + ( + [ + "-gen-typedef-defs", + "--typedefs-dialect=vifrt", + ], + "vifrt_types.cc.inc", + ), ( [ "-gen-attrdef-decls", @@ -354,50 +413,9 @@ gentbl_cc_library( ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "vifrt_ops.td", + td_file = "vifrt_dialect.td", test = True, - deps = [":vifrt_ops_td_files"], -) - -gentbl_cc_library( - name = "vifrt_op_interfaces_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "vifrt_op_interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "vifrt_op_interfaces.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "vifrt_ops.td", - deps = [":vifrt_ops_td_files"], -) - -cc_library( - name = "vifrt_ops", - srcs = [ - "vifrt_ops.cc", - ], - hdrs = [ - "vifrt_ops.h", - ], - compatible_with = get_compatible_with_portable(), - deps = [ - ":version", - ":vifrt_attr_interfaces_inc_gen", - ":vifrt_attrs_inc_gen", - ":vifrt_op_interfaces_inc_gen", - ":vifrt_ops_inc_gen", - ":vifrt_types", - ":vifrt_types_inc_gen", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], + deps = [":vifrt_td"], ) gentbl_cc_library( @@ -416,85 +434,24 @@ gentbl_cc_library( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "vifrt_ops.td", test = True, - deps = [":vifrt_ops_td_files"], -) - -td_library( - name = "vifrt_ops_td_files", - srcs = [ - "vifrt_attrs.td", - "vifrt_base.td", - "vifrt_dialect.td", - "vifrt_ops.td", - "vifrt_types.td", - ], - compatible_with = get_compatible_with_portable(), - deps = [ - "@llvm-project//mlir:BuiltinDialectTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:ShapeOpsTdFiles", - ], + deps = [":vifrt_td"], ) cc_library( - name = "vifrt_types", - srcs = ["vifrt_types.cc"], - hdrs = ["vifrt_types.h"], + name = "vifrt", + srcs = ["vifrt_dialect.cc"], + hdrs = ["vifrt_dialect.h"], compatible_with = get_compatible_with_portable(), + visibility = ["//xla/python/ifrt:friends"], deps = [ + ":sharding_param", ":version", - ":vifrt_type_interfaces_inc_gen", - ":vifrt_types_inc_gen", + ":vifrt_dialect_inc_gen", + ":vifrt_interfaces_inc_gen", + ":vifrt_ops_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@stablehlo//:stablehlo_assembly_format", - ], -) - -gentbl_cc_library( - name = "vifrt_type_interfaces_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-type-interface-decls"], - "vifrt_type_interfaces.h.inc", - ), - ( - ["-gen-type-interface-defs"], - "vifrt_type_interfaces.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "vifrt_types.td", - deps = [ - ":vifrt_ops_td_files", - ], -) - -gentbl_cc_library( - name = "vifrt_types_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - "--typedefs-dialect=vifrt", - ], - "vifrt_type_defs.h.inc", - ), - ( - [ - "-gen-typedef-defs", - "--typedefs-dialect=vifrt", - ], - "vifrt_type_defs.cc.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "vifrt_ops.td", - deps = [ - ":vifrt_ops_td_files", ], ) diff --git a/xla/python/ifrt/ir/sharding_param.cc b/xla/python/ifrt/ir/sharding_param.cc index 1bedd6e8d1d0c..39c286f5b55af 100644 --- a/xla/python/ifrt/ir/sharding_param.cc +++ b/xla/python/ifrt/ir/sharding_param.cc @@ -75,6 +75,15 @@ void PopulateDevices(llvm::ArrayRef permutation, } } +void PrintInternalV1(llvm::raw_ostream& os, const ShardingParam& sharding) { + PrintDims(os, sharding.dim_shards()); + os << " to ["; + llvm::interleaveComma( + llvm::ArrayRef(sharding.minor_to_major().permutation), os); + os << "] on "; + PrintDims(os, sharding.minor_to_major().axis_sizes); +} + } // namespace absl::Status ShardingParam::MinorToMajor::verify() const { @@ -124,6 +133,12 @@ void ShardingParam::MinorToMajor::ToDeviceList( mlir::FailureOr ShardingParam::Parse( mlir::AsmParser& ods_parser) { + // V1 is the current ShardingParam format. + return ParseV1(ods_parser); +} + +mlir::FailureOr ShardingParam::ParseV1( + mlir::AsmParser& ods_parser) { MinorToMajor minor_to_major; auto parseIntoPermutation = [&]() -> mlir::ParseResult { @@ -159,6 +174,11 @@ mlir::FailureOr ShardingParam::Parse( std::move(minor_to_major)); } +void ShardingParam::PrintV1(mlir::AsmPrinter& ods_printer, + const ShardingParam& sharding) { + PrintInternalV1(ods_printer.getStream(), sharding); +} + absl::Status ShardingParam::verify() const { TF_RETURN_IF_ERROR(minor_to_major().verify()); int dim_index = 0; @@ -275,17 +295,14 @@ llvm::hash_code hash_value(ShardingParam sharding) { } mlir::AsmPrinter& operator<<(mlir::AsmPrinter& os, ShardingParam sharding) { - os.getStream() << sharding; + // V1 if the current ShardingParam version. + PrintInternalV1(os.getStream(), sharding); return os; } llvm::raw_ostream& operator<<(llvm::raw_ostream& os, ShardingParam sharding) { - PrintDims(os, sharding.dim_shards()); - os << " to ["; - llvm::interleaveComma( - llvm::ArrayRef(sharding.minor_to_major().permutation), os); - os << "] on "; - PrintDims(os, sharding.minor_to_major().axis_sizes); + // V1 if the current ShardingParam version. + PrintInternalV1(os, sharding); return os; } diff --git a/xla/python/ifrt/ir/sharding_param.h b/xla/python/ifrt/ir/sharding_param.h index 5a96a9e25573d..fe851be123f08 100644 --- a/xla/python/ifrt/ir/sharding_param.h +++ b/xla/python/ifrt/ir/sharding_param.h @@ -100,6 +100,16 @@ class ShardingParam { minor_to_major_(std::move(minor_to_major)) {} static mlir::FailureOr Parse(mlir::AsmParser& ods_parser); + + // Parses V1 of ShardingParam. This method is meant to be used in the VIFRT + // dialect to parse versioned ShardingParams. + static mlir::FailureOr ParseV1(mlir::AsmParser& ods_parser); + + // Prints V1 of ShardingParam. This method is meant to be used in the VIFRT + // dialect to print versioned ShardingParams. + static void PrintV1(mlir::AsmPrinter& ods_printer, + const ShardingParam& sharding); + absl::Status verify() const; mlir::LogicalResult verify( llvm::function_ref emit_error) const; diff --git a/xla/python/ifrt/ir/tests/BUILD b/xla/python/ifrt/ir/tests/BUILD index a154eed9e4746..5b0ed0741181f 100644 --- a/xla/python/ifrt/ir/tests/BUILD +++ b/xla/python/ifrt/ir/tests/BUILD @@ -53,7 +53,7 @@ xla_cc_binary( "//xla/python/ifrt/ir:ifrt_ir_program", "//xla/python/ifrt/ir:ifrt_ir_program_serdes", # build_cleaner: keep "//xla/python/ifrt/ir:version", - "//xla/python/ifrt/ir:vifrt_ops", + "//xla/python/ifrt/ir:vifrt", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", diff --git a/xla/python/ifrt/ir/tests/ifrt-translate.cc b/xla/python/ifrt/ir/tests/ifrt-translate.cc index ae91e552a83f0..6c21bc3724af9 100644 --- a/xla/python/ifrt/ir/tests/ifrt-translate.cc +++ b/xla/python/ifrt/ir/tests/ifrt-translate.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/ifrt_ir_program.h" #include "xla/python/ifrt/ir/version.h" -#include "xla/python/ifrt/ir/vifrt_ops.h" +#include "xla/python/ifrt/ir/vifrt_dialect.h" #include "xla/python/ifrt/serdes.h" namespace xla { diff --git a/xla/python/ifrt/ir/vifrt_attrs.td b/xla/python/ifrt/ir/vifrt_attrs.td deleted file mode 100644 index 8eec47d172c42..0000000000000 --- a/xla/python/ifrt/ir/vifrt_attrs.td +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_ -#define XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_ - -include "mlir/IR/AttrTypeBase.td" -include "xla/python/ifrt/ir/vifrt_base.td" -include "xla/python/ifrt/ir/vifrt_dialect.td" -include "xla/python/ifrt/ir/vifrt_types.td" - -def Vifrt_VersionedAttrInterface : AttrInterface<"VifrtVersionedAttrInterface"> { - let cppNamespace = "::xla::ifrt"; - let methods = [ - InterfaceMethod< - "The min version of the VIFRT dialect an attribute is supported in.", - "::xla::ifrt::Version", "getMinVersion">, - InterfaceMethod< - "The maxi version (inclusive) of the VIFRT dialect an attribute is supported in.", - "::xla::ifrt::Version", "getMaxVersion">, - ]; -} - -class Vifrt_AttrDef - : AttrDef { - let extraClassDeclaration = [{ - ::xla::ifrt::Version getMinVersion() { - return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{); - } - ::xla::ifrt::Version getMaxVersion() { - }] # !if( - !eq(max_version, "current"), - [{ return ::xla::ifrt::Version::getCurrentVersion(); }], - [{ return ::xla::ifrt::Version("}] # !subst(".", ", ", max_version) # [{"); }] - ) # [{ - } - }]; -} - -def Vifrt_DevicesAttrV1 : Vifrt_AttrDef<"VifrtDevicesV1", "0.1.0", "current"> { - let mnemonic = "devices_v1"; - let parameters = (ins ArrayRefParameter<"int">:$ids); - let assemblyFormat = "`[` $ids `]`"; -} - -def Vifrt_UnspecifiedShardingAttrV1 - : Vifrt_AttrDef<"VifrtUnspecifiedShardingV1", "0.1.0", "current"> { - let mnemonic = "sharding_unspecified_v1"; - let parameters = (ins); - let assemblyFormat = ""; -} - -// TODO(icgog): Introduce versioned ShardingParamV1. -def Vifrt_ShardingParamAttrV1 : Vifrt_AttrDef<"VifrtShardingParamV1", "0.1.0", - "current"> { - let mnemonic = "sharding_param_v1"; -} - -def Vifrt_IntervalAttrV1 : Vifrt_AttrDef<"VifrtIntervalV1", "0.1.0", - "current"> { - let mnemonic = "interval_v1"; - let parameters = (ins "int":$start, "int":$end, "int":$step); - let assemblyFormat = "`[`$start `:` $end `:` $step`]`"; -} - -def Vifrt_MappingAttrV1 : Vifrt_AttrDef<"VifrtMappingV1", "0.1.0", "current"> { - let mnemonic = "mapping_v1"; - let parameters = (ins - Vifrt_IntervalAttrV1:$from_shards, - Vifrt_IntervalAttrV1:$to_shards); - let assemblyFormat = "`<` $from_shards `to` $to_shards `>`"; -} - -def Vifrt_GenericArrayAttrV1 : Vifrt_AttrDef<"VifrtGenericArrayAttrV1", - "0.1.0", "current"> { - let mnemonic = "generic_array_attr_v1"; - let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value); - let genVerifyDecl = 1; - let extraClassDefinition = [{ - mlir::LogicalResult VifrtGenericArrayAttrV1Attr::verify( - llvm::function_ref err_fn, - llvm::ArrayRef value) { - if (!allFromVifrt(value)) return err_fn() << "expected array of VIFRT attributes"; - return mlir::success(); - } - }]; - let assemblyFormat = "`<` custom($value) `>`"; -} - -def Vifrt_ArrayMappingAttrV1 : Vifrt_AttrDef<"VifrtArrayMappingV1", "0.1.0", - "current"> { - let mnemonic = "array_mapping_v1"; - let parameters = (ins - "int32_t":$in_array_index, - "int32_t":$out_array_index, - Vifrt_GenericArrayAttrV1:$mappings); - let assemblyFormat = "`<`$in_array_index`,` $out_array_index`,` $mappings`>`"; -} - -def Vifrt_IoAliasesAttrV1 : Vifrt_AttrDef<"IfrtIoAliasesV1", "0.1.0", "current"> { - let mnemonic = "io_aliases_v1"; - let parameters = (ins - Vifrt_GenericArrayAttrV1:$io_aliases); - let assemblyFormat = "`<` $io_aliases `>`"; -} - -// TODO(icgog): Introduce Vifrt_MappingAttrArrayAttrV1, -// Vifrt_ArrayMappingAttrArrayAttrV1. - -#endif // XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_base.td b/xla/python/ifrt/ir/vifrt_base.td deleted file mode 100644 index d92bea913c133..0000000000000 --- a/xla/python/ifrt/ir/vifrt_base.td +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_ -#define XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_ - -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/OpBase.td" - -// VIFRT represents the layout only. Therefore, it uses AnyType everywhere. -def Vifrt_AnyType : AnyTypeOf<[AnyType]>; -def Vifrt_AnyAttr : AnyAttrOf<[AnyAttr]>; -def Vifrt_AnyRegion : Region, "any region">; - -#endif // XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_ops.cc b/xla/python/ifrt/ir/vifrt_dialect.cc similarity index 75% rename from xla/python/ifrt/ir/vifrt_ops.cc rename to xla/python/ifrt/ir/vifrt_dialect.cc index 8aedd0ab40bbc..02f99e2e41eaa 100644 --- a/xla/python/ifrt/ir/vifrt_ops.cc +++ b/xla/python/ifrt/ir/vifrt_dialect.cc @@ -13,23 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/ir/vifrt_ops.h" +#include "xla/python/ifrt/ir/vifrt_dialect.h" #include #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" // IWYU pragma: export #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Types.h" -#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" -#include "xla/python/ifrt/ir/vifrt_types.h" namespace xla { namespace ifrt { @@ -37,6 +36,7 @@ namespace ifrt { namespace { // Verifies if a given type or attribute is from VIFRT dialect. +// Must be defined before importing the generated type interfaces and defs. template bool isFromVifrt(TypeOrAttr t) { return t.getDialect().getNamespace() == VifrtDialect::getDialectNamespace(); @@ -47,8 +47,6 @@ bool allFromVifrt(llvm::ArrayRef range) { return llvm::all_of(range, isFromVifrt); } -} // namespace - // Helper functions for VIFRT printers and parsers. static void printAttributeArray(mlir::AsmPrinter& os, llvm::ArrayRef arrayAttr) { @@ -66,57 +64,45 @@ mlir::ParseResult parseAttributeArray( return mlir::success(); } +} // namespace + } // namespace ifrt } // namespace xla +// Attributes #include "xla/python/ifrt/ir/vifrt_attr_interfaces.cc.inc" #define GET_ATTRDEF_CLASSES #include "xla/python/ifrt/ir/vifrt_attrs.cc.inc" +// Types +#include "xla/python/ifrt/ir/vifrt_type_interfaces.cc.inc" +#define GET_TYPEDEF_CLASSES +#include "xla/python/ifrt/ir/vifrt_types.cc.inc" +// Ops #include "xla/python/ifrt/ir/vifrt_op_interfaces.cc.inc" #define GET_OP_CLASSES #include "xla/python/ifrt/ir/vifrt_ops.cc.inc" -namespace xla { -namespace ifrt { - //===----------------------------------------------------------------------===// -// VIFRT Dialect Constructor +// VIFRT Dialect //===----------------------------------------------------------------------===// +namespace xla { +namespace ifrt { VifrtDialect::VifrtDialect(mlir::MLIRContext* context) : mlir::Dialect(getDialectNamespace(), context, mlir::TypeID::get()) { - addOperations< -#define GET_OP_LIST -#include "xla/python/ifrt/ir/vifrt_ops.cc.inc" - >(); - addVifrtTypes(); addAttributes< #define GET_ATTRDEF_LIST #include "xla/python/ifrt/ir/vifrt_attrs.cc.inc" >(); -} - -void VifrtDialect::addVifrtTypes() { - // Following the same solution as in VHLO; Idiomatically, this functionality - // is expressed as shown below: - // addTypes< - // #define GET_TYPEDEF_LIST - // #include - // "third_party/tensorflow/compiler/xla/python/ifrt/ir/vifrt_type_defs.cc.inc" - // >(); - // - // However, Dialect::addTypes doesn't work for our situation where we want to - // decouple the vifrt_ops and vifrt_types targets because - // vifrt_type_defs.h.inc only includes forward declarations of `TypeStorage` - // structs, and that's not sufficient for Dialect::addTypes to compile. - // Therefore, we introduce this function and then reimplementing - // Dialect::addTypes as shown below. - addTypesWithoutRegistering< + addTypes< #define GET_TYPEDEF_LIST -#include "xla/python/ifrt/ir/vifrt_type_defs.cc.inc" +#include "xla/python/ifrt/ir/vifrt_types.cc.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "xla/python/ifrt/ir/vifrt_ops.cc.inc" >(); - registerVifrtTypes(getContext()); } mlir::Type VifrtDialect::parseType(mlir::DialectAsmParser& parser) const { @@ -158,5 +144,29 @@ void VifrtDialect::printAttribute(mlir::Attribute attr, assert(mlir::succeeded(result)); } +//===----------------------------------------------------------------------===// +// VIFRT Type Converter to/from Builtin +//===----------------------------------------------------------------------===// + +void VifrtTypeConverterBuiltin::addBuiltinToVifrtConversions() { + // We currently rely on the builtin types being stable, and thus we do not + // convert builtin types to VIFRT types. +} + +void VifrtTypeConverterBuiltin::addVifrtToBuiltinConversions() { + // We currently rely on the builtin types are stable, and thus we do not + // convert from VIFRT types to builtin types. +} + +mlir::LogicalResult printVifrtType(mlir::Type type, mlir::AsmPrinter& printer) { + return generatedTypePrinter(type, printer); +} + +mlir::OptionalParseResult parseVifrtType(mlir::AsmParser& parser, + llvm::StringRef* mnemonic, + mlir::Type& type) { + return generatedTypeParser(parser, mnemonic, type); +} + } // namespace ifrt } // namespace xla diff --git a/xla/python/ifrt/ir/vifrt_ops.h b/xla/python/ifrt/ir/vifrt_dialect.h similarity index 55% rename from xla/python/ifrt/ir/vifrt_ops.h rename to xla/python/ifrt/ir/vifrt_dialect.h index 218aed758ca6b..9d54dcd6c63fb 100644 --- a/xla/python/ifrt/ir/vifrt_ops.h +++ b/xla/python/ifrt/ir/vifrt_dialect.h @@ -13,17 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_IFRT_IR_VIFRT_OPS_H_ -#define XLA_PYTHON_IFRT_IR_VIFRT_OPS_H_ +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_H_ +#define XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_H_ #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeSupport.h" #include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "xla/python/ifrt/ir/sharding_param.h" // IWYU pragma: export #include "xla/python/ifrt/ir/version.h" // IWYU pragma: export -#include "xla/python/ifrt/ir/vifrt_types.h" // IWYU pragma: export namespace xla { namespace ifrt { @@ -47,31 +50,50 @@ class VifrtDialect : public mlir::Dialect { // Prints an attribute registered in the VIFRT dialect. void printAttribute(mlir::Attribute attr, mlir::DialectAsmPrinter &os) const override; +}; + +class VifrtTypeConverterBase : public mlir::TypeConverter { + public: + VifrtTypeConverterBase() : mlir::TypeConverter() {}; + + ~VifrtTypeConverterBase() override = default; +}; - private: - // Adds VIFRT types to this dialect. - // See implementation comment for additional details. - void addVifrtTypes(); - - // Does the same this as Dialect::addTypes but without calling `registerType`. - // See comments for `addVifrtTypes` for additional details. - template - void addTypesWithoutRegistering() { - (addType(Types::getTypeID(), mlir::AbstractType::get(*this)), ...); - } +// Class used to manage conversions between VIFRT and Builtin types. +class VifrtTypeConverterBuiltin : public VifrtTypeConverterBase { + public: + // A subclass can call this method to add conversions from VIFRT to Builtin + // types. Conversions are applied in reverse order, with the most recently + // added conversion attempted to be applied first. + void addVifrtToBuiltinConversions(); + + // A subclass can call this method to add conversions from Builtin to VIFRT + // types. Conversions are applied in reverse order, with the most recently + // added conversion attempted to be applied first. + void addBuiltinToVifrtConversions(); }; +// Auto-generated VIFRT type printers and parsers. +mlir::LogicalResult printVifrtType(mlir::Type type, mlir::AsmPrinter &printer); +mlir::OptionalParseResult parseVifrtType(mlir::AsmParser &parser, + llvm::StringRef *mnemonic, + mlir::Type &type); + } // namespace ifrt } // namespace xla -// Attrs +// Generated definitions. +// Attributes #include "xla/python/ifrt/ir/vifrt_attr_interfaces.h.inc" #define GET_ATTRDEF_CLASSES #include "xla/python/ifrt/ir/vifrt_attrs.h.inc" - +// Types +#include "xla/python/ifrt/ir/vifrt_type_interfaces.h.inc" +#define GET_TYPEDEF_CLASSES +#include "xla/python/ifrt/ir/vifrt_types.h.inc" // Ops #include "xla/python/ifrt/ir/vifrt_op_interfaces.h.inc" #define GET_OP_CLASSES #include "xla/python/ifrt/ir/vifrt_ops.h.inc" -#endif // XLA_PYTHON_IFRT_IR_VIFRT_OPS_H_ +#endif // XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_H_ diff --git a/xla/python/ifrt/ir/vifrt_dialect.td b/xla/python/ifrt/ir/vifrt_dialect.td index 9aa84e19dbeff..19873884243f7 100644 --- a/xla/python/ifrt/ir/vifrt_dialect.td +++ b/xla/python/ifrt/ir/vifrt_dialect.td @@ -16,6 +16,14 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_ #define XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_ +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/DialectBase.td" +include "xla/python/ifrt/ir/vifrt_interfaces.td" + +//===--------------------------------------------------------------------------- +// Dialect +//===--------------------------------------------------------------------------- + def Vifrt_Dialect : Dialect { let name = "vifrt"; let summary = "Versioned IFRT dialect"; @@ -34,4 +42,131 @@ def Vifrt_Dialect : Dialect { let usePropertiesForAttributes = 1; } +//===--------------------------------------------------------------------------- +// Attributes +//===--------------------------------------------------------------------------- + +class Vifrt_AttrDef + : AttrDef { + let extraClassDeclaration = [{ + ::xla::ifrt::Version getMinVersion() { + return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{); + } + ::xla::ifrt::Version getMaxVersion() { + }] # !if( + !eq(max_version, "current"), + [{ return ::xla::ifrt::Version::getCurrentVersion(); }], + [{ return ::xla::ifrt::Version("}] # !subst(".", ", ", max_version) # [{"); }] + ) # [{ + } + }]; +} + +def Vifrt_DevicesAttrV1 : Vifrt_AttrDef<"VifrtDevicesV1", "0.1.0", "current"> { + let mnemonic = "devices_v1"; + let parameters = (ins ArrayRefParameter<"int">:$ids); + let assemblyFormat = "`[` $ids `]`"; +} + +def Vifrt_UnspecifiedShardingAttrV1 + : Vifrt_AttrDef<"VifrtUnspecifiedShardingV1", "0.1.0", "current"> { + let mnemonic = "sharding_unspecified_v1"; + let parameters = (ins); + let assemblyFormat = ""; +} + +def Vifrt_ShardingParameterV1 : + AttrOrTypeParameter<"::xla::ifrt::ShardingParam", ""> { + let parser = "::xla::ifrt::ShardingParam::ParseV1($_parser)"; + let printer = "::xla::ifrt::ShardingParam::PrintV1($_printer, $_self)"; +} + +def Vifrt_ShardingParamAttrV1 : Vifrt_AttrDef<"VifrtShardingParamV1", "0.1.0", + "current"> { + let mnemonic = "sharding_param_v1"; + let parameters = (ins Vifrt_ShardingParameterV1:$sharding); + let assemblyFormat = "`<` $sharding `>`"; +} + +def Vifrt_IntervalAttrV1 : Vifrt_AttrDef<"VifrtIntervalV1", "0.1.0", + "current"> { + let mnemonic = "interval_v1"; + let parameters = (ins "int":$start, "int":$end, "int":$step); + let assemblyFormat = "`[`$start `:` $end `:` $step`]`"; +} + +def Vifrt_MappingAttrV1 : Vifrt_AttrDef<"VifrtMappingV1", "0.1.0", "current"> { + let mnemonic = "mapping_v1"; + let parameters = (ins + Vifrt_IntervalAttrV1:$from_shards, + Vifrt_IntervalAttrV1:$to_shards); + let assemblyFormat = "`<` $from_shards `to` $to_shards `>`"; +} + +// Equivalent to `mlir::ArrayAttr`, but with VIFRT verification. +// This can be used to represent `Ifrt_MappingAttrArrayAttr`, +// `Ifrt_ArrayMappingAttrArrayAttr` and `Ifrt_IoAliasesAttr`, which are just +// arrays. +def Vifrt_GenericArrayAttrV1 : Vifrt_AttrDef<"VifrtGenericArrayAttrV1", + "0.1.0", "current"> { + let mnemonic = "generic_array_attr_v1"; + let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value); + let genVerifyDecl = 1; + let extraClassDefinition = [{ + mlir::LogicalResult VifrtGenericArrayAttrV1Attr::verify( + llvm::function_ref err_fn, + llvm::ArrayRef value) { + if (!allFromVifrt(value)) return err_fn() << "expected array of VIFRT attributes"; + return mlir::success(); + } + }]; + let assemblyFormat = "`<` custom($value) `>`"; +} + +def Vifrt_ArrayMappingAttrV1 : Vifrt_AttrDef<"VifrtArrayMappingV1", "0.1.0", + "current"> { + let mnemonic = "array_mapping_v1"; + let parameters = (ins + "int32_t":$in_array_index, + "int32_t":$out_array_index, + Vifrt_GenericArrayAttrV1:$mappings); + let assemblyFormat = "`<`$in_array_index`,` $out_array_index`,` $mappings`>`"; +} + +//===--------------------------------------------------------------------------- +// Types +//===--------------------------------------------------------------------------- + +class Vifrt_TypeDef + : TypeDef { + let mnemonic = name; + let extraClassDeclaration = [{ + ::xla::ifrt::Version getMinVersion() { + return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{); + } + ::xla::ifrt::Version getMaxVersion() { + }] # !if( + !eq(max_version, "current"), + [{ return ::xla::ifrt::Version::getCurrentVersion(); }], + [{ return ::xla::ifrt::Version("}] # !subst(".", ", ", max_version) # [{"); }] + ) # [{ + } + }]; +} + +def Vifrt_Array : Vifrt_TypeDef<"VifrtArrayV1", "array_v1", "0.1.0", "current"> { + let parameters = (ins + "::mlir::RankedTensorType":$shape, + "::mlir::Attribute":$sharding_attr, + Vifrt_DevicesAttrV1:$devices_attr, + "::mlir::StringAttr":$memory_kind_attr); + + let assemblyFormat = [{ + `<` $shape`,` $sharding_attr `,` $devices_attr `,` `memory_kind` `=` + $memory_kind_attr `>` + }]; +} + +def Vifrt_Control : Vifrt_TypeDef<"VifrtControlV1", "control_v1", "0.1.0", "current">; + #endif // XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_interfaces.td b/xla/python/ifrt/ir/vifrt_interfaces.td new file mode 100644 index 0000000000000..dfaa90df292cf --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_interfaces.td @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_INTERFACES_TD_ +#define XLA_PYTHON_IFRT_IR_VIFRT_INTERFACES_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +// VIFRT represents the layout only. Therefore, it uses AnyType everywhere. +def Vifrt_AnyType : AnyTypeOf<[AnyType]>; +def Vifrt_AnyAttr : AnyAttrOf<[AnyAttr]>; +def Vifrt_AnyRegion : Region, "any region">; + +//===--------------------------------------------------------------------------- +// Op interfaces +//===--------------------------------------------------------------------------- + +def Vifrt_VersionedOpInterface : OpInterface<"VifrtVersionedOpInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + "The min version of the VIFRT dialect an op is supported in.", + "::xla::ifrt::Version", "getMinVersion">, + InterfaceMethod< + "The max version (inclusive) of the VIFRT dialect an op is supported in.", + "::xla::ifrt::Version", "getMaxVersion">, + ]; +} + +def Vifrt_VersionedOpConstraintInterface : OpInterface<"VifrtVersionedOpConstraintInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + [{Validates versioned constraints on a versioned op. + Used if the specified constraints of an op change over time.}], + "mlir::LogicalResult", "validateConstraint", + (ins "mlir::Operation*":$op, "::xla::ifrt::Version":$targetVersion)>, + ]; +} + +//===--------------------------------------------------------------------------- +// Attribute interfaces +//===--------------------------------------------------------------------------- + +def Vifrt_VersionedAttrInterface : AttrInterface<"VifrtVersionedAttrInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + "The min version of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMinVersion">, + InterfaceMethod< + "The maxi version (inclusive) of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMaxVersion">, + ]; +} + +//===--------------------------------------------------------------------------- +// Type interfaces +//===--------------------------------------------------------------------------- + +def Vifrt_VersionedTypeInterface : TypeInterface<"VifrtVersionedTypeInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + "The min version of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMinVersion">, + InterfaceMethod< + "The max version (inclusive) of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMaxVersion">, + ]; +} + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_INTERFACES_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_ops.td b/xla/python/ifrt/ir/vifrt_ops.td index 6f1c9044ab967..c0b67c3f5d276 100644 --- a/xla/python/ifrt/ir/vifrt_ops.td +++ b/xla/python/ifrt/ir/vifrt_ops.td @@ -17,35 +17,17 @@ limitations under the License. #define XLA_PYTHON_IFRT_IR_VIFRT_OPS_TD_ include "mlir/IR/OpBase.td" -include "mlir/IR/BuiltinTypes.td" -include "mlir/IR/SymbolInterfaces.td" -include "mlir/Interfaces/CallInterfaces.td" include "xla/python/ifrt/ir/vifrt_dialect.td" -include "xla/python/ifrt/ir/vifrt_types.td" -include "xla/python/ifrt/ir/vifrt_attrs.td" - -def Vifrt_VersionedOpInterface : OpInterface<"VifrtVersionedOpInterface"> { - let cppNamespace = "::xla::ifrt"; - let methods = [ - InterfaceMethod< - "The min version of the VIFRT dialect an op is supported in.", - "::xla::ifrt::Version", "getMinVersion">, - InterfaceMethod< - "The max version (inclusive) of the VIFRT dialect an op is supported in.", - "::xla::ifrt::Version", "getMaxVersion">, - ]; -} +include "xla/python/ifrt/ir/vifrt_interfaces.td" -def Vifrt_VersionedOpConstraintInterface : OpInterface<"VifrtVersionedOpConstraintInterface"> { - let cppNamespace = "::xla::ifrt"; - let methods = [ - InterfaceMethod< - [{Validates versioned constraints on a versioned op. - Used if the specified constraints of an op change over time.}], - "mlir::LogicalResult", "validateConstraint", - (ins "mlir::Operation*":$op, "::xla::ifrt::Version":$targetVersion)>, - ]; -} +// VIFRT is a bare mininum versioned copy of IFRT. In the pursuit of minimality, +// it adopts the following conventions (similar to VHLO): +// 1) Use Vifrt_AnyType or Variading for all operands and +// results. +// 2) Use traits only when strictly necessary (e.g., `AttrSizedOperandSegments` +// when multiple variable operands are present). +// 3) Use Vifrt_AnyAttr for all attributes. +// 4) Don't use verifiers. class Vifrt_Op traits = []> : Op for all operands and -// results. -// 2) Use traits only when strictly necessary (e.g., `AttrSizedOperandSegments` -// when multiple variable operands are present). -// 3) Use Vifrt_AnyAttr for all attributes. -// 4) Don't use verifiers. - def Vifrt_ReshardOpV1 : Vifrt_Op<"ReshardV1", "0.1.0", "current", [AttrSizedOperandSegments]> { let arguments = (ins diff --git a/xla/python/ifrt/ir/vifrt_types.cc b/xla/python/ifrt/ir/vifrt_types.cc deleted file mode 100644 index 5fb1c7d7e16da..0000000000000 --- a/xla/python/ifrt/ir/vifrt_types.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/ifrt/ir/vifrt_types.h" - -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: export -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/Types.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "stablehlo/dialect/AssemblyFormat.h" // IWYU pragma: export - -namespace xla { -namespace ifrt { - -void VifrtTypeConverterBuiltin::addBuiltinToVifrtConversions() { - // We currently rely on the builtin types being stable, and thus we do not - // convert builtin types to VIFRT types. -} - -void VifrtTypeConverterBuiltin::addVifrtToBuiltinConversions() { - // We currently rely on the builtin types are stable, and thus we do not - // convert from VIFRT types to builtin types. -} - -namespace { - -// Verifies if a given type or attribute is from VIFRT dialect. -// Must be defined before importing the generated type interfaces and defs. -template -bool isFromVifrt(TypeOrAttr t) { - return t.getDialect().getNamespace() == "vifrt"; -} - -} // namespace - -} // namespace ifrt -} // namespace xla - -// Include order matters. -#include "xla/python/ifrt/ir/vifrt_type_interfaces.cc.inc" -#define GET_TYPEDEF_CLASSES -#include "xla/python/ifrt/ir/vifrt_type_defs.cc.inc" - -namespace xla { -namespace ifrt { - -mlir::LogicalResult printVifrtType(mlir::Type type, mlir::AsmPrinter& printer) { - return generatedTypePrinter(type, printer); -} - -mlir::OptionalParseResult parseVifrtType(mlir::AsmParser& parser, - llvm::StringRef* mnemonic, - mlir::Type& type) { - return generatedTypeParser(parser, mnemonic, type); -} - -namespace { -template -void registerVifrtTypes(mlir::MLIRContext* context) { - (mlir::detail::TypeUniquer::registerType(context), ...); -} -} // namespace - -void registerVifrtTypes(mlir::MLIRContext* context) { - registerVifrtTypes< -#define GET_TYPEDEF_LIST -#include "xla/python/ifrt/ir/vifrt_type_defs.cc.inc" - >(context); -} - -} // namespace ifrt -} // namespace xla diff --git a/xla/python/ifrt/ir/vifrt_types.h b/xla/python/ifrt/ir/vifrt_types.h deleted file mode 100644 index d24d7b17eae32..0000000000000 --- a/xla/python/ifrt/ir/vifrt_types.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_IFRT_IR_VIFRT_TYPES_H_ -#define XLA_PYTHON_IFRT_IR_VIFRT_TYPES_H_ - -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Transforms/DialectConversion.h" -#include "xla/python/ifrt/ir/version.h" // IWYU pragma: export - -namespace xla { -namespace ifrt { - -class VifrtTypeConverterBase : public mlir::TypeConverter { - public: - VifrtTypeConverterBase() : mlir::TypeConverter() {}; - - ~VifrtTypeConverterBase() override = default; -}; - -// Class used to manage conversions between VIFRT and Builtin types. -class VifrtTypeConverterBuiltin : public VifrtTypeConverterBase { - public: - // A subclass can call this method to add conversions from VIFRT to Builtin - // types. Conversions are applied in reverse order, with the most recently - // added conversion attempted to be applied first. - void addVifrtToBuiltinConversions(); - - // A subclass can call this method to add conversions from Builtin to VIFRT - // types. Conversions are applied in reverse order, with the most recently - // added conversion attempted to be applied first. - void addBuiltinToVifrtConversions(); -}; - -// Auto-generated VIFRT type printers and parsers. -mlir::LogicalResult printVifrtType(mlir::Type type, mlir::AsmPrinter& printer); -mlir::OptionalParseResult parseVifrtType(mlir::AsmParser& parser, - llvm::StringRef* mnemonic, - mlir::Type& type); - -// Registers VIFRT types in a given MLIR context. -void registerVifrtTypes(mlir::MLIRContext* context); - -} // namespace ifrt -} // namespace xla - -#include "xla/python/ifrt/ir/vifrt_type_interfaces.h.inc" -#define GET_TYPEDEF_CLASSES -#include "xla/python/ifrt/ir/vifrt_type_defs.h.inc" - -#endif // XLA_PYTHON_IFRT_IR_VIFRT_TYPES_H_ diff --git a/xla/python/ifrt/ir/vifrt_types.td b/xla/python/ifrt/ir/vifrt_types.td deleted file mode 100644 index e327a64dfe60c..0000000000000 --- a/xla/python/ifrt/ir/vifrt_types.td +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_IFRT_IR_VIFRT_TYPES_TD_ -#define XLA_PYTHON_IFRT_IR_VIFRT_TYPES_TD_ - -include "mlir/IR/AttrTypeBase.td" -include "xla/python/ifrt/ir/vifrt_base.td" -include "xla/python/ifrt/ir/vifrt_dialect.td" - -def Vifrt_VersionedTypeInterface : TypeInterface<"VifrtVersionedTypeInterface"> { - let cppNamespace = "::xla::ifrt"; - let methods = [ - InterfaceMethod< - "The min version of the VIFRT dialect an attribute is supported in.", - "::xla::ifrt::Version", "getMinVersion">, - InterfaceMethod< - "The max version (inclusive) of the VIFRT dialect an attribute is supported in.", - "::xla::ifrt::Version", "getMaxVersion">, - ]; -} - -class Vifrt_TypeDef - : TypeDef { - let mnemonic = name; - let extraClassDeclaration = [{ - ::xla::ifrt::Version getMinVersion() { - return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{); - } - ::xla::ifrt::Version getMaxVersion() { - }] # !if( - !eq(max_version, "current"), - [{ return ::xla::ifrt::Version::getCurrentVersion(); }], - [{ return ::xla::ifrt::Version("}] # !subst(".", ", ", max_version) # [{"); }] - ) # [{ - } - }]; -} - -def Vifrt_Array : Vifrt_TypeDef<"VifrtArrayV1", "array_v1", "0.1.0", "current"> { - let parameters = (ins - "::mlir::Type":$shape, - "::mlir::Attribute":$sharding_attr, - "::mlir::Attribute":$devices_attr, - "::mlir::Attribute":$memory_kind_attr); - - let genVerifyDecl = 1; - // The verifier does not check `shape` and `memory_kind_attr` because they - // use builtin types `mlir::RankedTensor` and `mlir::StringAttr`. - let extraClassDefinition = [{ - ::llvm::LogicalResult VifrtArrayV1Type::verify( - ::llvm::function_ref<::mlir::InFlightDiagnostic ()> err_fn, - ::mlir::Type shape, - ::mlir::Attribute sharding_attr, - ::mlir::Attribute devices_attr, - ::mlir::Attribute memory_kind_attr - ) { - if (!isFromVifrt(sharding_attr) || - !isFromVifrt(devices_attr)) { - return err_fn() << "expected VIFRT type or attribute"; - } - return mlir::success(); - } - }]; - let assemblyFormat = [{ - `<` $shape`,` $sharding_attr `,` $devices_attr `,` `memory_kind` `=` - $memory_kind_attr `>` - }]; -} - -def Vifrt_Control : Vifrt_TypeDef<"VifrtControlV1", "control_v1", "0.1.0", "current">; - - -#endif // XLA_PYTHON_IFRT_IR_VIFRT_TYPES_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/support/BUILD b/xla/python/ifrt/support/BUILD index c26bb2a5efd73..a5a5494e0c801 100644 --- a/xla/python/ifrt/support/BUILD +++ b/xla/python/ifrt/support/BUILD @@ -16,6 +16,7 @@ cc_library( "//xla/mlir/utils:error_util", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/python/ifrt/ir", + "//xla/python/ifrt/ir:vifrt", "//xla/python/ifrt/ir/transforms:built_in_spmd_expansions", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/xla/python/ifrt/support/module_parsing.cc b/xla/python/ifrt/support/module_parsing.cc index 7a0a8fa7006f3..b1740cd5cf0ca 100644 --- a/xla/python/ifrt/support/module_parsing.cc +++ b/xla/python/ifrt/support/module_parsing.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h" +#include "xla/python/ifrt/ir/vifrt_dialect.h" namespace xla { namespace ifrt { @@ -38,6 +39,7 @@ namespace support { void InitializeMlirDialectRegistry(mlir::DialectRegistry& registry) { registry.insert(); + registry.insert(); mlir::registerAllDialects(registry); mlir::func::registerAllExtensions(registry); mlir::mhlo::registerAllMhloDialects(registry);