Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UT] Port and run operator tests #246

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ jobs:
python3 assert_helper.py device_assert
python3 print_helper.py device_print float 1> /dev/null

- name: Clear cache
run: |
rm -rf ~/.triton

- name: Run interpreter tests
env:
# TRITON_INTERPRET: "1"
CUA_VISIBLE_DEVICES: ""
if: ${{ env.BACKEND == 'XPU'}}
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
run: |
cd python/test/unit
python3 -m pytest -vs operators/test_flash_attention.py

- name: Run partial operators tests
if: ${{ env.BACKEND == 'XPU'}}
run: |
cd python/test/unit
python3 -m pytest -n auto --verbose operators

- name: Run XPU python tests
if: ${{ env.BACKEND == 'XPU'}}
run: |
Expand Down
17 changes: 17 additions & 0 deletions .github/workflows/build_and_test_2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ jobs:
python3 assert_helper.py device_assert
python3 print_helper.py device_print float 1> /dev/null

- name: Clear cache
run: |
rm -rf ~/.triton

- name: Run interpreter tests
env:
# TRITON_INTERPRET: "1"
CUA_VISIBLE_DEVICES: ""
run: |
cd python/test/unit
python3 -m pytest -vs operators/test_flash_attention.py

- name: Run partial operators tests
run: |
cd python/test/unit
python3 -m pytest -n auto --verbose operators

- name: Run XPU python tests
run: |
cd python/test/backend/third_party_backends
Expand Down
12 changes: 8 additions & 4 deletions python/test/unit/operators/test_blocksparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


def sparsify_tensor(x, mask, block):
ret = torch.empty((x.size(0), mask.sum(), block, block), dtype=x.dtype, device=x.device)
Expand All @@ -12,7 +15,7 @@ def sparsify_tensor(x, mask, block):
return ret


def make_pair(shape, device="cuda", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
def make_pair(shape, device="xpu", alpha=1e-2, beta=0., trans=False, data=None, dtype=torch.float32):
if data is None:
data = torch.randn(shape, dtype=torch.float32, requires_grad=True, device=device)
ref_ret = data
Expand Down Expand Up @@ -79,7 +82,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K=
b_tri = do_sparsify(b_tri) if is_dds else b_tri
a_tri.retain_grad()
b_tri.retain_grad()
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="cuda")
op = triton.ops.blocksparse.matmul(layout, BLOCK, MODE, trans_a=TRANS_A, trans_b=TRANS_B, device="xpu")
c_tri = op(a_tri, b_tri)
c_tri.backward(dc_tri)
da_tri = a_tri.grad
Expand Down Expand Up @@ -119,7 +122,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
# compute [torch]
a_ref = mask_tensor(a_ref, layout, BLOCK, value=float("-inf"))
a_ref.retain_grad()
at_mask = torch.ones((M, N), device="cuda")
at_mask = torch.ones((M, N), device="xpu")
if is_causal:
at_mask = torch.tril(at_mask)
M = at_mask[None, None, :, :] + torch.zeros_like(a_ref)
Expand All @@ -132,7 +135,7 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4):
a_tri = sparsify_tensor(a_tri, layout, BLOCK)
a_tri.retain_grad()
dout_tri = sparsify_tensor(dout_tri, layout, BLOCK)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="cuda", is_dense=is_dense)
op = triton.ops.blocksparse.softmax(layout, BLOCK, device="xpu", is_dense=is_dense)
out_tri = op(a_tri, scale=scale, is_causal=is_causal)
out_tri.backward(dout_tri)
da_tri = a_tri.grad
Expand All @@ -152,6 +155,7 @@ def test_attention_fwd_bwd(
batch_size=2,
n_heads=2,
):
pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
Expand Down
4 changes: 4 additions & 0 deletions python/test/unit/operators/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


@pytest.mark.parametrize("M, N, dtype, mode", [ #
(M, N, dtype, mode)
Expand All @@ -13,6 +16,7 @@
for mode in ['forward', 'backward']
])
def test_op(M, N, dtype, mode):
pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
if capability[0] < 8 and dtype == "bfloat16":
pytest.skip("Only test bfloat16 on devices with sm >= 80")
Expand Down
13 changes: 8 additions & 5 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ #
(2, 4, 512, 16),
Expand All @@ -20,7 +23,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
if enable_tma in ["on", "true", "1"]:
if dtype == torch.bfloat16:
pytest.skip('bfloat16 tma not support currently')

pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
if not interpreter and capability[0] < 8:
Expand Down Expand Up @@ -87,14 +90,14 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):


@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"):
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="xpu"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
sm_scale = 1.3
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="xpu", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="xpu", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="xpu", requires_grad=True)
if provider == "triton":
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
if mode == 'bwd':
Expand Down
25 changes: 14 additions & 11 deletions python/test/unit/operators/test_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import triton
import triton.language as tl

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


def test_normalization_with_remat():

Expand Down Expand Up @@ -47,12 +50,12 @@ def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel

torch.manual_seed(123)

