Skip to content

Commit

Permalink
PR #19026: [NVIDIA GPU] LHS enhancement for multiple collective resou…
Browse files Browse the repository at this point in the history
…rces

Imported from GitHub PR #19026

With #17749, we can let LHS schedule for multiple collective resources. There are some cases that two collectives cannot be overlapped. When two collectives on different stream share at least 2 ranks, they can form cyclic dependency because the execution order of NCCL kernels can be different on each rank. This PR refactored LHS to expose the comparator to backend, and enforced above constraint for GPU backend.
Copybara import of the project:

--
09ce310 by Terry Sun <tesun@nvidia.com>:

LHS deadlock avoidance

Merging this change closes #19026

FUTURE_COPYBARA_INTEGRATE_REVIEW=#19026 from terryysun:terryysun/overlapping_collectives 09ce310
PiperOrigin-RevId: 696020313
  • Loading branch information
terryysun authored and Google-ML-Automation committed Nov 13, 2024
1 parent 2a78903 commit 567c796
Show file tree
Hide file tree
Showing 16 changed files with 1,244 additions and 51 deletions.
184 changes: 184 additions & 0 deletions xla/python/ifrt/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,187 @@ cc_library(
"@llvm-project//mlir:Support",
],
)

gentbl_cc_library(
name = "vifrt_attr_interfaces_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
["-gen-attr-interface-decls"],
"vifrt_attr_interfaces.h.inc",
),
(
["-gen-attr-interface-defs"],
"vifrt_attr_interfaces.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_attrs.td",
test = True,
deps = [":vifrt_ops_td_files"],
)

gentbl_cc_library(
name = "vifrt_attrs_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
"-gen-attrdef-decls",
"--attrdefs-dialect=vifrt",
],
"vifrt_attrs.h.inc",
),
(
[
"-gen-attrdef-defs",
"--attrdefs-dialect=vifrt",
],
"vifrt_attrs.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
test = True,
deps = [":vifrt_ops_td_files"],
)

gentbl_cc_library(
name = "vifrt_op_interfaces_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
["-gen-op-interface-decls"],
"vifrt_op_interfaces.h.inc",
),
(
["-gen-op-interface-defs"],
"vifrt_op_interfaces.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
deps = [":vifrt_ops_td_files"],
)

cc_library(
name = "vifrt_ops",
srcs = [
"vifrt_ops.cc",
],
hdrs = [
"vifrt_ops.h",
],
compatible_with = get_compatible_with_portable(),
deps = [
":version",
":vifrt_attr_interfaces_inc_gen",
":vifrt_attrs_inc_gen",
":vifrt_op_interfaces_inc_gen",
":vifrt_ops_inc_gen",
":vifrt_types",
":vifrt_types_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

gentbl_cc_library(
name = "vifrt_ops_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
["-gen-op-decls"],
"vifrt_ops.h.inc",
),
(
["-gen-op-defs"],
"vifrt_ops.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
test = True,
deps = [":vifrt_ops_td_files"],
)

td_library(
name = "vifrt_ops_td_files",
srcs = [
"vifrt_attrs.td",
"vifrt_base.td",
"vifrt_dialect.td",
"vifrt_ops.td",
"vifrt_types.td",
],
compatible_with = get_compatible_with_portable(),
deps = [
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:ShapeOpsTdFiles",
],
)

cc_library(
name = "vifrt_types",
srcs = ["vifrt_types.cc"],
hdrs = ["vifrt_types.h"],
compatible_with = get_compatible_with_portable(),
deps = [
":version",
":vifrt_type_interfaces_inc_gen",
":vifrt_types_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@stablehlo//:stablehlo_assembly_format",
],
)

gentbl_cc_library(
name = "vifrt_type_interfaces_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
["-gen-type-interface-decls"],
"vifrt_type_interfaces.h.inc",
),
(
["-gen-type-interface-defs"],
"vifrt_type_interfaces.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_types.td",
deps = [
":vifrt_ops_td_files",
],
)

gentbl_cc_library(
name = "vifrt_types_inc_gen",
compatible_with = get_compatible_with_portable(),
tbl_outs = [
(
[
"-gen-typedef-decls",
"--typedefs-dialect=vifrt",
],
"vifrt_type_defs.h.inc",
),
(
[
"-gen-typedef-defs",
"--typedefs-dialect=vifrt",
],
"vifrt_type_defs.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "vifrt_ops.td",
deps = [
":vifrt_ops_td_files",
],
)
122 changes: 122 additions & 0 deletions xla/python/ifrt/ir/vifrt_attrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_
#define XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_

include "mlir/IR/AttrTypeBase.td"
include "xla/python/ifrt/ir/vifrt_base.td"
include "xla/python/ifrt/ir/vifrt_dialect.td"
include "xla/python/ifrt/ir/vifrt_types.td"

def Vifrt_VersionedAttrInterface : AttrInterface<"VifrtVersionedAttrInterface"> {
let cppNamespace = "::xla::ifrt";
let methods = [
InterfaceMethod<
"The min version of the VIFRT dialect an attribute is supported in.",
"::xla::ifrt::Version", "getMinVersion">,
InterfaceMethod<
"The maxi version (inclusive) of the VIFRT dialect an attribute is supported in.",
"::xla::ifrt::Version", "getMaxVersion">,
];
}

class Vifrt_AttrDef<string name, string min_version, string max_version>
: AttrDef<Vifrt_Dialect, name, [Vifrt_VersionedAttrInterface]> {
let extraClassDeclaration = [{
::xla::ifrt::Version getMinVersion() {
return ::xla::ifrt::Version(}] # !subst(".", ", ", min_version) # [{);
}
::xla::ifrt::Version getMaxVersion() {
}] # !if(
!eq(max_version, "current"),
[{ return ::xla::ifrt::Version::getCurrentVersion(); }],
[{ return ::xla::ifrt::Version("}] # !subst(".", ", ", max_version) # [{"); }]
) # [{
}
}];
}

