Skip to content

Commit

Permalink
fix splitk assertion issue
Browse files Browse the repository at this point in the history
  • Loading branch information
LiyangLingIntel committed Nov 15, 2024
1 parent 532728c commit fa6cc70
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
9 changes: 5 additions & 4 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def forward(ctx, a, b, c, acc_dtype=None):
[512, 32768, 8192],
[1024, 28672, 8192],
[3072, 4096, 3072],
[4096, 4096, 4096],
],
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
Expand All @@ -152,17 +153,17 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='_kernel')
elif provider == 'xetla':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_splitk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,17 @@ def benchmark(M, N, K, provider):
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
c = torch.zeros((M, N), device=a.device, dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
elif provider == 'xetla':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
acc = torch.empty((M, N), device='xpu', dtype=torch.float32)
cnt = torch.empty((M, N), device='xpu', dtype=torch.int32)
c = torch.zeros((M, N), device='xpu', dtype=torch.float32)
acc = torch.zeros((M, N), device='xpu', dtype=torch.float32)
cnt = torch.zeros((M, N), device='xpu', dtype=torch.int32)

name = f'gemm_streamk_shape_{M}_{K}_{N}'
func = getattr(xetla_kernel, name)
Expand Down

0 comments on commit fa6cc70

Please sign in to comment.