Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #19275: [NVIDIA] Add fixes for supporting determinism expander for…
… high-dimensional scatter operation and a flag to disable it Imported from GitHub PR #19275 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers. Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable the `scatter_determinism_expander` pass without getting blocked. Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- b016044 by Chenhao Jiang <chenhaoj@nvidia.com>: PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations Imported from GitHub PR #18326 This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates. The change of this PR is on top of #17886 Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit Bugs resolved: jax-ml/jax#17844 Copybara import of the project: -- de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>: Support scatter with non-scalar indices and updates Merging this change closes #18326 PiperOrigin-RevId: 691023328 -- fbdb066 by Chenhao Jiang <chenhaoj@nvidia.com>: Add the scatter indices to operand space mapping and change the offset column-wise permutation based on scatter_dims_to_operand_dims, so that they can add together correctly. -- d36c8ac by Chenhao Jiang <chenhaoj@nvidia.com>: Fix the scatter determinism expander for various dimension numbers -- 678886f by Chenhao Jiang <chenhaoj@nvidia.com>: Add a flag for enabling the scatter_determinism_expander on GPU. Merging this change closes #19275 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4 PiperOrigin-RevId: 696078761
- Loading branch information