Skip to content

Commit

Permalink
[inductor] don't materialize the large sparse matrix in CE bwd (pytor…
Browse files Browse the repository at this point in the history
…ch#129043)

Inductor currently materialize a large sparse matrix in the backward pass for CrossEntropyLoss and load that to compute gradients of Softmax input. If we could fuse the sparse matrix computation to the consumer sides, we gonna have both perf and memory usage wins.

The Fx graph snippets that construct this aforementioned sparse matrix looks like:
```
       full_default_3: "bf16[32768, 50257]" = torch.ops.aten.full.default([32768, 50257], 0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
       scatter: "bf16[32768, 50257]" = torch.ops.aten.scatter.value(full_default_3, 1, where_2, -1.0);  full_default_3 = where_2 = None
```
Leveraging the following observations:
- the scatter is applied upon a all zero (or more generally a const tensor)
- the index tensor for the scatter has a single element on the scatter dimension. In this case it's the label tensor

allow us to lower this 'scatter_upon_const_tensor' pattern to a pointwise kernel that can be easily fused with downstream kernels:

```
    def inner_fn(idx):
        selector_idx = list(idx)
        selector_idx[dim] = 0  # can do this since the index tensor has a single element on the scatter dimension

        selector = selector_loader(selector_idx)
        return ops.where(
            selector == ops.index_expr(idx[dim], torch.int64),
            ops.constant(val, dtype),
            ops.constant(background_val, dtype),
        )
```

## Test result on microbenchmark

For the microbenchmark added as `test_cross_entropy_loss`, we improve latency from 47.340ms to 42.768ms, memory footprint from 10.524GB to 7.227GB on A100. (on H100, we improve latency from 27.54ms to 23.51ms, memory footprint from 10.574GB to 7.354GB).

The saving matches the back-of-envelope calculation. We avoid storing a BF16 tensor with shape [30K, 50K] which is about 3GB in size. On A100, avoid loading and storing such a tensor can roughly save 3GB x 2 / 1.5TBGS = 4ms

## Test result on llm.c

