Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #19275: [NVIDIA] Add fixes for supporting determinism expander for high-dimensional scatter operation and a flag to disable it #19384

Merged
merged 1 commit into from
Nov 15, 2024

Commits on Nov 15, 2024

  1. PR #19275: [NVIDIA] Add fixes for supporting determinism expander for…

    … high-dimensional scatter operation and a flag to disable it
    
    Imported from GitHub PR #19275
    
    This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.
    
    The change of this PR is on top of #17886, and has fixed issues reported in the reverted PR #18326. The issue was that the changes in #18326 were not able to handle different kinds of complicated but realistic scatter dimension numbers. Specifically, this PR unifies the implementation of 1D and multi-dimensional scatter operation to make the code easier to maintain, adds multiple tests for various scatter dimension numbers, and thoroughly handles all cases of different kinds of dimension numbers.
    
    Moreover, this PR also adds an option `xla_gpu_enable_scatter_determinism_expander`, the default value of which is set to be true. This option could make sure that although unlikely, if anything happens with changes in this PR, the user can easily disable  the `scatter_determinism_expander` pass without getting blocked.
    
    Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
    
    Bugs resolved: jax-ml/jax#17844
    Copybara import of the project:
    
    --
    3b7b56a by Chenhao Jiang <chenhaoj@nvidia.com>:
    
    PR #18326: [NVIDIA] Complete the optimization of deterministic scatter operations
    
    Imported from GitHub PR #18326
    
    This PR is the 2nd step (out of 2) to improve the performance of deterministic scatter. Originally, the scatter op will be expanded to be deterministic in xla/service/ScatterExpander.cc. However, since it took a while-loop-based approach, the performance is extremely poor. We designed and implemented a prefix-scan-based approach to rewrite the scatter operation to be an efficient deterministic scatter. This PR completes the optimization of deterministic scatter operations with non-scalar indices and updates.
    
    The change of this PR is on top of #17886
    
    Design doc: https://docs.google.com/document/d/1K204VZR3OP0SUDOPsGUYgIIDf2ucTKEC4yQj8XRG2SA/edit
    
    Bugs resolved: jax-ml/jax#17844
    Copybara import of the project:
    
    --
    de647d4 by Chenhao Jiang <chenhaoj@nvidia.com>:
    
    Support scatter with non-scalar indices and updates
    
    Merging this change closes #18326
    
    PiperOrigin-RevId: 691023328
    
    --
    126c952 by Chenhao Jiang <chenhaoj@nvidia.com>:
    
    Add the scatter indices to operand space mapping
    and change the offset column-wise permutation
    based on scatter_dims_to_operand_dims, so that
    they can add together correctly.
    
    --
    1ecb608 by Chenhao Jiang <chenhaoj@nvidia.com>:
    
    Fix the scatter determinism expander for various dimension numbers
    
    --
    985079f by Chenhao Jiang <chenhaoj@nvidia.com>:
    
    Add a flag for enabling the scatter_determinism_expander on GPU.
    
    Merging this change closes #19275
    
    COPYBARA_INTEGRATE_REVIEW=#18326 from serach24:chenhao/opt_det_scatter_full de647d4
    PiperOrigin-RevId: 696956113
    serach24 authored and Google-ML-Automation committed Nov 15, 2024
    Configuration menu
    Copy the full SHA
    3d1f6b2 View commit details
    Browse the repository at this point in the history