buf14 = torch.rand(8, 64, 64, 64, device="cuda")
buf16 = torch.rand(8, 1, 64, device="cuda")
arg114_1 = torch.rand(64, device="cuda")
arg115_1 = torch.rand(64, device="cuda")
arg8_1 = torch.rand(64, device="cuda")
arg9_1 = torch.rand(64, device="cuda")
buf14 = torch.rand(8, 64, 64, 64, device="xpu")
buf16 = torch.rand(8, 1, 64, device="xpu")
arg114_1 = torch.rand(64, device="xpu")
arg115_1 = torch.rand(64, device="xpu")
arg8_1 = torch.rand(64, device="xpu")
arg9_1 = torch.rand(64, device="xpu")
triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048)
torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0)

Expand Down Expand Up @@ -146,7 +149,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):
tmp76 = tl.where(tmp74, tmp75, tmp71)
tl.store(out_ptr0 + (x5 + tl.zeros([XBLOCK], tl.int32)), tmp76, None)

inp = torch.ones(8, 2048, 8, 8, device="cuda", dtype=torch.half)
inp = torch.ones(8, 2048, 8, 8, device="xpu", dtype=torch.half)
out = torch.ones_like(inp) * 3
numel = inp.numel()
triton_[(numel // 1024, )](inp, out, 1024)
Expand All @@ -172,8 +175,8 @@ def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
tl.store(out_ptr + xindex * RBLOCK + rindex, scan)

XBLOCK = 4
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda')
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='xpu')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='xpu')
fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK))
torch.testing.assert_close(output, ref)
Expand All @@ -192,7 +195,7 @@ def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr):
tl.store(out_ptr0 + rindex, tmp6, rmask)

RBLOCK = 8
out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64)
out0 = torch.empty(RBLOCK, device="xpu", dtype=torch.int64)
fn[(1, )](out0, RBLOCK, RBLOCK)
ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1
ref = torch.arange(RBLOCK, device="xpu", dtype=torch.int64) + 1
torch.testing.assert_close(out0, ref)
12 changes: 8 additions & 4 deletions python/test/unit/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import triton.language as tl
import triton.ops

# FIXME remove this once Triton L0 queue and IPEX SYCL queue can be synchronized through events
torch.xpu.enable_sync_mode()


@pytest.mark.parametrize(
"BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE",
Expand Down Expand Up @@ -102,6 +105,7 @@
)
def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32,
F8_FASTACCUM, ACC_DTYPE, OUTPUT_DTYPE):
pytest.skip("FIXME: Port get_device_capability to XPU")
ESI-SYD marked this conversation as resolved.
Show resolved Hide resolved
capability = torch.cuda.get_device_capability()
if capability[0] < 7:
pytest.skip("Only test tl.dot() on devices with sm >= 70")
Expand Down Expand Up @@ -152,15 +156,15 @@ def upcast_if_fp8(x, dtype):
def init_input(m, n, dtype, acc_dtype):
if 'float8' in dtype:
ewidth = {'float8e4b15': 4, 'float8e4nv': 4, 'float8e5': 5}[dtype]
sign = torch.randint(2, size=(m, n), device="cuda", dtype=torch.int8) * 128
val = torch.randint(2**3 - 1, size=(m, n), device="cuda", dtype=torch.int8) << 7 - ewidth
sign = torch.randint(2, size=(m, n), device="xpu", dtype=torch.int8) * 128
val = torch.randint(2**3 - 1, size=(m, n), device="xpu", dtype=torch.int8) << 7 - ewidth
return sign | val
if dtype == "int8":
return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8)
return torch.randint(-128, 127, (m, n), device="xpu", dtype=torch.int8)
# Use small range of values to prevent numerical issues.
min_exp = -4 if acc_dtype == "float16" else -10
exponents = torch.randint(min_exp, 0, size=(m, n))
ret = (2.**exponents).to(getattr(torch, dtype)).to("cuda")
ret = (2.**exponents).to(getattr(torch, dtype)).to("xpu")
return ret

# allocate/transpose inputs
Expand Down
9 changes: 7 additions & 2 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ function run_core_tests {
echo "***************************************************"
echo "****** Running Triton Core tests ******"
echo "***************************************************"
CORE_TEST_DIR=$TRITON_PROJ/python/test/unit/language
CORE_TEST_DIR=$TRITON_PROJ/python/test/unit
if [ ! -d "${CORE_TEST_DIR}" ]; then
echo "Not found '${CORE_TEST_DIR}'. Build Triton please" ; exit 3
fi
cd $CORE_TEST_DIR
cd $CORE_TEST_DIR/language
TRITON_DISABLE_LINE_INFO=1 python3 -m pytest --verbose --device xpu --ignore=test_line_info.py --ignore=test_subprocess.py
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
Expand All @@ -117,6 +117,11 @@ function run_core_tests {
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi
cd $CORE_TEST_DIR/operators
whitneywhtsang marked this conversation as resolved.
Show resolved Hide resolved
TRITON_DISABLE_LINE_INFO=1 python3 -m pytest -n auto --verbose
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi
}

function run_tutorial_test {
Expand Down
Loading