From fd414d61892bf2b750a6cd3262053146ca3d52c9 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 25 Jun 2024 10:30:11 -0700 Subject: [PATCH] [inductor] don't materialize the large sparse matrix in CE bwd (#129043) Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins. The Fx graph snippets that construct this aforementioned sparse matrix looks like: ``` full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False) scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0); full_default_3 = where_2 = None ``` Leveraging the following observations: - the scatter is applied upon a all zero (or more generally a const tensor) - the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels: ``` def inner_fn(idx): selector_idx = list(idx) selector_idx[dim] = 0 # can do this since the index tensor has a single element on the scatter dimension selector = selector_loader(selector_idx) return ops.where( selector == ops.index_expr(idx[dim], torch.int64), ops.constant(val, dtype), ops.constant(background_val, dtype), ) ``` ## Test result on microbenchmark For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB). The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms ## Test result on llm.c We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check https://github.com/pytorch/pytorch/issues/129258 and https://github.com/pytorch/pytorch/pull/129399 for more details). For llm.c PyTorch implementation on A100, we improve from 171K tokens/s , 33.6G peak memory usage to 180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory) ## Test on PyTorch 2.0 Dashboard The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71). TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129043 Approved by: https://github.com/jansel, https://github.com/Chillee --- test/inductor/test_scatter_optimization.py | 200 +++++++++++++++++++++ torch/_inductor/config.py | 10 ++ torch/_inductor/fx_passes/post_grad.py | 83 +++++++++ torch/_inductor/metrics.py | 3 + 4 files changed, 296 insertions(+) create mode 100644 test/inductor/test_scatter_optimization.py diff --git a/test/inductor/test_scatter_optimization.py b/test/inductor/test_scatter_optimization.py new file mode 100644 index 0000000000000..9929fb956d7a6 --- /dev/null +++ b/test/inductor/test_scatter_optimization.py @@ -0,0 +1,200 @@ +# Owner(s): ["module: inductor"] + +import copy +import os + +import torch +from torch import nn +from torch._dynamo.utils import counters, same +from torch._inductor import metrics +from torch._inductor.runtime.runtime_utils import do_bench_gpu as do_bench +from torch._inductor.test_case import TestCase +from torch.testing._internal.inductor_utils import HAS_GPU + +DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" + + +class TestScatterOpt(TestCase): + def setUp(self): + super().setUp() + metrics.reset() + counters.clear() + + def check_metric(self, val=1): + self.assertEqual(val, metrics.num_matches_for_scatter_upon_const_tensor) + + def do_acc_test(self, f, *args): + expect = f(*args) + actual = torch.compile(f)(*args) + self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n") + + def test_3d_tensor(self): + L, M, N = 2, 1024, 2048 + + def f(x): + y = torch.full([L, M, N], 3.14, dtype=torch.float) + y.scatter_(2, x.unsqueeze(2), 2.718) + return y + + x = torch.randint(0, N, (L, M), dtype=torch.int64) + self.do_acc_test(f, x) + expected_num_bytes = ( + L * M * N * torch.float.itemsize + L * M * torch.int64.itemsize + ) + self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) + + def test_non_last_dim(self): + """ + Test the case that the scatter dimension is not the last one. + """ + M, N = 1024, 2048 + + def f(x): + y = torch.full([M, N], 3.14, dtype=torch.float) + y.scatter_(0, x.unsqueeze(0), 2.718) + return y + + x = torch.randint(0, M, (N,), dtype=torch.int64) + self.do_acc_test(f, x) + expected_num_bytes = M * N * torch.float.itemsize + N * torch.int64.itemsize + self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) + + def test_neg_scatter_dim(self): + M, N = 1024, 2048 + + def f(x): + y = torch.full([M, N], 3.14, dtype=torch.float) + y.scatter_(-1, x.unsqueeze(1), 2.718) + return y + + x = torch.randint(0, N, (M,), dtype=torch.int64) + self.do_acc_test(f, x) + expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize + self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) + + def test_shorter_index_tensor(self): + M, N = 1024, 2048 + + def f(x): + y = torch.full([M, N], 3.14, dtype=torch.float) + y.scatter_(1, x.unsqueeze(1), 2.718) + return y + + x = torch.randint(0, N, (M // 2,), dtype=torch.int64) + self.do_acc_test(f, x) + + # no match since the index tensor is shorter. May support it in future. + self.assertEqual(0, counters["inductor"]["pattern_matcher_count"]) + + def test_nonzero_const_tensor(self): + M, N = 1024, 2048 + + def f(x): + y = torch.full([M, N], 3.14, dtype=torch.float) + y.scatter_(1, x.unsqueeze(1), 2.718) + return y + + x = torch.randint(0, N, (M,), dtype=torch.int64) + self.do_acc_test(f, x) + expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize + self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes) + + def test_can_not_optimize_due_to_dense(self): + M, N = 1024, 2048 + + def f(x): + y = torch.full([M, N], 0, dtype=torch.float) + y.scatter_(1, x, 0.618) + return y + + x = torch.randint(0, N, (M, N // 2), dtype=torch.int64) + self.do_acc_test(f, x) + expected_num_bytes = M * N * torch.float.itemsize + M * (N // 2) * ( + torch.int64.itemsize + torch.float.itemsize + ) + # Use assertGreaterEqual rather than assertEqual due to the issue related + # to StarDep mentioned here: https://github.com/pytorch/pytorch/pull/129043#discussion_r1651699706 + self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) + + def test_can_not_optimize_due_to_non_const(self): + M, N = 1024, 2048 + + def f(x, y): + y.scatter_(1, x, 0.618) + return y + + x = torch.randint(0, N, (M, 1), dtype=torch.int64) + y = torch.randn([M, N]) + self.do_acc_test(f, x, y) + + # The generated code is quite in-efficient. + # There are 3 kernels + # 1. copy from arg to buf + # 2. scatter upon buf + # 3. copy buf back to arg + # Link to the wrapper: https://gist.github.com/shunting314/d43b74e680b3e5b514f7c28160c39f40 + expected_num_bytes = 4 * M * N * torch.float.itemsize + M * ( + torch.int64.itemsize + torch.float.itemsize + ) + self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes) + + # the second kernel and third kernel are both mutation kernel. So we + # overestimated the memory accessed + # Update the test once the overestimiation is fixed. + over_estimate = M * torch.float.itemsize + M * N * torch.float.itemsize + self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes + over_estimate) + + def test_cross_entropy_loss(self): + """ + Match full+scatter in CEL and replaces it with a pointwise. + + Perf data on an A100 GPU: + Without the scatter optimization: + ms=47.340, peak_mem=10.524 GB + With the scatter optimization: + ms=42.768, peak_mem=7.227 GB + """ + B, T, D, V = 32, 1024, 768, 50257 + if not DO_PERF_TEST: + # use a smaller V if not doing perf test to avoid OOM + # in CI + V = V // 100 + ref_model = nn.Linear(D, V).to(torch.bfloat16) + opt_model = copy.deepcopy(ref_model) + ce = nn.CrossEntropyLoss() + + def f(m, x, label): + ce(m(x).view(-1, V), label.view(-1)).backward() + + opt_f = torch.compile(f) + + x = torch.randn(B, T, D).to(torch.bfloat16) + label = torch.randint(0, V, (B, T)).to(torch.int64) + + f(ref_model, x, label) + ref_grad = ref_model.weight.grad + opt_f(opt_model, x, label) + act_grad = opt_model.weight.grad + assert torch.allclose( + ref_grad, act_grad, atol=1e-3, rtol=1e-3 + ), f"{ref_grad=}\n{act_grad=}" + + self.check_metric() + + if DO_PERF_TEST: + torch.cuda.reset_peak_memory_stats() + for _ in range(3): + opt_f(opt_model, x, label) + ms = do_bench(lambda: opt_f(opt_model, x, label)) + peak_mem = torch.cuda.max_memory_allocated() / 10**9 + print(f"{ms=:.3f}, {peak_mem=:.3f} GB") + + +if HAS_GPU: + torch.set_default_device("cuda") + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + if HAS_GPU: + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 6a26677244ee2..eadb79e9c10f7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -402,6 +402,16 @@ def fx_graph_remote_cache_default(): is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__ developer_warnings = is_fbcode() or is_nightly_or_source +# This pattern matches a special usage of scatter +# 1. It's applied to a constant tensor +# 2. The index tensor has size 1 in the scatter dimension +# Such pattern generates a sparse matrix when the const tensor is all-zero. +# We can lower this pattern to a pointwise kernel for more fusion opportunities +# and saving memory footprint. +optimize_scatter_upon_const_tensor = ( + os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1" +) + # The multiprocessing start method to use for inductor workers in the codecache. # "subprocess", "fork", or "spawn" diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index c67471c55ab7c..e6cd5a65a9df5 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -12,6 +12,7 @@ from torch import fx from torch._decomp import register_decomposition from torch._dynamo.utils import counters, optimus_scuba_log +from torch._inductor.virtualized import ops from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype @@ -216,6 +217,88 @@ def is_valid_mm_plus_mm(match: Match): return True +def scatter_upon_const_tensor_extra_check(m): + if not config.optimize_scatter_upon_const_tensor: + return False + full_shape = m.kwargs["shape"] + selector = m.kwargs["selector"] + dim = m.kwargs["dim"] + if dim < 0: + dim += len(full_shape) + + selector_ft = selector.meta["val"] + assert selector_ft.dim() == len(full_shape) + + for idx, select_sz, full_sz in zip( + itertools.count(), selector_ft.shape, full_shape + ): + if idx == dim: + continue + + # TODO: the pattern can be updated to support the case that index tensor + # is shorter. But that will need a more complex condition expression + # especially for multi-dimensional tensors. + # Skip it for now. + if isinstance(full_sz, fx.Node): + full_sz = full_sz.meta["val"] + if select_sz < full_sz: + return False + + # Actually we can support small size larger than 1. It would be a bit + # tedius. E.g., we load all the index values (not many) and compare + # them with the position in tensor to decide what value to return. + return selector_ft.size(dim) == 1 + + +@register_lowering_pattern( + CallFunction( + aten.scatter.value, + CallFunction( + aten.full, + KeywordArg("shape"), + KeywordArg("background_val"), + dtype=KeywordArg("dtype"), + ), + KeywordArg("dim"), + KeywordArg("selector"), + KeywordArg("val"), # scalar value + ), + extra_check=scatter_upon_const_tensor_extra_check, +) +def scatter_upon_const_tensor( + match: Match, shape, background_val, dtype, dim, selector, val +): + """ + Match the pattern of full+scatter into a pointwise. + + TODO: Right now the scatter value must be a scalar. But we could support it + when it is a tensor as well. + """ + from torch._inductor import metrics + + metrics.num_matches_for_scatter_upon_const_tensor += 1 + + selector_loader = selector.make_loader() + + def inner_fn(idx): + selector_idx = list(idx) + selector_idx[dim] = 0 + + selector = selector_loader(selector_idx) + return ops.where( + selector == ops.index_expr(idx[dim], torch.int64), + ops.constant(val, dtype), + ops.constant(background_val, dtype), + ) + + return ir.Pointwise.create( + device=selector.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=shape, + ) + + @register_lowering_pattern( CallFunction( aten.add, diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 3d8de535542e7..fc7d0e6a7ab70 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -46,6 +46,7 @@ cpp_outer_loop_fused_inner_counts: List[int] = [] num_comprehensive_padding = 0 +num_matches_for_scatter_upon_const_tensor = 0 # reset all counters @@ -57,6 +58,7 @@ def reset(): global cpp_to_dtype_count global cpp_outer_loop_fused_inner_counts global num_comprehensive_padding + global num_matches_for_scatter_upon_const_tensor generated_kernel_count = 0 generated_cpp_vec_kernel_count = 0 @@ -67,6 +69,7 @@ def reset(): cpp_to_dtype_count = 0 cpp_outer_loop_fused_inner_counts.clear() num_comprehensive_padding = 0 + num_matches_for_scatter_upon_const_tensor = 0 @dataclass