forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] don't materialize the large sparse matrix in CE bwd (pytor…
…ch#129043) Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins. The Fx graph snippets that construct this aforementioned sparse matrix looks like: ``` full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0); full_default_3 = where_2 = None ``` Leveraging the following observations: - the scatter is applied upon a all zero (or more generally a const tensor) - the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels: ``` def inner_fn(idx): selector_idx = list(idx) selector_idx[dim] = 0 # can do this since the index tensor has a single element on the scatter dimension selector = selector_loader(selector_idx) return ops.where( selector == ops.index_expr(idx[dim], torch.int64), ops.constant(val, dtype), ops.constant(background_val, dtype), ) ``` ## Test result on microbenchmark For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB). The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms ## Test result on llm.c We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check pytorch#129258 and pytorch#129399 for more details). For llm.c PyTorch implementation on A100, we improve from 171K tokens/s , 33.6G peak memory usage to 180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory) ## Test on PyTorch 2.0 Dashboard The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71). TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement. Pull Request resolved: pytorch#129043 Approved by: https://github.com/jansel, https://github.com/Chillee
- Loading branch information
1 parent
e1499f6
commit fd414d6
Showing
4 changed files
with
296 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# Owner(s): ["module: inductor"] | ||
|
||
import copy | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
from torch._dynamo.utils import counters, same | ||
from torch._inductor import metrics | ||
from torch._inductor.runtime.runtime_utils import do_bench_gpu as do_bench | ||
from torch._inductor.test_case import TestCase | ||
from torch.testing._internal.inductor_utils import HAS_GPU | ||
|
||
DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" | ||
|
||
|
||
class TestScatterOpt(TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
metrics.reset() | ||
counters.clear() | ||
|
||
def check_metric(self, val=1): | ||
self.assertEqual(val, metrics.num_matches_for_scatter_upon_const_tensor) | ||
|
||
def do_acc_test(self, f, *args): | ||
expect = f(*args) | ||
actual = torch.compile(f)(*args) | ||
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n") | ||
|
||
def test_3d_tensor(self): | ||
L, M, N = 2, 1024, 2048 | ||
|
||
def f(x): | ||
y = torch.full([L, M, N], 3.14, dtype=torch.float) | ||
y.scatter_(2, x.unsqueeze(2), 2.718) | ||
return y | ||
|
||
x = torch.randint(0, N, (L, M), dtype=torch.int64) | ||
self.do_acc_test(f, x) | ||
expected_num_bytes = ( | ||
L * M * N * torch.float.itemsize + L * M * torch.int64.itemsize | ||
) | ||
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) | ||
|
||
def test_non_last_dim(self): | ||
""" | ||
Test the case that the scatter dimension is not the last one. | ||
""" | ||
M, N = 1024, 2048 | ||
|
||
def f(x): | ||
y = torch.full([M, N], 3.14, dtype=torch.float) | ||
y.scatter_(0, x.unsqueeze(0), 2.718) | ||
return y | ||
|
||
x = torch.randint(0, M, (N,), dtype=torch.int64) | ||
self.do_acc_test(f, x) | ||
expected_num_bytes = M * N * torch.float.itemsize + N * torch.int64.itemsize | ||
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) | ||
|
||
def test_neg_scatter_dim(self): | ||
M, N = 1024, 2048 | ||
|
||
def f(x): | ||
y = torch.full([M, N], 3.14, dtype=torch.float) | ||
y.scatter_(-1, x.unsqueeze(1), 2.718) | ||
return y | ||
|
||
x = torch.randint(0, N, (M,), dtype=torch.int64) | ||
self.do_acc_test(f, x) | ||
expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize | ||
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) | ||
|
||
def test_shorter_index_tensor(self): | ||
M, N = 1024, 2048 | ||
|
||
def f(x): | ||
y = torch.full([M, N], 3.14, dtype=torch.float) | ||
y.scatter_(1, x.unsqueeze(1), 2.718) | ||
return y | ||
|
||
x = torch.randint(0, N, (M // 2,), dtype=torch.int64) | ||
self.do_acc_test(f, x) | ||
|
||
# no match since the index tensor is shorter. May support it in future. | ||
self.assertEqual(0, counters["inductor"]["pattern_matcher_count"]) | ||
|
||
def test_nonzero_const_tensor(self): | ||
M, N = 1024, 2048 | ||
|
||
def f(x): | ||
y = torch.full([M, N], 3.14, dtype=torch.float) | ||
y.scatter_(1, x.unsqueeze(1), 2.718) | ||
return y | ||
|
||
x = torch.randint(0, N, (M,), dtype=torch.int64) | ||
self.do_acc_test(f, x) | ||
expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize | ||
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) | ||
|
||
def test_can_not_optimize_due_to_dense(self): | ||
M, N = 1024, 2048 | ||
|
||
def f(x): | ||
y = torch.full([M, N], 0, dtype=torch.float) | ||
y.scatter_(1, x, 0.618) | ||
return y | ||
|
||
x = torch.randint(0, N, (M, N // 2), dtype=torch.int64) | ||
self.do_acc_test(f, x) | ||
expected_num_bytes = M * N * torch.float.itemsize + M * (N // 2) * ( | ||
torch.int64.itemsize + torch.float.itemsize | ||
) | ||
# Use assertGreaterEqual rather than assertEqual due to the issue related | ||
# to StarDep mentioned here: https://github.com/pytorch/pytorch/pull/129043#discussion_r1651699706 | ||
self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) | ||
|
||
def test_can_not_optimize_due_to_non_const(self): | ||
M, N = 1024, 2048 | ||
|
||
def f(x, y): | ||
y.scatter_(1, x, 0.618) | ||
return y | ||
|
||
x = torch.randint(0, N, (M, 1), dtype=torch.int64) | ||
y = torch.randn([M, N]) | ||
self.do_acc_test(f, x, y) | ||
|
||
# The generated code is quite in-efficient. | ||
# There are 3 kernels | ||
# 1. copy from arg to buf | ||
# 2. scatter upon buf | ||
# 3. copy buf back to arg | ||
# Link to the wrapper: https://gist.github.com/shunting314/d43b74e680b3e5b514f7c28160c39f40 | ||
expected_num_bytes = 4 * M * N * torch.float.itemsize + M * ( | ||
torch.int64.itemsize + torch.float.itemsize | ||
) | ||
self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) | ||
|
||
# the second kernel and third kernel are both mutation kernel. So we | ||
# overestimated the memory accessed | ||
# Update the test once the overestimiation is fixed. | ||
over_estimate = M * torch.float.itemsize + M * N * torch.float.itemsize | ||
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes + over_estimate) | ||
|
||
def test_cross_entropy_loss(self): | ||
""" | ||
Match full+scatter in CEL and replaces it with a pointwise. | ||
Perf data on an A100 GPU: | ||
Without the scatter optimization: | ||
ms=47.340, peak_mem=10.524 GB | ||
With the scatter optimization: | ||
ms=42.768, peak_mem=7.227 GB | ||
""" | ||
B, T, D, V = 32, 1024, 768, 50257 | ||
if not DO_PERF_TEST: | ||
# use a smaller V if not doing perf test to avoid OOM | ||
# in CI | ||
V = V // 100 | ||
ref_model = nn.Linear(D, V).to(torch.bfloat16) | ||
opt_model = copy.deepcopy(ref_model) | ||
ce = nn.CrossEntropyLoss() | ||
|
||
def f(m, x, label): | ||
ce(m(x).view(-1, V), label.view(-1)).backward() | ||
|
||
opt_f = torch.compile(f) | ||
|
||
x = torch.randn(B, T, D).to(torch.bfloat16) | ||
label = torch.randint(0, V, (B, T)).to(torch.int64) | ||
|
||
f(ref_model, x, label) | ||
ref_grad = ref_model.weight.grad | ||
opt_f(opt_model, x, label) | ||
act_grad = opt_model.weight.grad | ||
assert torch.allclose( | ||
ref_grad, act_grad, atol=1e-3, rtol=1e-3 | ||
), f"{ref_grad=}\n{act_grad=}" | ||
|
||
self.check_metric() | ||
|
||
if DO_PERF_TEST: | ||
torch.cuda.reset_peak_memory_stats() | ||
for _ in range(3): | ||
opt_f(opt_model, x, label) | ||
ms = do_bench(lambda: opt_f(opt_model, x, label)) | ||
peak_mem = torch.cuda.max_memory_allocated() / 10**9 | ||
print(f"{ms=:.3f}, {peak_mem=:.3f} GB") | ||
|
||
|
||
if HAS_GPU: | ||
torch.set_default_device("cuda") | ||
|
||
if __name__ == "__main__": | ||
from torch._inductor.test_case import run_tests | ||
|
||
if HAS_GPU: | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters