diff --git a/xla/python/ifrt/ir/BUILD b/xla/python/ifrt/ir/BUILD index 47654c5b33bde9..9ae96c068bc2e7 100644 --- a/xla/python/ifrt/ir/BUILD +++ b/xla/python/ifrt/ir/BUILD @@ -314,3 +314,187 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +gentbl_cc_library( + name = "vifrt_attr_interfaces_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-attr-interface-decls"], + "vifrt_attr_interfaces.h.inc", + ), + ( + ["-gen-attr-interface-defs"], + "vifrt_attr_interfaces.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vifrt_attrs.td", + test = True, + deps = [":vifrt_ops_td_files"], +) + +gentbl_cc_library( + name = "vifrt_attrs_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + [ + "-gen-attrdef-decls", + "--attrdefs-dialect=vifrt", + ], + "vifrt_attrs.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + "--attrdefs-dialect=vifrt", + ], + "vifrt_attrs.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vifrt_ops.td", + test = True, + deps = [":vifrt_ops_td_files"], +) + +gentbl_cc_library( + name = "vifrt_op_interfaces_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "vifrt_op_interfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "vifrt_op_interfaces.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vifrt_ops.td", + deps = [":vifrt_ops_td_files"], +) + +cc_library( + name = "vifrt_ops", + srcs = [ + "vifrt_ops.cc", + ], + hdrs = [ + "vifrt_ops.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":version", + ":vifrt_attr_interfaces_inc_gen", + ":vifrt_attrs_inc_gen", + ":vifrt_op_interfaces_inc_gen", + ":vifrt_ops_inc_gen", + ":vifrt_types", + ":vifrt_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +gentbl_cc_library( + name = "vifrt_ops_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-op-decls"], + "vifrt_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "vifrt_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vifrt_ops.td", + test = True, + deps = [":vifrt_ops_td_files"], +) + +td_library( + name = "vifrt_ops_td_files", + srcs = [ + "vifrt_attrs.td", + "vifrt_base.td", + "vifrt_dialect.td", + "vifrt_ops.td", + "vifrt_types.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:ShapeOpsTdFiles", + ], +) + +cc_library( + name = "vifrt_types", + srcs = ["vifrt_types.cc"], + hdrs = ["vifrt_types.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":version", + ":vifrt_type_interfaces_inc_gen", + ":vifrt_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_assembly_format", + ], +) + +gentbl_cc_library( + name = "vifrt_type_interfaces_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-type-interface-decls"], + "vifrt_type_interfaces.h.inc", + ), + ( + ["-gen-type-interface-defs"], + "vifrt_type_interfaces.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vifrt_types.td", + deps = [ + ":vifrt_ops_td_files", + ], +) + +gentbl_cc_library( + name = "vifrt_types_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + "--typedefs-dialect=vifrt", + ], + "vifrt_type_defs.h.inc", + ), + ( + [ + "-gen-typedef-defs", + "--typedefs-dialect=vifrt", + ], + "vifrt_type_defs.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "vifrt_ops.td", + deps = [ + ":vifrt_ops_td_files", + ], +) diff --git a/xla/python/ifrt/ir/vifrt_attrs.td b/xla/python/ifrt/ir/vifrt_attrs.td new file mode 100644 index 00000000000000..8eec47d172c425 --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_attrs.td @@ -0,0 +1,122 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_ +#define XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "xla/python/ifrt/ir/vifrt_base.td" +include "xla/python/ifrt/ir/vifrt_dialect.td" +include "xla/python/ifrt/ir/vifrt_types.td" + +def Vifrt_VersionedAttrInterface : AttrInterface<"VifrtVersionedAttrInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + "The min version of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMinVersion">, + InterfaceMethod< + "The maxi version (inclusive) of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMaxVersion">, + ]; +} + +class Vifrt_AttrDef + : AttrDef { + let extraClassDeclaration = [{ + ::xla::ifrt::Version getMinVersion() { + return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{); + } + ::xla::ifrt::Version getMaxVersion() { + }] # !if( + !eq(max_version, "current"), + [{ return ::xla::ifrt::Version::getCurrentVersion(); }], + [{ return ::xla::ifrt::Version("}] # !subst(".", ", ", max_version) # [{"); }] + ) # [{ + } + }]; +} + +def Vifrt_DevicesAttrV1 : Vifrt_AttrDef<"VifrtDevicesV1", "0.1.0", "current"> { + let mnemonic = "devices_v1"; + let parameters = (ins ArrayRefParameter<"int">:$ids); + let assemblyFormat = "`[` $ids `]`"; +} + +def Vifrt_UnspecifiedShardingAttrV1 + : Vifrt_AttrDef<"VifrtUnspecifiedShardingV1", "0.1.0", "current"> { + let mnemonic = "sharding_unspecified_v1"; + let parameters = (ins); + let assemblyFormat = ""; +} + +// TODO(icgog): Introduce versioned ShardingParamV1. +def Vifrt_ShardingParamAttrV1 : Vifrt_AttrDef<"VifrtShardingParamV1", "0.1.0", + "current"> { + let mnemonic = "sharding_param_v1"; +} + +def Vifrt_IntervalAttrV1 : Vifrt_AttrDef<"VifrtIntervalV1", "0.1.0", + "current"> { + let mnemonic = "interval_v1"; + let parameters = (ins "int":$start, "int":$end, "int":$step); + let assemblyFormat = "`[`$start `:` $end `:` $step`]`"; +} + +def Vifrt_MappingAttrV1 : Vifrt_AttrDef<"VifrtMappingV1", "0.1.0", "current"> { + let mnemonic = "mapping_v1"; + let parameters = (ins + Vifrt_IntervalAttrV1:$from_shards, + Vifrt_IntervalAttrV1:$to_shards); + let assemblyFormat = "`<` $from_shards `to` $to_shards `>`"; +} + +def Vifrt_GenericArrayAttrV1 : Vifrt_AttrDef<"VifrtGenericArrayAttrV1", + "0.1.0", "current"> { + let mnemonic = "generic_array_attr_v1"; + let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value); + let genVerifyDecl = 1; + let extraClassDefinition = [{ + mlir::LogicalResult VifrtGenericArrayAttrV1Attr::verify( + llvm::function_ref err_fn, + llvm::ArrayRef value) { + if (!allFromVifrt(value)) return err_fn() << "expected array of VIFRT attributes"; + return mlir::success(); + } + }]; + let assemblyFormat = "`<` custom($value) `>`"; +} + +def Vifrt_ArrayMappingAttrV1 : Vifrt_AttrDef<"VifrtArrayMappingV1", "0.1.0", + "current"> { + let mnemonic = "array_mapping_v1"; + let parameters = (ins + "int32_t":$in_array_index, + "int32_t":$out_array_index, + Vifrt_GenericArrayAttrV1:$mappings); + let assemblyFormat = "`<`$in_array_index`,` $out_array_index`,` $mappings`>`"; +} + +def Vifrt_IoAliasesAttrV1 : Vifrt_AttrDef<"IfrtIoAliasesV1", "0.1.0", "current"> { + let mnemonic = "io_aliases_v1"; + let parameters = (ins + Vifrt_GenericArrayAttrV1:$io_aliases); + let assemblyFormat = "`<` $io_aliases `>`"; +} + +// TODO(icgog): Introduce Vifrt_MappingAttrArrayAttrV1, +// Vifrt_ArrayMappingAttrArrayAttrV1. + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_base.td b/xla/python/ifrt/ir/vifrt_base.td new file mode 100644 index 00000000000000..d92bea913c1336 --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_base.td @@ -0,0 +1,27 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_ +#define XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +// VIFRT represents the layout only. Therefore, it uses AnyType everywhere. +def Vifrt_AnyType : AnyTypeOf<[AnyType]>; +def Vifrt_AnyAttr : AnyAttrOf<[AnyAttr]>; +def Vifrt_AnyRegion : Region, "any region">; + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_dialect.td b/xla/python/ifrt/ir/vifrt_dialect.td new file mode 100644 index 00000000000000..9aa84e19dbeff2 --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_dialect.td @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_ +#define XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_ + +def Vifrt_Dialect : Dialect { + let name = "vifrt"; + let summary = "Versioned IFRT dialect"; + let cppNamespace = "::xla::ifrt"; + + let description = [{ + A versioned copy of the IFRT IR dialect that is used for forward and + backward compatible serialization/deserialization. + + Version log: + 0.1.0: Initial IFRT IR stability guarantees. + }]; + + let useDefaultAttributePrinterParser = 0; + let useDefaultTypePrinterParser = 0; + let usePropertiesForAttributes = 1; +} + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_ops.cc b/xla/python/ifrt/ir/vifrt_ops.cc new file mode 100644 index 00000000000000..8aedd0ab40bbc9 --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_ops.cc @@ -0,0 +1,162 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/ir/vifrt_ops.h" + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: export +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "xla/python/ifrt/ir/vifrt_types.h" + +namespace xla { +namespace ifrt { + +namespace { + +// Verifies if a given type or attribute is from VIFRT dialect. +template +bool isFromVifrt(TypeOrAttr t) { + return t.getDialect().getNamespace() == VifrtDialect::getDialectNamespace(); +} + +template +bool allFromVifrt(llvm::ArrayRef range) { + return llvm::all_of(range, isFromVifrt); +} + +} // namespace + +// Helper functions for VIFRT printers and parsers. +static void printAttributeArray(mlir::AsmPrinter& os, + llvm::ArrayRef arrayAttr) { + os << '[' << arrayAttr << ']'; +} + +// Parse attributes in brackets: [#virt.attr, #virt.attr] +mlir::ParseResult parseAttributeArray( + mlir::AsmParser& parser, llvm::SmallVector& arrayAttr) { + mlir::ArrayAttr array; + if (mlir::failed(parser.parseAttribute(array))) { + return mlir::failure(); + } + arrayAttr.append(array.begin(), array.end()); + return mlir::success(); +} + +} // namespace ifrt +} // namespace xla + +#include "xla/python/ifrt/ir/vifrt_attr_interfaces.cc.inc" +#define GET_ATTRDEF_CLASSES +#include "xla/python/ifrt/ir/vifrt_attrs.cc.inc" +#include "xla/python/ifrt/ir/vifrt_op_interfaces.cc.inc" +#define GET_OP_CLASSES +#include "xla/python/ifrt/ir/vifrt_ops.cc.inc" + +namespace xla { +namespace ifrt { + +//===----------------------------------------------------------------------===// +// VIFRT Dialect Constructor +//===----------------------------------------------------------------------===// + +VifrtDialect::VifrtDialect(mlir::MLIRContext* context) + : mlir::Dialect(getDialectNamespace(), context, + mlir::TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "xla/python/ifrt/ir/vifrt_ops.cc.inc" + >(); + addVifrtTypes(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "xla/python/ifrt/ir/vifrt_attrs.cc.inc" + >(); +} + +void VifrtDialect::addVifrtTypes() { + // Following the same solution as in VHLO; Idiomatically, this functionality + // is expressed as shown below: + // addTypes< + // #define GET_TYPEDEF_LIST + // #include + // "third_party/tensorflow/compiler/xla/python/ifrt/ir/vifrt_type_defs.cc.inc" + // >(); + // + // However, Dialect::addTypes doesn't work for our situation where we want to + // decouple the vifrt_ops and vifrt_types targets because + // vifrt_type_defs.h.inc only includes forward declarations of `TypeStorage` + // structs, and that's not sufficient for Dialect::addTypes to compile. + // Therefore, we introduce this function and then reimplementing + // Dialect::addTypes as shown below. + addTypesWithoutRegistering< +#define GET_TYPEDEF_LIST +#include "xla/python/ifrt/ir/vifrt_type_defs.cc.inc" + >(); + registerVifrtTypes(getContext()); +} + +mlir::Type VifrtDialect::parseType(mlir::DialectAsmParser& parser) const { + llvm::StringRef data_type; + mlir::Type type; + auto parse_result_opt = parseVifrtType(parser, &data_type, type); + if (parse_result_opt.has_value() && mlir::succeeded(*parse_result_opt)) { + return type; + } + parser.emitError(parser.getNameLoc()) << "unknown vifrt type: " << data_type; + return nullptr; +} + +void VifrtDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter& os) const { + if (mlir::succeeded(printVifrtType(type, os))) { + return; + } + os << ""; +} + +mlir::Attribute VifrtDialect::parseAttribute(mlir::DialectAsmParser& parser, + mlir::Type type) const { + llvm::StringRef attr_tag; + mlir::Attribute attr; + auto parse_result = generatedAttributeParser(parser, &attr_tag, type, attr); + if (parse_result.has_value()) { + return attr; + } + parser.emitError(parser.getNameLoc(), "unknown vifrt attribute"); + return mlir::Attribute(); +} + +void VifrtDialect::printAttribute(mlir::Attribute attr, + mlir::DialectAsmPrinter& os) const { + mlir::LogicalResult result = generatedAttributePrinter(attr, os); + // Avoid clang unused variable error. + (void)result; + assert(mlir::succeeded(result)); +} + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/ir/vifrt_ops.h b/xla/python/ifrt/ir/vifrt_ops.h new file mode 100644 index 00000000000000..218aed758ca6be --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_ops.h @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_OPS_H_ +#define XLA_PYTHON_IFRT_IR_VIFRT_OPS_H_ + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/Support/LLVM.h" +#include "xla/python/ifrt/ir/version.h" // IWYU pragma: export +#include "xla/python/ifrt/ir/vifrt_types.h" // IWYU pragma: export + +namespace xla { +namespace ifrt { + +class VifrtDialect : public mlir::Dialect { + public: + explicit VifrtDialect(mlir::MLIRContext *context); + + static mlir::StringRef getDialectNamespace() { return "vifrt"; } + + // Parses a type registered in the VIFRT dialect. + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + + // Prints a type registered in the VIFRT dialect. + void printType(mlir::Type type, mlir::DialectAsmPrinter &os) const override; + + // Parses an attribute registered in the VIFRT dialect. + mlir::Attribute parseAttribute(mlir::DialectAsmParser &parser, + mlir::Type type) const override; + + // Prints an attribute registered in the VIFRT dialect. + void printAttribute(mlir::Attribute attr, + mlir::DialectAsmPrinter &os) const override; + + private: + // Adds VIFRT types to this dialect. + // See implementation comment for additional details. + void addVifrtTypes(); + + // Does the same this as Dialect::addTypes but without calling `registerType`. + // See comments for `addVifrtTypes` for additional details. + template + void addTypesWithoutRegistering() { + (addType(Types::getTypeID(), mlir::AbstractType::get(*this)), ...); + } +}; + +} // namespace ifrt +} // namespace xla + +// Attrs +#include "xla/python/ifrt/ir/vifrt_attr_interfaces.h.inc" +#define GET_ATTRDEF_CLASSES +#include "xla/python/ifrt/ir/vifrt_attrs.h.inc" + +// Ops +#include "xla/python/ifrt/ir/vifrt_op_interfaces.h.inc" +#define GET_OP_CLASSES +#include "xla/python/ifrt/ir/vifrt_ops.h.inc" + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_OPS_H_ diff --git a/xla/python/ifrt/ir/vifrt_ops.td b/xla/python/ifrt/ir/vifrt_ops.td new file mode 100644 index 00000000000000..6f1c9044ab967e --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_ops.td @@ -0,0 +1,162 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_OPS_TD_ +#define XLA_PYTHON_IFRT_IR_VIFRT_OPS_TD_ + +include "mlir/IR/OpBase.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "xla/python/ifrt/ir/vifrt_dialect.td" +include "xla/python/ifrt/ir/vifrt_types.td" +include "xla/python/ifrt/ir/vifrt_attrs.td" + +def Vifrt_VersionedOpInterface : OpInterface<"VifrtVersionedOpInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + "The min version of the VIFRT dialect an op is supported in.", + "::xla::ifrt::Version", "getMinVersion">, + InterfaceMethod< + "The max version (inclusive) of the VIFRT dialect an op is supported in.", + "::xla::ifrt::Version", "getMaxVersion">, + ]; +} + +def Vifrt_VersionedOpConstraintInterface : OpInterface<"VifrtVersionedOpConstraintInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + [{Validates versioned constraints on a versioned op. + Used if the specified constraints of an op change over time.}], + "mlir::LogicalResult", "validateConstraint", + (ins "mlir::Operation*":$op, "::xla::ifrt::Version":$targetVersion)>, + ]; +} + +class Vifrt_Op traits = []> : + Op] # traits> { + let extraClassDefinition = [{ + ::xla::ifrt::Version $cppClass::getMinVersion() { + return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{); + } + ::xla::ifrt::Version $cppClass::getMaxVersion() { + }] # !if( + !eq(max_version, "current"), + [{ return ::xla::ifrt::Version::getCurrentVersion(); }], + [{ return ::xla::ifrt::Version(}] # !subst(".", ", ", max_version) # [{); }] + ) # [{ + } + }]; +} + +// VIFRT is a bare mininum versioned copy of IFRT. In the pursuit of minimality, +// it adopts the following conventions (similar to VHLO): +// 1) Use Vifrt_AnyType or Variading for all operands and +// results. +// 2) Use traits only when strictly necessary (e.g., `AttrSizedOperandSegments` +// when multiple variable operands are present). +// 3) Use Vifrt_AnyAttr for all attributes. +// 4) Don't use verifiers. + +def Vifrt_ReshardOpV1 : Vifrt_Op<"ReshardV1", "0.1.0", "current", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$inputs, + Vifrt_AnyAttr:$donated, + Variadic:$control_inputs); + let results = (outs + Variadic:$outputs, + Vifrt_AnyType:$control_output); +} + +def Vifrt_CopyArraysOpV1 : Vifrt_Op<"CopyArraysV1", "0.1.0", "current", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$inputs, + Vifrt_AnyAttr:$donated, + Variadic:$control_inputs); + let results = (outs + Variadic:$outputs, + Vifrt_AnyType:$control_output); +} + +def Vifrt_AssembleOpV1 + : Vifrt_Op<"AssembleV1", "0.1.0", "current", [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$inputs, + Variadic:$control_inputs); + let results = (outs Vifrt_AnyType:$output); +} + +def Vifrt_DisassembleOpV1 : Vifrt_Op<"DisassembleV1", "0.1.0", "current", []> { + let arguments = (ins + Vifrt_AnyType:$input, + Variadic:$control_inputs); + let results = (outs Variadic:$outputs); +} + +def Vifrt_RemapArraysOpV1 : Vifrt_Op<"RemapArraysV1", "0.1.0", "current", []> { + let arguments = (ins + Variadic:$inputs, + Vifrt_AnyAttr:$mappings, + Vifrt_AnyAttr:$donated); + let results = (outs Variadic:$outputs); +} + +def Vifrt_CallOpV1 : Vifrt_Op<"CallV1", "0.1.0", "current", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$inputs, + Variadic:$control_inputs, + Vifrt_AnyAttr:$callee, + Vifrt_AnyAttr:$devices, + Vifrt_AnyAttr:$io_aliases, + Vifrt_AnyAttr:$donated_input_indices); + let results = (outs + Variadic:$outputs, + Vifrt_AnyType:$control_output); +} + +def Vifrt_CallLoadedExecutableOpV1 : Vifrt_Op<"CallLoadedExecutableV1", "0.1.0", + "current", [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$inputs, + Variadic:$control_inputs, + Vifrt_AnyAttr:$callee, + Vifrt_AnyAttr:$io_aliases, + Vifrt_AnyAttr:$donated_input_indices); + let results = (outs + Variadic:$outputs, + Vifrt_AnyType:$control_output); +} + +def Vifrt_LoadedExecutableOpV1 : Vifrt_Op<"LoadedExecutableV1", "0.1.0", + "current", []> { + let arguments = (ins + Vifrt_AnyAttr:$sym_name, + Vifrt_AnyAttr:$function_type, + Vifrt_AnyAttr:$devices + ); +} + +def Vifrt_AfterOpV1 : Vifrt_Op<"AfterV1", "0.1.0", "current", []> { + let arguments = (ins Variadic:$inputs); + let results = (outs Vifrt_AnyType:$control_output); +} + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_OPS_TD_ \ No newline at end of file diff --git a/xla/python/ifrt/ir/vifrt_types.cc b/xla/python/ifrt/ir/vifrt_types.cc new file mode 100644 index 00000000000000..5fb1c7d7e16da5 --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_types.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/ir/vifrt_types.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: export +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "stablehlo/dialect/AssemblyFormat.h" // IWYU pragma: export + +namespace xla { +namespace ifrt { + +void VifrtTypeConverterBuiltin::addBuiltinToVifrtConversions() { + // We currently rely on the builtin types being stable, and thus we do not + // convert builtin types to VIFRT types. +} + +void VifrtTypeConverterBuiltin::addVifrtToBuiltinConversions() { + // We currently rely on the builtin types are stable, and thus we do not + // convert from VIFRT types to builtin types. +} + +namespace { + +// Verifies if a given type or attribute is from VIFRT dialect. +// Must be defined before importing the generated type interfaces and defs. +template +bool isFromVifrt(TypeOrAttr t) { + return t.getDialect().getNamespace() == "vifrt"; +} + +} // namespace + +} // namespace ifrt +} // namespace xla + +// Include order matters. +#include "xla/python/ifrt/ir/vifrt_type_interfaces.cc.inc" +#define GET_TYPEDEF_CLASSES +#include "xla/python/ifrt/ir/vifrt_type_defs.cc.inc" + +namespace xla { +namespace ifrt { + +mlir::LogicalResult printVifrtType(mlir::Type type, mlir::AsmPrinter& printer) { + return generatedTypePrinter(type, printer); +} + +mlir::OptionalParseResult parseVifrtType(mlir::AsmParser& parser, + llvm::StringRef* mnemonic, + mlir::Type& type) { + return generatedTypeParser(parser, mnemonic, type); +} + +namespace { +template +void registerVifrtTypes(mlir::MLIRContext* context) { + (mlir::detail::TypeUniquer::registerType(context), ...); +} +} // namespace + +void registerVifrtTypes(mlir::MLIRContext* context) { + registerVifrtTypes< +#define GET_TYPEDEF_LIST +#include "xla/python/ifrt/ir/vifrt_type_defs.cc.inc" + >(context); +} + +} // namespace ifrt +} // namespace xla diff --git a/xla/python/ifrt/ir/vifrt_types.h b/xla/python/ifrt/ir/vifrt_types.h new file mode 100644 index 00000000000000..d24d7b17eae324 --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_types.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_TYPES_H_ +#define XLA_PYTHON_IFRT_IR_VIFRT_TYPES_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "xla/python/ifrt/ir/version.h" // IWYU pragma: export + +namespace xla { +namespace ifrt { + +class VifrtTypeConverterBase : public mlir::TypeConverter { + public: + VifrtTypeConverterBase() : mlir::TypeConverter() {}; + + ~VifrtTypeConverterBase() override = default; +}; + +// Class used to manage conversions between VIFRT and Builtin types. +class VifrtTypeConverterBuiltin : public VifrtTypeConverterBase { + public: + // A subclass can call this method to add conversions from VIFRT to Builtin + // types. Conversions are applied in reverse order, with the most recently + // added conversion attempted to be applied first. + void addVifrtToBuiltinConversions(); + + // A subclass can call this method to add conversions from Builtin to VIFRT + // types. Conversions are applied in reverse order, with the most recently + // added conversion attempted to be applied first. + void addBuiltinToVifrtConversions(); +}; + +// Auto-generated VIFRT type printers and parsers. +mlir::LogicalResult printVifrtType(mlir::Type type, mlir::AsmPrinter& printer); +mlir::OptionalParseResult parseVifrtType(mlir::AsmParser& parser, + llvm::StringRef* mnemonic, + mlir::Type& type); + +// Registers VIFRT types in a given MLIR context. +void registerVifrtTypes(mlir::MLIRContext* context); + +} // namespace ifrt +} // namespace xla + +#include "xla/python/ifrt/ir/vifrt_type_interfaces.h.inc" +#define GET_TYPEDEF_CLASSES +#include "xla/python/ifrt/ir/vifrt_type_defs.h.inc" + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_TYPES_H_ diff --git a/xla/python/ifrt/ir/vifrt_types.td b/xla/python/ifrt/ir/vifrt_types.td new file mode 100644 index 00000000000000..e327a64dfe60cf --- /dev/null +++ b/xla/python/ifrt/ir/vifrt_types.td @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_VIFRT_TYPES_TD_ +#define XLA_PYTHON_IFRT_IR_VIFRT_TYPES_TD_ + +include "mlir/IR/AttrTypeBase.td" +include "xla/python/ifrt/ir/vifrt_base.td" +include "xla/python/ifrt/ir/vifrt_dialect.td" + +def Vifrt_VersionedTypeInterface : TypeInterface<"VifrtVersionedTypeInterface"> { + let cppNamespace = "::xla::ifrt"; + let methods = [ + InterfaceMethod< + "The min version of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMinVersion">, + InterfaceMethod< + "The max version (inclusive) of the VIFRT dialect an attribute is supported in.", + "::xla::ifrt::Version", "getMaxVersion">, + ]; +} + +class Vifrt_TypeDef + : TypeDef { + let mnemonic = name; + let extraClassDeclaration = [{ + ::xla::ifrt::Version getMinVersion() { + return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{); + } + ::xla::ifrt::Version getMaxVersion() { + }] # !if( + !eq(max_version, "current"), + [{ return ::xla::ifrt::Version::getCurrentVersion(); }], + [{ return ::xla::ifrt::Version("}] # !subst(".", ", ", max_version) # [{"); }] + ) # [{ + } + }]; +} + +def Vifrt_Array : Vifrt_TypeDef<"VifrtArrayV1", "array_v1", "0.1.0", "current"> { + let parameters = (ins + "::mlir::Type":$shape, + "::mlir::Attribute":$sharding_attr, + "::mlir::Attribute":$devices_attr, + "::mlir::Attribute":$memory_kind_attr); + + let genVerifyDecl = 1; + // The verifier does not check `shape` and `memory_kind_attr` because they + // use builtin types `mlir::RankedTensor` and `mlir::StringAttr`. + let extraClassDefinition = [{ + ::llvm::LogicalResult VifrtArrayV1Type::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic ()> err_fn, + ::mlir::Type shape, + ::mlir::Attribute sharding_attr, + ::mlir::Attribute devices_attr, + ::mlir::Attribute memory_kind_attr + ) { + if (!isFromVifrt(sharding_attr) || + !isFromVifrt(devices_attr)) { + return err_fn() << "expected VIFRT type or attribute"; + } + return mlir::success(); + } + }]; + let assemblyFormat = [{ + `<` $shape`,` $sharding_attr `,` $devices_attr `,` `memory_kind` `=` + $memory_kind_attr `>` + }]; +} + +def Vifrt_Control : Vifrt_TypeDef<"VifrtControlV1", "control_v1", "0.1.0", "current">; + + +#endif // XLA_PYTHON_IFRT_IR_VIFRT_TYPES_TD_ \ No newline at end of file diff --git a/xla/service/gpu/gpu_hlo_schedule.cc b/xla/service/gpu/gpu_hlo_schedule.cc index 6962e86babe8c9..d03a673eab54b4 100644 --- a/xla/service/gpu/gpu_hlo_schedule.cc +++ b/xla/service/gpu/gpu_hlo_schedule.cc @@ -518,8 +518,8 @@ absl::StatusOr ScheduleGpuModule( return GetSizeOfShape(shape, pointer_size); }; auto scheduler_core = std::make_unique( - shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), - config); + shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config, + GpuIsResourceConstrained); pipeline.AddPass(); pipeline.AddPass( std::move(latency_estimator), std::move(async_tracker), diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 418d29dbbd3362..1e555f7b022015 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -114,6 +114,27 @@ bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) { return IsGpuAsyncStart(from) && IsGpuAsyncDone(target); } +// Count the maximum overlapping count in subgroups of group and other +size_t CountOverlappingRanks(const std::vector& group, + const std::vector& other) { + size_t overlapping_count = 0; + for (const auto& curr_replica_group : group) { + absl::flat_hash_set curr_replica_ids; + for (const auto curr_replica_id : curr_replica_group.replica_ids()) { + curr_replica_ids.insert(curr_replica_id); + } + + for (const auto& replica_group : other) { + size_t subgroup_count = 0; + for (const auto replica_id : replica_group.replica_ids()) { + if (curr_replica_ids.contains(replica_id)) ++subgroup_count; + } + overlapping_count = std::max(overlapping_count, subgroup_count); + } + } + return overlapping_count; +} + } // namespace int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { @@ -141,6 +162,78 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo) { } } +bool GpuIsResourceConstrained( + const DefaultSchedulerCore::SchedulingState& sched_state, + DefaultSchedulerCore::ScheduleCandidate& cand) { + if (cand.resource_constrained) { + return *cand.resource_constrained; + } + if (cand.node->GetResources().empty()) { + cand.resource_constrained = false; + return *(cand.resource_constrained); + } + cand.resource_constrained = false; + for (const auto& [resource_type, usage_type] : cand.node->GetResources()) { + auto max_it = sched_state.max_concurrent_resource.find(resource_type); + auto res_it = sched_state.resource_users_in_queue.find(resource_type); + cand.resource_constrained = + max_it != sched_state.max_concurrent_resource.end() && + max_it->second == 0 && + res_it != sched_state.resource_users_in_queue.end() && + res_it->second > 0; + + // If the candidate collective has more than 1 overlapping ranks with + // in-flight collectives, they can form cyclic dependency and cannot be + // overlapped + if ((resource_type - AsyncTracker::GetFirstTargetDefinedResource()) == + static_cast(GpuResourceType::kGpuAsyncStreamCollectives) && + sched_state.resource_occupiers_in_flight.contains(resource_type) && + sched_state.resource_occupiers_in_flight.at(resource_type).size() > 0) { + const HloInstruction& curr_hlo_inst = cand.node->GetInstr(); + if (hlo_query::IsAsyncCollectiveDoneOp(&curr_hlo_inst, true)) { + CHECK(hlo_query::IsAsyncCollectiveStartOp(curr_hlo_inst.operand(0), + true)); + const HloInstruction* curr_start_inst = + curr_hlo_inst.operand(0)->async_wrapped_instruction(); + + // If candidate can be overlapped with in-flight collectives + bool can_overlap = true; + for (const auto occupier : + sched_state.resource_occupiers_in_flight.at(resource_type)) { + const HloInstruction& hlo_inst = occupier->GetInstr(); + if (hlo_query::IsAsyncCollectiveDoneOp(&hlo_inst, true)) { + const HloInstruction* start_inst = hlo_inst.operand(0); + // Number of overlapping ranks between this occupier and candidate + size_t overlapping_count = + CountOverlappingRanks(curr_start_inst->replica_groups(), + start_inst->replica_groups()); + if (overlapping_count > 1) { + can_overlap = false; + VLOG(3) + << "Collectives have " << overlapping_count + << "overlapping ranks and cannot be overlapped. Candidate " + "collective: " + << curr_start_inst->ToString() + << ", in flight collective: " << start_inst->ToString(); + break; + } + } + } + + if (!can_overlap) { + cand.resource_constrained = true; + return *cand.resource_constrained; + } + } + } + + if (*cand.resource_constrained) { + return *cand.resource_constrained; + } + } + return *cand.resource_constrained; +} + //===--------------------------------------------------------------------===// // GpuAsyncTrackerBase //===--------------------------------------------------------------------===// diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler.h b/xla/service/gpu/gpu_latency_hiding_scheduler.h index ad6f67d774924b..85b527f8208d03 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler.h +++ b/xla/service/gpu/gpu_latency_hiding_scheduler.h @@ -34,6 +34,15 @@ CanonicalAsyncOp GpuGetCanonicalAsyncOp(const HloInstruction& hlo); // Returns size of the `shape` given the `pointer_size`. int64_t GetSizeOfShape(const Shape& shape, int pointer_size); +// GPU resource constrain rule for scheduling candidate. +// On top of the default rule, we do not allow collectives with more than 1 +// overlapping ranks to overlap. This is because the execution order of NCCL +// kernels is not deterministic and cannot be controlled by launch order at the +// moment. A cyclic dependency can be formed with at least 2 overlapping ranks. +bool GpuIsResourceConstrained( + const DefaultSchedulerCore::SchedulingState& sched_state, + DefaultSchedulerCore::ScheduleCandidate& cand); + // GPU specific resources for latency hiding scheduler. // // We use two different set of resources to model the scheduling of asynchronous diff --git a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 42e05cf9db71cd..d3895ac6926c80 100644 --- a/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -371,7 +371,7 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest, std::vector instruction_sequence = schedule.sequence(module->entry_computation()).instructions(); // Since we allow 2 collectives in-flight, we should expect this pattern: - // ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> ar(rs)-done + // ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> rs(ar)-done EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_0") < GetIndexByName(instruction_sequence, "rs_1") && GetIndexByName(instruction_sequence, "rs_0") < @@ -386,5 +386,59 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest, GetIndexByName(instruction_sequence, "rs_1")); } +TEST_F(GpuLatencyHidingSchedulerBaseTest, + OverlappingRanksPreventOverlappingCollectives) { + absl::string_view kFdoProfile = R"pb( + costs { name: "add_0" cost_us: 100000.0 } + costs { name: "ar_0" cost_us: 10.0 } + costs { name: "rs_0" cost_us: 10.0 } + )pb"; + ; + absl::string_view kHloModule = R"( + HloModule m + + reduce { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] add(x, y) + } + + ENTRY main { + p0 = f32[] parameter(0) + p1 = f32[2] parameter(1) + p2 = f32[2] parameter(2) + ar_0 = f32[] all-reduce-start(p0), to_apply=reduce, replica_groups={{0,1}} + ar_1 = f32[] all-reduce-done(ar_0) + rs_0 = ((f32[2]), f32[1]) reduce-scatter-start(p1), to_apply=reduce, dimensions={0}, replica_groups={{0, 1}} + rs_1 = f32[1] reduce-scatter-done(rs_0) + add_0 = f32[2] add(p1, p2) + ROOT _ = (f32[], f32[1], f32[2]) tuple(ar_1, rs_1, add_0) + } + )"; + + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + TF_EXPECT_OK(ScheduleModule(module.get(), /*num_parallel_resources=*/2)); + auto schedule = module->schedule(); + std::vector instruction_sequence = + schedule.sequence(module->entry_computation()).instructions(); + // AR and RS have two ranks in common so cannot be overlapped, expect pattern: + // rs(ar)-start -> add -> rs(ar)-done -> ar(rs)-start -> ar(rs)-done + EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_1") < + GetIndexByName(instruction_sequence, "rs_0") || + GetIndexByName(instruction_sequence, "rs_1") < + GetIndexByName(instruction_sequence, "ar_0")); + EXPECT_TRUE((GetIndexByName(instruction_sequence, "ar_0") < + GetIndexByName(instruction_sequence, "add_0") && + GetIndexByName(instruction_sequence, "add_0") < + GetIndexByName(instruction_sequence, "ar_1")) || + (GetIndexByName(instruction_sequence, "rs_0") < + GetIndexByName(instruction_sequence, "add_0") && + GetIndexByName(instruction_sequence, "add_0") < + GetIndexByName(instruction_sequence, "rs_1"))); +} + } // namespace } // namespace xla::gpu diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index 05b7e3aa872368..1eefc664f9c711 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -47,6 +47,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/map_util.h" #include "xla/service/dump.h" #include "xla/service/hlo_buffer.h" @@ -825,10 +826,12 @@ class ReadySetLt { explicit ReadySetLt( const DefaultSchedulerCore::SchedulingState* sched_state, DefaultSchedulerCore::TargetSchedulingRule target_scheduling_rule, - DefaultSchedulerCore::TargetSchedulingRule early_target_scheduling_rule) + DefaultSchedulerCore::TargetSchedulingRule early_target_scheduling_rule, + DefaultSchedulerCore::ResourceConstrainRule is_resource_constrained) : sched_state_(*sched_state), target_scheduling_rule_(target_scheduling_rule), - early_target_scheduling_rule_(early_target_scheduling_rule) {} + early_target_scheduling_rule_(early_target_scheduling_rule), + is_resource_constrained_(is_resource_constrained) {} // The comparison here implements the priority for the nodes in the ready set. DefaultSchedulerCore::CandidateResult operator()( DefaultSchedulerCore::ScheduleCandidate& a, @@ -913,7 +916,7 @@ class ReadySetLt { if (auto value = DefaultSchedulerCore::ChooseBestCandidate( ShouldScheduleAsyncDone(a), a, ShouldScheduleAsyncDone(b), b, "kScheduleDone")) { - return *value; + if (!is_resource_constrained_(sched_state_, value->result)) return *value; } // The following rule targets the async ops using resources that should be @@ -969,11 +972,12 @@ class ReadySetLt { if (auto value = DefaultSchedulerCore::ChooseBestCandidate( /*first_cond=*/!(a.node->DoesReleaseAnyResource() && a.node->GetAsyncDepth() == 0 && - !IsResourceConstrained(a)), + !is_resource_constrained_(sched_state_, a)), a, /*second_cond=*/ !(b.node->DoesReleaseAnyResource() && - b.node->GetAsyncDepth() == 0 && !IsResourceConstrained(b)), + b.node->GetAsyncDepth() == 0 && + !is_resource_constrained_(sched_state_, b)), b, "kStartAtZeroDepth")) { return value; } @@ -997,7 +1001,7 @@ class ReadySetLt { if (auto value = DefaultSchedulerCore::ChooseBestCandidate( a_ready_interval < b_ready_interval, a, b_ready_interval < a_ready_interval, b, "kLessStall")) { - return *value; + if (!is_resource_constrained_(sched_state_, value->result)) return *value; } if (sched_state_.config.resource_serializing) { // Prioritize scheduling the instruction which has less serial-resource @@ -1016,13 +1020,17 @@ class ReadySetLt { if (sched_state_.config.aggressive_scheduling_policies && !sched_state_.config.prioritize_async_depth_over_stall) { if (auto value = async_depth_0_candidate(a, b)) { - return *value; + if (!is_resource_constrained_(sched_state_, value->result)) + return *value; } } if (auto value = DefaultSchedulerCore::ChooseBestCandidate( - a.node->DoesReleaseAnyResource() && IsResourceConstrained(a), a, - b.node->DoesReleaseAnyResource() && IsResourceConstrained(b), b, - "kFreeBackedupResource")) { + a.node->DoesReleaseAnyResource() && + is_resource_constrained_(sched_state_, a), + a, + b.node->DoesReleaseAnyResource() && + is_resource_constrained_(sched_state_, b), + b, "kFreeBackedupResource")) { return *value; } if (sched_state_.config.aggressive_scheduling_policies) { @@ -1033,7 +1041,8 @@ class ReadySetLt { a.node->GetAsyncDepth() > b.node->GetAsyncDepth(), a, b.node->GetAsyncDepth() > a.node->GetAsyncDepth(), b, "kAsyncDepth")) { - return *value; + if (!is_resource_constrained_(sched_state_, value->result)) + return *value; } // Favor nodes that are the closest in amount of latency they hide with // the highest amount of latency that needs to be hidden to avoid @@ -1125,10 +1134,12 @@ class ReadySetLt { } // If none of the heuristics above triggers then prefer to schedule // according the original order so that we don't impact memory pressure. - if (sched_state_.sched_graph.OriginalInstructionPosition( - &a.node->GetInstr()) > - sched_state_.sched_graph.OriginalInstructionPosition( - &b.node->GetInstr())) { + if ((sched_state_.sched_graph.OriginalInstructionPosition( + &a.node->GetInstr()) > + sched_state_.sched_graph.OriginalInstructionPosition( + &b.node->GetInstr())) || + (!is_resource_constrained_(sched_state_, a) && + is_resource_constrained_(sched_state_, b))) { return {a, "kOriginalOrder"}; } return {b, "kOriginalOrder"}; @@ -1138,6 +1149,7 @@ class ReadySetLt { const DefaultSchedulerCore::SchedulingState& sched_state_; DefaultSchedulerCore::TargetSchedulingRule target_scheduling_rule_; DefaultSchedulerCore::TargetSchedulingRule early_target_scheduling_rule_; + DefaultSchedulerCore::ResourceConstrainRule is_resource_constrained_; int ReadyIfScheduled(const HloGraphNode& gn) const { int ready_nodes_if_scheduled = 0; @@ -1151,30 +1163,6 @@ class ReadySetLt { static bool IsNop(const HloGraphNode& gn) { return IsNopInstruction(gn.GetInstr()); } - bool IsResourceConstrained( - DefaultSchedulerCore::ScheduleCandidate& cand) const { - if (cand.resource_constrained) { - return *cand.resource_constrained; - } - if (cand.node->GetResources().empty()) { - cand.resource_constrained = false; - return *(cand.resource_constrained); - } - cand.resource_constrained = false; - for (const auto& [resource_type, usage_type] : cand.node->GetResources()) { - auto max_it = sched_state_.max_concurrent_resource.find(resource_type); - auto res_it = sched_state_.resource_users_in_queue.find(resource_type); - cand.resource_constrained = - max_it != sched_state_.max_concurrent_resource.end() && - max_it->second == 0 && - res_it != sched_state_.resource_users_in_queue.end() && - res_it->second > 0; - if (*cand.resource_constrained) { - return *cand.resource_constrained; - } - } - return *cand.resource_constrained; - } bool ShouldScheduleAsyncDone( DefaultSchedulerCore::ScheduleCandidate& gn_cand) const { if (!gn_cand.node->DoesOccupyAnyResource()) { @@ -1271,9 +1259,10 @@ class ReadySetLt { cand.node->GetResources()); int64_t num_conflicting_resources = 0; for (int64_t resource : resources) { - if (!sched_state_.resources_in_flight.contains(resource)) continue; + if (!sched_state_.resource_occupiers_in_flight.contains(resource)) + continue; num_conflicting_resources += - sched_state_.resources_in_flight.at(resource); + sched_state_.resource_occupiers_in_flight.at(resource).size(); } return num_conflicting_resources; } @@ -1317,8 +1306,9 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable( for (const auto& [resource, limit] : sched_state.max_concurrent_resource) { // No resources in flight of this kind. Continue. - auto it = sched_state.resources_in_flight.find(resource); - if (it == sched_state.resources_in_flight.end() || it->second == 0) { + auto it = sched_state.resource_occupiers_in_flight.find(resource); + if (it == sched_state.resource_occupiers_in_flight.end() || + it->second.size() == 0) { continue; } // Number of instances of 'resource' needed if this instruction was to @@ -1334,7 +1324,7 @@ DefaultSchedulerCore::FindAndExtractBestNodeAvailable( }; VLOG(2) << "Current time: " << sched_state.current_time; ReadySetLt ready_lt{&sched_state, target_scheduling_rule_, - early_target_scheduling_rule_}; + early_target_scheduling_rule_, is_resource_constrained_}; // Construct a schedule candidate for caching. ScheduleCandidate ready_chosen; auto chosen_it = sched_state.ready_set.end(); @@ -1432,6 +1422,33 @@ void DefaultSchedulerCore::LogInstruction(const HloInstruction* instr) const { VLOG(5) << instr->ToString(); } +bool DefaultSchedulerCore::DefaultIsResourceConstrained( + const DefaultSchedulerCore::SchedulingState& sched_state, + DefaultSchedulerCore::ScheduleCandidate& cand) { + if (cand.resource_constrained) { + return *cand.resource_constrained; + } + if (cand.node->GetResources().empty()) { + cand.resource_constrained = false; + return *(cand.resource_constrained); + } + cand.resource_constrained = false; + for (const auto& [resource_type, usage_type] : cand.node->GetResources()) { + auto max_it = sched_state.max_concurrent_resource.find(resource_type); + auto res_it = sched_state.resource_users_in_queue.find(resource_type); + cand.resource_constrained = + max_it != sched_state.max_concurrent_resource.end() && + max_it->second == 0 && + res_it != sched_state.resource_users_in_queue.end() && + res_it->second > 0; + + if (*cand.resource_constrained) { + return *cand.resource_constrained; + } + } + return *cand.resource_constrained; +} + void PrintOccupierList( std::vector>& occupiers) { for (int64_t i = 0; i < occupiers.size(); i++) { @@ -1902,9 +1919,9 @@ absl::StatusOr DefaultSchedulerCore::ScheduleNode( ++sched_state->scheduled_count; for (auto& resource : n->GetResources()) { if (resource.second == ResourceUsageType::kResourceRelease) { - --sched_state->resources_in_flight[resource.first]; + sched_state->resource_occupiers_in_flight[resource.first].erase(n); } else if (resource.second == ResourceUsageType::kResourceOccupy) { - ++sched_state->resources_in_flight[resource.first]; + sched_state->resource_occupiers_in_flight[resource.first].insert(n); } } VLOG(10) << "Memory pressure before schedule: " diff --git a/xla/service/latency_hiding_scheduler.h b/xla/service/latency_hiding_scheduler.h index e902f1ceefb761..7c3b0f03e7a641 100644 --- a/xla/service/latency_hiding_scheduler.h +++ b/xla/service/latency_hiding_scheduler.h @@ -957,8 +957,9 @@ class DefaultSchedulerCore : public SchedulerCore { std::vector new_sequence_reversed; // Units of time passed in the schedule. To keep track of latency hiding. HloGraphNode::TimeCost current_time = 0; - // Number of resources in flight. - ResourceMap resources_in_flight; + // Resources and corresponding occupiers in flight. + absl::flat_hash_map> + resource_occupiers_in_flight; // Number of instructions using the key resource type in the set waiting to // be scheduled. ResourceMap resource_users_in_queue; @@ -1008,12 +1009,16 @@ class DefaultSchedulerCore : public SchedulerCore { config(config) {} }; + using ResourceConstrainRule = + std::function; using PostProcessingFn = std::function; DefaultSchedulerCore( HloCostAnalysis::ShapeSizeFunction shape_size_bytes, const AsyncTracker* async_tracker, const LatencyEstimator* latency_estimator, const SchedulerConfig& config, + ResourceConstrainRule is_resource_constrained = + DefaultIsResourceConstrained, TargetSchedulingRule target_scheduling_rule = nullptr, TargetSchedulingRule early_target_scheduling_rule = nullptr, PostProcessingFn post_processing_fn = nullptr) @@ -1021,6 +1026,7 @@ class DefaultSchedulerCore : public SchedulerCore { async_tracker_(async_tracker), latency_estimator_(latency_estimator), config_(config), + is_resource_constrained_(is_resource_constrained), target_scheduling_rule_(target_scheduling_rule), early_target_scheduling_rule_(early_target_scheduling_rule), post_processing_fn_(post_processing_fn) {} @@ -1066,6 +1072,8 @@ class DefaultSchedulerCore : public SchedulerCore { const HloComputation* computation, const HloScheduleGraph& schedule_graph, const std::vector& instructions, int cycles_per_microsecond, const DebugOptions& debug_options); + static bool DefaultIsResourceConstrained(const SchedulingState& sched_state, + ScheduleCandidate& cand); HloCostAnalysis::ShapeSizeFunction shape_size_bytes_; std::unique_ptr module_pressure_state_; @@ -1075,6 +1083,7 @@ class DefaultSchedulerCore : public SchedulerCore { SchedulerConfig config_; TargetSchedulingRule target_scheduling_rule_ = nullptr; TargetSchedulingRule early_target_scheduling_rule_ = nullptr; + ResourceConstrainRule is_resource_constrained_ = nullptr; PostProcessingFn post_processing_fn_ = nullptr; std::unique_ptr annotation_tracker_; };