def Vifrt_DevicesAttrV1 : Vifrt_AttrDef<"VifrtDevicesV1", "0.1.0", "current"> {
let mnemonic = "devices_v1";
let parameters = (ins ArrayRefParameter<"int">:$ids);
let assemblyFormat = "`[` $ids `]`";
}

def Vifrt_UnspecifiedShardingAttrV1
: Vifrt_AttrDef<"VifrtUnspecifiedShardingV1", "0.1.0", "current"> {
let mnemonic = "sharding_unspecified_v1";
let parameters = (ins);
let assemblyFormat = "";
}

// TODO(icgog): Introduce versioned ShardingParamV1.
def Vifrt_ShardingParamAttrV1 : Vifrt_AttrDef<"VifrtShardingParamV1", "0.1.0",
"current"> {
let mnemonic = "sharding_param_v1";
}

def Vifrt_IntervalAttrV1 : Vifrt_AttrDef<"VifrtIntervalV1", "0.1.0",
"current"> {
let mnemonic = "interval_v1";
let parameters = (ins "int":$start, "int":$end, "int":$step);
let assemblyFormat = "`[`$start `:` $end `:` $step`]`";
}

def Vifrt_MappingAttrV1 : Vifrt_AttrDef<"VifrtMappingV1", "0.1.0", "current"> {
let mnemonic = "mapping_v1";
let parameters = (ins
Vifrt_IntervalAttrV1:$from_shards,
Vifrt_IntervalAttrV1:$to_shards);
let assemblyFormat = "`<` $from_shards `to` $to_shards `>`";
}

def Vifrt_GenericArrayAttrV1 : Vifrt_AttrDef<"VifrtGenericArrayAttrV1",
"0.1.0", "current"> {
let mnemonic = "generic_array_attr_v1";
let parameters = (ins ArrayRefParameter<"mlir::Attribute">:$value);
let genVerifyDecl = 1;
let extraClassDefinition = [{
mlir::LogicalResult VifrtGenericArrayAttrV1Attr::verify(
llvm::function_ref<mlir::InFlightDiagnostic ()> err_fn,
llvm::ArrayRef<mlir::Attribute> value) {
if (!allFromVifrt(value)) return err_fn() << "expected array of VIFRT attributes";
return mlir::success();
}
}];
let assemblyFormat = "`<` custom<AttributeArray>($value) `>`";
}

def Vifrt_ArrayMappingAttrV1 : Vifrt_AttrDef<"VifrtArrayMappingV1", "0.1.0",
"current"> {
let mnemonic = "array_mapping_v1";
let parameters = (ins
"int32_t":$in_array_index,
"int32_t":$out_array_index,
Vifrt_GenericArrayAttrV1:$mappings);
let assemblyFormat = "`<`$in_array_index`,` $out_array_index`,` $mappings`>`";
}

def Vifrt_IoAliasesAttrV1 : Vifrt_AttrDef<"IfrtIoAliasesV1", "0.1.0", "current"> {
let mnemonic = "io_aliases_v1";
let parameters = (ins
Vifrt_GenericArrayAttrV1:$io_aliases);
let assemblyFormat = "`<` $io_aliases `>`";
}

// TODO(icgog): Introduce Vifrt_MappingAttrArrayAttrV1,
// Vifrt_ArrayMappingAttrArrayAttrV1.

#endif // XLA_PYTHON_IFRT_IR_VIFRT_ATTRS_TD_
27 changes: 27 additions & 0 deletions xla/python/ifrt/ir/vifrt_base.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_
#define XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"

// VIFRT represents the layout only. Therefore, it uses AnyType everywhere.
def Vifrt_AnyType : AnyTypeOf<[AnyType]>;
def Vifrt_AnyAttr : AnyAttrOf<[AnyAttr]>;
def Vifrt_AnyRegion : Region<CPred<"true">, "any region">;

#endif // XLA_PYTHON_IFRT_IR_VIFRT_BASE_TD_
37 changes: 37 additions & 0 deletions xla/python/ifrt/ir/vifrt_dialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2024 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_
#define XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_

def Vifrt_Dialect : Dialect {
let name = "vifrt";
let summary = "Versioned IFRT dialect";
let cppNamespace = "::xla::ifrt";

let description = [{
A versioned copy of the IFRT IR dialect that is used for forward and
backward compatible serialization/deserialization.

Version log:
0.1.0: Initial IFRT IR stability guarantees.
}];

let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
let usePropertiesForAttributes = 1;
}

#endif // XLA_PYTHON_IFRT_IR_VIFRT_DIALECT_TD_
Loading

0 comments on commit 567c796

Please sign in to comment.