forked from openxla/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Offloading 2/3: generate multi-stream async copies using events on GPUs #2
Closed
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
zhenying-liu
force-pushed
the
offloading/annotator
branch
from
March 20, 2024 06:23
47a2ebd
to
11148fe
Compare
…xcess_precision Imported from GitHub PR openxla#10687 Several weeks ago it was a change which enables "simplify-fp-conversions" pass in cpu_compiler.cc for intel cpus unconditionally. [PR-8402](openxla#8402) - [XLA:CPU] [oneDNN] Enable Dot op (MatMul) in BF16 Type I noticed the following issue with having "simplify-fp-conversions" pass in cpu_compiler.cc enabled unconditionally. My model uses bf16 operators (e.g. convolution). I want to jit compile and run it on CPU preserving intermediate bf16 accuracy. Cpu compiler uses`float-normalization-bf16` pass which converts bf16 convolution to f32_convolution + convert_to_bf16 + convert_to_f32. (because typical cpu does not support bf16 computation) Cpu compiler (on XEON) also uses `simplify-fp-conversions` pass which simplifies `f32_convolution + convert_to_bf16 + convert_to_f32` to just `f32_convolution`. As the result - the whole model was converted to f32 precision internally and conversion to bf16 happens only at the very end. In some cases we want to execute bf16 model on CPU but get results with accuracy similar to the case when it is executed on bf16 hardware. To control the accuracy we can use debug_option `xla_allow_excess_precision` By default it is true - hence, `simplify-fp-conversions` pass is enabled. If we need to emulate bf16 computation on intel cpu we can set `XLA_FLAGS="--xla_allow_excess_precision=false"` - in this case `simplify-fp-conversions` will not be added to cpu_compiler pipeline. f32 ops results will be converted to bf16 immediately. This will preserve bf16 accuracy internally. [gpu_compiler.cc](https://github.com/openxla/xla/blob/main/xla/service/gpu/gpu_compiler.cc#L1359) already enables `SimplifyFPConversions` pass only if `debug_options.xla_allow_excess_precision()` is true. Copybara import of the project: -- 796dc83 by Alexander Pivovarov <pivovaa@amazon.com>: [CPU] Add SimplifyFPConversions only if xla_allow_excess_precision Merging this change closes openxla#10687 COPYBARA_INTEGRATE_REVIEW=openxla#10687 from apivovarov:fix_cpu_SimplifyFPConversions 796dc83 PiperOrigin-RevId: 617460913
While there, also fix a bug in the test for Slice. The slice operand should have the same layout as the slice. PiperOrigin-RevId: 617473515
PiperOrigin-RevId: 617476081
Imported from GitHub PR openxla#10626 This PR: 1. Adds a pattern for BF16 softmax. 2. Adds a test to verify rewrite and execution result. 3. Skips BF16 matmul / softmax tests on machines that do not support BF16 Copybara import of the project: -- f0370a1 by Akhil Goel <akhil.goel@intel.com>: Add Softmax BF16 pattern and test -- 62af424 by Akhil Goel <akhil.goel@intel.com>: Declare convert_instr identifier -- d030100 by Akhil Goel <akhil.goel@intel.com>: Fix formatting Merging this change closes openxla#10626 COPYBARA_INTEGRATE_REVIEW=openxla#10626 from Intel-tensorflow:akhil/bf16_softmax d030100 PiperOrigin-RevId: 617477998
This change lookups and erases a few more elements from `reverse_map_`, but overall it shouldn't make a dent in the compile time. With this change, we check if priority is negative before inserting into the queue, not after. This also doesn't affect compile time, but arguably makes the code easier to understand. PiperOrigin-RevId: 617478676
…lation on stream executor and cuDNN handle. Imported from GitHub PR openxla#10711 LocalCuDnnHandle can now be created only via CudnnAccess which ensures that main cuDNN handle is created first. DnnGraph methods using cuDNN handle now require DnnSupport&. Copybara import of the project: -- 44baf4b by Ilia Sergachev <isergachev@nvidia.com>: [GPU][NFC] Clarify dependency of cuDNN fusion compilation on stream executor and cuDNN handle. Merging this change closes openxla#10711 COPYBARA_INTEGRATE_REVIEW=openxla#10711 from openxla:cudnn_handle_cleanup 44baf4b PiperOrigin-RevId: 617480424
Imported from GitHub PR openxla#10536 Enables the fusion of layer norm patterns with degenerate input dimensions of size 1 and with additional optional type conversions or reshapes adding or removing degenerate dimensions. Copybara import of the project: -- 54e247d by Philipp Hack <phack@nvidia.com>: Layer norm fusion for inputs with degenerate dimensions. -- 7d766bf by Philipp Hack <phack@nvidia.com>: Layer norm fusion for inputs with degenerate dimensions. -- 14075f1 by Philipp Hack <phack@nvidia.com>: Layer norm fusion for inputs with degenerate dimensions. Merging this change closes openxla#10536 COPYBARA_INTEGRATE_REVIEW=openxla#10536 from philipphack:u_layer_size1_xla 14075f1 PiperOrigin-RevId: 617484341
PiperOrigin-RevId: 617491428
PiperOrigin-RevId: 617496344
…n related changes. Imported from GitHub PR openxla#10510 First commit of the series for enabling Triton in XLA for ROCm . Copybara import of the project: -- 832d725 by Zoran Jovanovic <zjovanov@amd.com>: [ROCm] Triton in XLA for ROCm - gemm rewriter triton related changes. Merging this change closes openxla#10510 COPYBARA_INTEGRATE_REVIEW=openxla#10510 from ROCm:rocm_triton_backend 832d725 PiperOrigin-RevId: 617500238
Imported from GitHub PR openxla#10489 CuDnnCmd is constructed before DnnGraph in CuDnnThunk is initialized so CuDnnCmd has to get unique_ptr\<DnnGraph\>& instead of DnnGraph& at initialization. Accordingly cuDNN thunks have to be initialized before command buffer ones to initialize graphs before they get captured. Test CommandBuffersAreSupported used to not demonstrate the use of command buffers because the corresponding command buffer call used to be inlined and no command buffers were created. This is now cleaned up and does work as expected with minimal CUDA graph size set to 1 with a flag. Copybara import of the project: -- 8547c67 by Ilia Sergachev <isergachev@nvidia.com>: [GPU] Fix command buffer support for cuDNN fusions. Merging this change closes openxla#10489 COPYBARA_INTEGRATE_REVIEW=openxla#10489 from openxla:fix_cudnn_cmd_buffers 8547c67 PiperOrigin-RevId: 617505164
Imported from GitHub PR openxla#10722 Moves the unit tests in `xla/service/gpu/gemm_rewriter_test.cc` into `xla/service/gpu/tests/gemm_rewrite_test.cc`. Copybara import of the project: -- e3056a1 by Philipp Hack <phack@nvidia.com>: Merge GEMM rewriter unit tests. Merging this change closes openxla#10722 COPYBARA_INTEGRATE_REVIEW=openxla#10722 from philipphack:u_merged_gemm_test_xla e3056a1 PiperOrigin-RevId: 617518997
…ner to pre-fusion. Previously, the Send-Recv pipeliner is invoked as a post-fusion pass. This is a problem for two reasons. The loop bound analysis in the pipeliner can't recognize loop bound in fusioned computation; performing loop pipelining before fusion can enable better fusion decision. The Send-Recv pipeliner relies on the collective-permute-decomposer to generate Send-Recv operations. We now change the collective-permute-decomposer to process collective-permute operations instead of collective-permute-operations and move this pass along with the Send-Recv pass to pre-fusion. PiperOrigin-RevId: 617520340
Updates LLVM usage to match [d93cfd8dab57](llvm/llvm-project@d93cfd8dab57) PiperOrigin-RevId: 617522339
…done Imported from GitHub PR openxla#10636 Add stream id in the backend_config of copy-start instruction. The stream id is obtained from hlo_query::NextChannelId(). The corresponding copy-done instruction which is the use of copy-start instruction will be traversed and added the stream id in the backend_config too. This part is automatically done by the subsequent AnnotateStreamAttributesForUsers() existing in the function. The bool data member copy_start_done_ is used to differentiate copy-start/copy-done from other collective instructions and go through two different paths. openxla#10450 is split and the current PR is the first 1 out of 3 PRs. Copybara import of the project: -- ff99c16 by Jane Liu <janeliu@nvidia.com>: Add annotation of stream id for copy-start and its use of copy-done instruction -- 64c746a by Jane Liu <janeliu@nvidia.com>: Enable the annotator for copy-start/copy-done in gpu compiler -- 7972b98 by Jane Liu <janeliu@nvidia.com>: Add the annotator pass after HLO rematerialization pass -- 2a9e317 by Jane Liu <janeliu@nvidia.com>: Add the dependency in BUILD -- db5ed79 by Jane Liu <janeliu@nvidia.com>: Use a function to annotate copy-start and add description -- b4b3932 by Jane Liu <janeliu@nvidia.com>: remove the bool var copy_start from the StreamAttributeAnnotator class -- 11148fe by Jane Liu <janeliu@nvidia.com>: Fixes according to the code review Merging this change closes openxla#10636 COPYBARA_INTEGRATE_REVIEW=openxla#10636 from zhenying-liu:offloading/annotator 11148fe PiperOrigin-RevId: 617550093
…ization. Drive-by: Shard `onednn_softmax_test` test. PiperOrigin-RevId: 617597142
…sstable. In this case, we don't call hash/eq on the first insertion with SOO so the crash doesn't take place on every insertion anymore. PiperOrigin-RevId: 617606725
…n oneDNN MatMul Rewriter Imported from GitHub PR openxla#10747 This PR fixes a bug in MAC (multiply-accumulate) / FMA(fused-multiply-add) count in dot operation when B matrix is transposed. Copybara import of the project: -- d989fdb by mdfaijul <md.faijul.amin@intel.com>: Fix MAC(Multiply Accumulate) count Merging this change closes openxla#10747 COPYBARA_INTEGRATE_REVIEW=openxla#10747 from Intel-tensorflow:amin/fix-mac-count d989fdb PiperOrigin-RevId: 617616266
zhenying-liu
force-pushed
the
offloading/async_copy
branch
from
March 21, 2024 05:22
9f0be2c
to
f10e739
Compare
Messy Part 2 of the stacked PR without a PR number. |
zhenying-liu
pushed a commit
that referenced
this pull request
May 2, 2024
PiperOrigin-RevId: 629867362
zhenying-liu
pushed a commit
that referenced
this pull request
May 13, 2024
…d phase to Initialize() Imported from GitHub PR openxla#12228 The first time that a NormThunk is executed, it will build a cudnn execution plan. This build step can hang if a NCCL collective is running at the same time. To fix this, I've moved the build step to take place during thunk initialization. We only observe this hang when using cudnn 9. Here's a backtrace from the hang that will be fixed: ``` Thread 585 (Thread 0x7fb9391ff640 (LWP 41364) "main.py"): #0 0x00007fd3d17cffd9 in ?? () from /lib/x86_64-linux-gnu/libc.so.6 #1 0x00007fd3d17da24f in pthread_rwlock_wrlock () from /lib/x86_64-linux-gnu/libc.so.6 #2 0x00007fd070967dfe in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1 #3 0x00007fd0709c928a in ?? () from /lib/x86_64-linux-gnu/libcuda.so.1 openxla#4 0x00007f1970d76102 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 openxla#5 0x00007f1970f2c999 in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 openxla#6 0x00007f1970a7d4ab in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 openxla#7 0x00007f1970d0a9cb in ?? () from /lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0 openxla#8 0x00007fce60b2a98c in cudnn::backend::ExecutionPlan::finalize_internal() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0 openxla#9 0x00007fce60aefbb1 in cudnn::backend::Descriptor::finalize() () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0 openxla#10 0x00007fce60b15bec in cudnnBackendFinalize () from /lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0 #11 0x00007fd2521b8f39 in cudnn_frontend::ExecutionPlanBuilder_v8::build() () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #12 0x00007fd2521734ba in stream_executor::gpu::(anonymous namespace)::GetExecPlanFromHeuristics(cudnn_frontend::OperationGraph_v8&&, stream_executor::gpu::(anonymous namespace)::CudnnHandle const&, bool) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so openxla#13 0x00007fd25216ff9b in stream_executor::gpu::CudnnSupport::NormRunnerFromDesc(stream_executor::Stream*, stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormKind, double, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, stream_executor::dnn::TensorDescriptor const&, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>, std::optional<stream_executor::dnn::TensorDescriptor>) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so #14 0x00007fd24e36b88b in stream_executor::dnn::NormOp::RunnerFromAlgorithmDesc(stream_executor::dnn::AlgorithmDesc const&, stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so openxla#15 0x00007fd24e36ae37 in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}::operator()() const () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so openxla#16 0x00007fd24e36adbc in void absl::lts_20230802::base_internal::CallOnceImpl<stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}>(std::atomic<unsigned int>*, absl::lts_20230802::base_internal::SchedulingMode, stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*)::{lambda()#1}&&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so openxla#17 0x00007fd24e36a9bd in stream_executor::dnn::LazyOpRunner<stream_executor::dnn::NormOp>::GetOrCreateRunner(stream_executor::dnn::NormOp::Config, stream_executor::Stream*) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so openxla#18 0x00007fd24e369d29 in xla::gpu::RunGpuNorm(xla::gpu::GpuNormConfig const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, stream_executor::DeviceMemoryBase const&, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, std::optional<stream_executor::DeviceMemoryBase>, stream_executor::DeviceMemoryBase const&, stream_executor::Stream*, xla::gpu::RunNormOptions) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so openxla#19 0x00007fd24e368be6 in xla::gpu::NormThunk::ExecuteOnStream(xla::gpu::Thunk::ExecuteParams const&) () from /usr/local/lib/python3.10/dist-packages/jaxlib/xla_extension.so ``` Copybara import of the project: -- f535330 by Trevor Morris <tmorris@nvidia.com>: Fix hang with cudnn layer norm by moving cudnn init to Initialize() Merging this change closes openxla#12228 COPYBARA_INTEGRATE_REVIEW=openxla#12228 from trevor-m:tmorris-norm-init f535330 PiperOrigin-RevId: 633220207
zhenying-liu
pushed a commit
that referenced
this pull request
Aug 6, 2024
Use std::aligned_storage_t trick to avoid default-initializing Node struct on a hot path. name old cpu/op new cpu/op delta BM_SelectAndScatterF32/128/process_time 791µs ± 4% 720µs ± 2% -8.93% BM_SelectAndScatterF32/256/process_time 3.20ms ± 4% 2.96ms ± 2% -7.46% BM_SelectAndScatterF32/512/process_time 13.7ms ± 5% 12.8ms ± 2% -6.80% name old time/op new time/op delta BM_SelectAndScatterF32/128/process_time 790µs ± 5% 719µs ± 1% -9.00% BM_SelectAndScatterF32/256/process_time 3.20ms ± 3% 2.96ms ± 1% -7.58% BM_SelectAndScatterF32/512/process_time 13.2ms ± 4% 12.3ms ± 1% -6.82% PiperOrigin-RevId: 658139935
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Emit asynchronous memory copies between hosts and devices using additional streams while the main stream does the computation. The async copies are guarded by RecordEvent() and WaitForEvent() created by the copy-start/copy-done thunks respectively. A hash table is utilized to map copy-start instructions to events. The corresponding events will be waited at copy-done and extracted from the hash table.