Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] don't materialize the large sparse matrix in CE bwd (pytor…
…ch#129043) Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins. The Fx graph snippets that construct this aforementioned sparse matrix looks like: ``` full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0); full_default_3 = where_2 = None ``` Leveraging the following observations: - the scatter is applied upon a all zero (or more generally a const tensor) - the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels: ``` def inner_fn(idx): selector_idx = list(idx) selector_idx[dim] = 0 # can do this since the index tensor has a single element on the scatter dimension selector = selector_loader(selector_idx) return ops.where( selector == ops.index_expr(idx[dim], torch.int64), ops.constant(val, dtype), ops.constant(background_val, dtype), ) ``` ## Test result on microbenchmark For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB). The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms ## Test result on llm.c We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check pytorch#129258 and pytorch#129399 for more details). For llm.c PyTorch implementation on A100, we improve from 171K tokens/s , 33.6G peak memory usage to 180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory) ## Test on PyTorch 2.0 Dashboard The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71). TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement. Pull Request resolved: pytorch#129043 Approved by: https://github.com/jansel, https://github.com/Chillee
- Loading branch information