Skip to content

Commit

Permalink
PR #19275: [NVIDIA] Add fixes for supporting determinism expander for…
Browse files Browse the repository at this point in the history
… high-dimensional scatter operation and a flag to disable it

Imported from GitHub PR #19275

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.

Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
b016044 by Chenhao Jiang <chenhaoj@nvidia.com>:

PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations

Imported from GitHub PR #18326

This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.

The change of this PR is on top of #17886

Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit

Bugs resolved: jax-ml/jax#17844
Copybara import of the project:

--
de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:

Support scatter with non-scalar indices and updates

Merging this change closes #18326

PiperOrigin-RevId: 691023328

--
fbdb066 by Chenhao Jiang <chenhaoj@nvidia.com>:

Add the scatter indices to operand space mapping
and change the offset column-wise permutation
based on scatter_dims_to_operand_dims, so that
they can add together correctly.

--
d36c8ac by Chenhao Jiang <chenhaoj@nvidia.com>:

Fix the scatter determinism expander for various dimension numbers

--
678886f by Chenhao Jiang <chenhaoj@nvidia.com>:

Add a flag for enabling the scatter_determinism_expander on GPU.

Merging this change closes #19275

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
PiperOrigin-RevId: 696078761
  • Loading branch information
serach24 authored and Google-ML-Automation committed Nov 15, 2024
1 parent a7aff7f commit 3bf54ef
Show file tree
Hide file tree
Showing 6 changed files with 1,202 additions and 143 deletions.
11 changes: 11 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {
opts.set_xla_enable_fast_math(false);
opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1);
opts.set_xla_pjrt_allow_auto_layout_in_hlo(false);
opts.set_xla_gpu_enable_scatter_determinism_expander(true);
return opts;
}

Expand Down Expand Up @@ -2064,6 +2065,16 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Experimental: Make unset entry computation layout mean auto layout "
"instead of default layout in HLO when run through PjRT. In other cases "
"(StableHLO or non-PjRT) the auto layout is already used."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_scatter_determinism_expander",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_scatter_determinism_expander),
debug_options->xla_gpu_enable_scatter_determinism_expander(),
"Enable the scatter determinism expander, an optimized pass that "
"rewrites scatter operations to ensure deterministic behavior with high "
"performance."
"Note that even when this flag is disabled, scatter operations may still "
"be deterministic, although with additional overhead."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,7 @@ cc_library(
"//xla/hlo/transforms:op_expander_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
Expand Down
4 changes: 3 additions & 1 deletion xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ absl::Status RunOptimizationPasses(
if (RequireDeterminism(hlo_module->config())) {
// Scatter can be indeterministic if indices are not unique or a non
// associative combiner function is used. Eliminate these Scatter ops.
pipeline.AddPass<ScatterDeterminismExpander>();
if (debug_options.xla_gpu_enable_scatter_determinism_expander()) {
pipeline.AddPass<ScatterDeterminismExpander>();
}
pipeline.AddPass<ScatterExpander>(
ScatterExpander::kEliminateIndeterministicScatters);
}
Expand Down
Loading

0 comments on commit 3bf54ef

Please sign in to comment.