Skip to content

Commit

Permalink
Create test_triton_windows.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Puiching-Memory committed Nov 20, 2024
1 parent ba2c6e6 commit bd283b2
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/test_triton_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
此测试文件用于测试triton是否能正确编译windows平台上的cuda代码
---
代码支持: https://github.com/woct0rdho/triton-windows
"""

import torch
import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.is_cuda and y.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output

a = torch.rand(3, device="cuda")
b = a + a
b_compiled = add(a, a)
print(b_compiled - b)
print("If you see tensor([0., 0., 0.], device='cuda:0'), then it works")

0 comments on commit bd283b2

Please sign in to comment.