We also test this on llm.c and the saving is much larger especially for memory footprint. The reason is due to autotuning that allocates extra memory for benchmarking. (Check pytorch#129258 and pytorch#129399 for more details).

For llm.c PyTorch implementation on A100, we improve from
171K tokens/s , 33.6G peak memory usage to
180K tokens/s, 18.6G peak memory usage. (A **45%** saving of peak memory)

## Test on PyTorch 2.0 Dashboard

The optimization is quite general especially for transformers. We tested this on PyTorch2.0 dashboard. Here is the [result](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2017%20Jun%202024%2018%3A07%3A51%20GMT&stopTime=Mon%2C%2024%20Jun%202024%2018%3A07%3A51%20GMT&granularity=hour&suite=torchbench&mode=training&dtype=amp&lBranch=gh/shunting314/158/head&lCommit=c62c55e29c65497d495217b6574bb36b0c4da7d4&rBranch=main&rCommit=0d25f096c1beaf8749932a3d6083ad653405ed71).

TLDR, for Huggingface benchmark suite, we get **6%** geomean perf improvement and **10%** geomean memory footprint improvement.

Pull Request resolved: pytorch#129043
Approved by: https://github.com/jansel, https://github.com/Chillee
  • Loading branch information
shunting314 authored and pytorchmergebot committed Jun 25, 2024
1 parent e1499f6 commit fd414d6
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 0 deletions.
200 changes: 200 additions & 0 deletions test/inductor/test_scatter_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Owner(s): ["module: inductor"]

import copy
import os

import torch
from torch import nn
from torch._dynamo.utils import counters, same
from torch._inductor import metrics
from torch._inductor.runtime.runtime_utils import do_bench_gpu as do_bench
from torch._inductor.test_case import TestCase
from torch.testing._internal.inductor_utils import HAS_GPU

DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1"


class TestScatterOpt(TestCase):
def setUp(self):
super().setUp()
metrics.reset()
counters.clear()

def check_metric(self, val=1):
self.assertEqual(val, metrics.num_matches_for_scatter_upon_const_tensor)

def do_acc_test(self, f, *args):
expect = f(*args)
actual = torch.compile(f)(*args)
self.assertTrue(same(expect, actual, tol=1e-3), f"{expect=}\n{actual=}\n")

def test_3d_tensor(self):
L, M, N = 2, 1024, 2048

def f(x):
y = torch.full([L, M, N], 3.14, dtype=torch.float)
y.scatter_(2, x.unsqueeze(2), 2.718)
return y

x = torch.randint(0, N, (L, M), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = (
L * M * N * torch.float.itemsize + L * M * torch.int64.itemsize
)
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_non_last_dim(self):
"""
Test the case that the scatter dimension is not the last one.
"""
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(0, x.unsqueeze(0), 2.718)
return y

x = torch.randint(0, M, (N,), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + N * torch.int64.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_neg_scatter_dim(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(-1, x.unsqueeze(1), 2.718)
return y

x = torch.randint(0, N, (M,), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_shorter_index_tensor(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(1, x.unsqueeze(1), 2.718)
return y

x = torch.randint(0, N, (M // 2,), dtype=torch.int64)
self.do_acc_test(f, x)

# no match since the index tensor is shorter. May support it in future.
self.assertEqual(0, counters["inductor"]["pattern_matcher_count"])

def test_nonzero_const_tensor(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 3.14, dtype=torch.float)
y.scatter_(1, x.unsqueeze(1), 2.718)
return y

x = torch.randint(0, N, (M,), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + M * torch.int64.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_can_not_optimize_due_to_dense(self):
M, N = 1024, 2048

def f(x):
y = torch.full([M, N], 0, dtype=torch.float)
y.scatter_(1, x, 0.618)
return y

x = torch.randint(0, N, (M, N // 2), dtype=torch.int64)
self.do_acc_test(f, x)
expected_num_bytes = M * N * torch.float.itemsize + M * (N // 2) * (
torch.int64.itemsize + torch.float.itemsize
)
# Use assertGreaterEqual rather than assertEqual due to the issue related
# to StarDep mentioned here: https://github.com/pytorch/pytorch/pull/129043#discussion_r1651699706
self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes)

def test_can_not_optimize_due_to_non_const(self):
M, N = 1024, 2048

def f(x, y):
y.scatter_(1, x, 0.618)
return y

x = torch.randint(0, N, (M, 1), dtype=torch.int64)
y = torch.randn([M, N])
self.do_acc_test(f, x, y)

# The generated code is quite in-efficient.
# There are 3 kernels
# 1. copy from arg to buf
# 2. scatter upon buf
# 3. copy buf back to arg
# Link to the wrapper: https://gist.github.com/shunting314/d43b74e680b3e5b514f7c28160c39f40
expected_num_bytes = 4 * M * N * torch.float.itemsize + M * (
torch.int64.itemsize + torch.float.itemsize
)
self.assertGreaterEqual(metrics.num_bytes_accessed, expected_num_bytes)

# the second kernel and third kernel are both mutation kernel. So we
# overestimated the memory accessed
# Update the test once the overestimiation is fixed.
over_estimate = M * torch.float.itemsize + M * N * torch.float.itemsize
self.assertEqual(metrics.num_bytes_accessed, expected_num_bytes + over_estimate)

def test_cross_entropy_loss(self):
"""
Match full+scatter in CEL and replaces it with a pointwise.
Perf data on an A100 GPU:
Without the scatter optimization:
ms=47.340, peak_mem=10.524 GB
With the scatter optimization:
ms=42.768, peak_mem=7.227 GB
"""
B, T, D, V = 32, 1024, 768, 50257
if not DO_PERF_TEST:
# use a smaller V if not doing perf test to avoid OOM
# in CI
V = V // 100
ref_model = nn.Linear(D, V).to(torch.bfloat16)
opt_model = copy.deepcopy(ref_model)
ce = nn.CrossEntropyLoss()

def f(m, x, label):
ce(m(x).view(-1, V), label.view(-1)).backward()

opt_f = torch.compile(f)

x = torch.randn(B, T, D).to(torch.bfloat16)
label = torch.randint(0, V, (B, T)).to(torch.int64)

f(ref_model, x, label)
ref_grad = ref_model.weight.grad
opt_f(opt_model, x, label)
act_grad = opt_model.weight.grad
assert torch.allclose(
ref_grad, act_grad, atol=1e-3, rtol=1e-3
), f"{ref_grad=}\n{act_grad=}"

self.check_metric()

if DO_PERF_TEST:
torch.cuda.reset_peak_memory_stats()
for _ in range(3):
opt_f(opt_model, x, label)
ms = do_bench(lambda: opt_f(opt_model, x, label))
peak_mem = torch.cuda.max_memory_allocated() / 10**9
print(f"{ms=:.3f}, {peak_mem=:.3f} GB")


if HAS_GPU:
torch.set_default_device("cuda")

if __name__ == "__main__":
from torch._inductor.test_case import run_tests

if HAS_GPU:
run_tests()
10 changes: 10 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,16 @@ def fx_graph_remote_cache_default():
is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
developer_warnings = is_fbcode() or is_nightly_or_source

# This pattern matches a special usage of scatter
# 1. It's applied to a constant tensor
# 2. The index tensor has size 1 in the scatter dimension
# Such pattern generates a sparse matrix when the const tensor is all-zero.
# We can lower this pattern to a pointwise kernel for more fusion opportunities
# and saving memory footprint.
optimize_scatter_upon_const_tensor = (
os.environ.get("TORCHINDUCTOR_OPTIMIZE_SCATTER_UPON_CONST_TENSOR", "1") == "1"
)


# The multiprocessing start method to use for inductor workers in the codecache.
# "subprocess", "fork", or "spawn"
Expand Down
83 changes: 83 additions & 0 deletions torch/_inductor/fx_passes/post_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import fx
from torch._decomp import register_decomposition
from torch._dynamo.utils import counters, optimus_scuba_log
from torch._inductor.virtualized import ops

from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype

Expand Down Expand Up @@ -216,6 +217,88 @@ def is_valid_mm_plus_mm(match: Match):
return True


def scatter_upon_const_tensor_extra_check(m):
if not config.optimize_scatter_upon_const_tensor:
return False
full_shape = m.kwargs["shape"]
selector = m.kwargs["selector"]
dim = m.kwargs["dim"]
if dim < 0:
dim += len(full_shape)

selector_ft = selector.meta["val"]
assert selector_ft.dim() == len(full_shape)

for idx, select_sz, full_sz in zip(
itertools.count(), selector_ft.shape, full_shape
):
if idx == dim:
continue

# TODO: the pattern can be updated to support the case that index tensor
# is shorter. But that will need a more complex condition expression
# especially for multi-dimensional tensors.
# Skip it for now.
if isinstance(full_sz, fx.Node):
full_sz = full_sz.meta["val"]
if select_sz < full_sz:
return False

# Actually we can support small size larger than 1. It would be a bit
# tedius. E.g., we load all the index values (not many) and compare
# them with the position in tensor to decide what value to return.
return selector_ft.size(dim) == 1


@register_lowering_pattern(
CallFunction(
aten.scatter.value,
CallFunction(
aten.full,
KeywordArg("shape"),
KeywordArg("background_val"),
dtype=KeywordArg("dtype"),
),
KeywordArg("dim"),
KeywordArg("selector"),
KeywordArg("val"), # scalar value
),
extra_check=scatter_upon_const_tensor_extra_check,
)
def scatter_upon_const_tensor(
match: Match, shape, background_val, dtype, dim, selector, val
):
"""
Match the pattern of full+scatter into a pointwise.
TODO: Right now the scatter value must be a scalar. But we could support it
when it is a tensor as well.
"""
from torch._inductor import metrics

metrics.num_matches_for_scatter_upon_const_tensor += 1

selector_loader = selector.make_loader()

def inner_fn(idx):
selector_idx = list(idx)
selector_idx[dim] = 0

selector = selector_loader(selector_idx)
return ops.where(
selector == ops.index_expr(idx[dim], torch.int64),
ops.constant(val, dtype),
ops.constant(background_val, dtype),
)

return ir.Pointwise.create(
device=selector.get_device(),
dtype=dtype,
inner_fn=inner_fn,
ranges=shape,
)


@register_lowering_pattern(
CallFunction(
aten.add,
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
cpp_outer_loop_fused_inner_counts: List[int] = []

num_comprehensive_padding = 0
num_matches_for_scatter_upon_const_tensor = 0


# reset all counters
Expand All @@ -57,6 +58,7 @@ def reset():
global cpp_to_dtype_count
global cpp_outer_loop_fused_inner_counts
global num_comprehensive_padding
global num_matches_for_scatter_upon_const_tensor

generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
Expand All @@ -67,6 +69,7 @@ def reset():
cpp_to_dtype_count = 0
cpp_outer_loop_fused_inner_counts.clear()
num_comprehensive_padding = 0
num_matches_for_scatter_upon_const_tensor = 0


@dataclass
Expand Down

0 comments on commit fd414d6

Please sign in to comment.