diff --git a/test/test_triton_windows.py b/test/test_triton_windows.py new file mode 100644 index 0000000..fb600d2 --- /dev/null +++ b/test/test_triton_windows.py @@ -0,0 +1,34 @@ +""" +此测试文件用于测试triton是否能正确编译windows平台上的cuda代码 +--- +代码支持: https://github.com/woct0rdho/triton-windows +""" + +import torch +import triton +import triton.language as tl + +@triton.jit +def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + +def add(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.is_cuda and y.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + +a = torch.rand(3, device="cuda") +b = a + a +b_compiled = add(a, a) +print(b_compiled - b) +print("If you see tensor([0., 0., 0.], device='cuda:0'), then it works") \ No newline at end of file