Skip to content

Commit

Permalink
Merge branch 'main' into gregory/windows-support
Browse files Browse the repository at this point in the history
  • Loading branch information
gshimansky committed Nov 12, 2024
2 parents f017395 + ee755e8 commit fdb63be
Show file tree
Hide file tree
Showing 28 changed files with 336 additions and 178 deletions.
60 changes: 0 additions & 60 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

This is the development repository of Intel® XPU Backend for Triton\*, a new [Triton](https://github.com/triton-lang/triton/) backend for Intel GPUs. Intel® XPU Backend for Triton\* is a out of tree backend module for [Triton](https://github.com/triton-lang/triton/blob/main/CONTRIBUTING.md) used to provide best-in-class performance and productivity on any Intel GPUs for [PyTorch](https://github.com/triton-lang/triton/blob/main/CONTRIBUTING.md) and standalone usage.

<<<<<<< HEAD
# Compatibility

* Operating systems:
Expand All @@ -22,25 +21,11 @@ This is the development repository of Intel® XPU Backend for Triton\*, a new [T
* Latest [PyTorch Prerequisites for Intel GPUs](https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html)

Note that Intel® XPU Backend for Triton\* is not compatible with Intel® Extension for PyTorch\* and Intel® oneAPI Base Toolkit\*.
=======
| **`Documentation`** | **`Nightly Wheels`** |
|-------------------- | -------------------- |
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |

# Triton

This is the development repository of Triton, a language and compiler for writing highly efficient custom Deep-Learning primitives. The aim of Triton is to provide an open-source environment to write fast code at higher productivity than CUDA, but also with higher flexibility than other existing DSLs.

The foundations of this project are described in the following MAPL2019 publication: [Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations](http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf). Please consider citing this work if you use Triton!

The [official documentation](https://triton-lang.org) contains installation instructions and tutorials. See also these third-party [Triton puzzles](https://github.com/srush/Triton-Puzzles), which can all be run using the Triton interpreter -- no GPU required.
>>>>>>> d6739d3c33dee481f2d4dee4f6ecd4123f671597

# Quick Installation

## Prerequisites

<<<<<<< HEAD
1. Latest [Rolling Release](https://dgpu-docs.intel.com/driver/installation-rolling.html) or [Long Term Support Release](https://dgpu-docs.intel.com/driver/installation.html) of GPU driver
2. Latest release of [PyTorch Prerequisites for Intel GPUs](https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html)
3. Latest release of [Profiling Tools Interfaces for Intel GPU (PTI for GPU)](https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html)
Expand All @@ -55,35 +40,18 @@ Extract the archive and in the extracted directory execute:
```shell
pip install torch-*.whl triton-*.whl
```
=======
```shell
pip install triton
```

Binary wheels are available for CPython 3.8-3.12 and PyPy 3.8-3.9.
>>>>>>> d6739d3c33dee481f2d4dee4f6ecd4123f671597

Before using Intel® XPU Backend for Triton\* you need to initialize the toolchain.
The default location is `/opt/intel/oneapi` (if installed as a `root` user) or `~/intel/oneapi` (if installed as a regular user).

```shell
<<<<<<< HEAD
# replace /opt/intel/oneapi with the actual location of PyTorch Prerequisites for Intel GPUs
source /opt/intel/oneapi/setvars.sh
=======
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
>>>>>>> d6739d3c33dee481f2d4dee4f6ecd4123f671597
```

# Install from source

<<<<<<< HEAD
## Prerequisites
=======
```shell
git clone https://github.com/triton-lang/triton.git;
cd triton;
>>>>>>> d6739d3c33dee481f2d4dee4f6ecd4123f671597

1. Latest [Rolling Release](https://dgpu-docs.intel.com/driver/installation-rolling.html) or [Long Term Support Release](https://dgpu-docs.intel.com/driver/installation.html) of GPU driver
2. Latest release of [PyTorch Prerequisites for Intel GPUs](https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html)
Expand All @@ -104,14 +72,9 @@ source /opt/intel/oneapi/setvars.sh
Clone this repository:

```shell
<<<<<<< HEAD
git clone https://github.com/intel/intel-xpu-backend-for-triton.git
cd intel-xpu-backend-for-triton
```
=======
git clone https://github.com/triton-lang/triton.git;
cd triton;
>>>>>>> d6739d3c33dee481f2d4dee4f6ecd4123f671597

To avoid potential conflicts with installed packages it is recommended to create and activate a new Python virtual environment:

Expand Down Expand Up @@ -242,7 +205,6 @@ For detailed instructions on how to debug Triton's frontend, please refer to thi
# Usage Guide
<<<<<<< HEAD
## Code Modifications
Intel® XPU Backend for Triton\* requires a special version of PyTorch that can be built from sources or installed from nightly wheels.
Expand Down Expand Up @@ -346,14 +308,6 @@ Note that the user needs to explicitly set `TRITON_XPU_PROFILE=1` when the user
```Bash
export TRITON_XPU_PROFILE=1
```
=======
Version 2.0 is out! New features include:
- Many, many bug fixes
- Performance improvements
- Backend rewritten to use MLIR
- Support for kernels that contain back-to-back matmuls (e.g., flash attention)
>>>>>>> d6739d3c33dee481f2d4dee4f6ecd4123f671597
# Contributing
Expand All @@ -363,24 +317,10 @@ Community contributions are more than welcome, whether it be to fix bugs or to a

_MIT License_. As found in [LICENSE](https://github.com/intel/intel-xpu-backend-for-triton/blob/main/LICENSE) file.

<<<<<<< HEAD

## Security

See Intel's [Security Center](https://www.intel.com/content/www/us/en/security-center/default.html)
for information on how to report a potential security issue or vulnerability.
See also: [Security Policy](security.md)
=======
# Compatibility
Supported Platforms:
- Linux
Supported Hardware:
- NVIDIA GPUs (Compute Capability 8.0+)
- AMD GPUs (ROCm 5.2+)
- Under development: CPUs
>>>>>>> d6739d3c33dee481f2d4dee4f6ecd4123f671597
6 changes: 3 additions & 3 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def matmul_kernel_with_block_pointers_batched(
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
bid = tl.program_id(axis=0)
pid = tl.program_id(axis=1)
bid = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
Expand Down Expand Up @@ -186,8 +186,8 @@ def matmul(a, b, c, transpose_a=False, transpose_b=False):
B = a.shape[0]
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
B,
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
B,
)
matmul_kernel_with_block_pointers_batched[grid](
a, b, c, #
Expand Down
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ADVANCED_PATH",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B",
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
Expand Down
11 changes: 9 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
}
}
llAssert(op, condition, adaptor.getMessage(), rewriter);
if (isa<RankedTensorType>(op.getCondition().getType())) {
// Add a barrier to avoid a race condition in case an assert is followed
// by an op that may trap if the assert condition is true. Since the
// tensor in those two operations may have different layout we need to
// make sure all the threads are done executing the assert before going to
// the next op.
barrier();
}
rewriter.eraseOp(op);
return success();
}
// op: the op at which the assert is inserted. Unlike printf, we need to
// know about the op to split the block.
void llAssert(Operation *op, Value condition, StringRef message,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter::InsertionGuard guard(rewriter);

auto ctx = rewriter.getContext();
auto loc = op->getLoc();

Expand Down Expand Up @@ -79,6 +85,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
rewriter.create<cf::BranchOp>(loc, thenBlock);
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
rewriter.setInsertionPointToStart(thenBlock);
}

protected:
Expand Down
2 changes: 1 addition & 1 deletion lib/Target/SPIRV/spirv-llvm-translator.conf
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cf697333b60d2000509ab7e79869ecab5eda9e9c
1a1bf17d9e8684cd826e4278e78f63aa80e2e2ca
4 changes: 0 additions & 4 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,10 +1729,6 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
if not builder.options.debug:
return
cond_ty = cond.type
if not cond_ty.is_block():
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)


Expand Down
2 changes: 1 addition & 1 deletion scripts/compile-pytorch-ipex.sh
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ fi
# Configure, build and install PyTorch from source.

if [[ $BUILD_PYTORCH = true ]]; then
PYTORCH_PROJ=$BASE/pytorch
PYTORCH_PROJ=$BASE/pytorch-stonepia

echo "**** Cleaning $PYTORCH_PROJ before build ****"
rm -rf $PYTORCH_PROJ
Expand Down
2 changes: 2 additions & 0 deletions scripts/skiplist/a770/language.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 2 additions & 0 deletions scripts/skiplist/conda/language.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-1
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 2 additions & 0 deletions scripts/skiplist/default/language.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
2 changes: 2 additions & 0 deletions scripts/skiplist/lts/language.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-1
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 2 additions & 0 deletions scripts/skiplist/mtl/language.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 2 additions & 0 deletions scripts/skiplist/xe2/language.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1434
test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662
test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
10 changes: 5 additions & 5 deletions test/Analysis/test-liveness.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ module attributes {"triton_gpu.num-warps" = 8 : i32} {

// CHECK: scf.if
// CHECK-NEXT: LiveIntervals for block: ^bb0
// CHECK-NEXT: [[[LOAD1:%.*]], [[LOAD1]]] for value: %arg0
// CHECK-NEXT: [[[LOAD1]], scf.yield] for value: [[LOAD1]]
// CHECK-NEXT: LiveIntervals for block: ^bb0
// CHECK-NEXT: [[[LOAD2:%.*]], [[LOAD2]]] for value: %arg1
// CHECK-NEXT: [[[LOAD2]], scf.yield] for value: [[LOAD2]]
// CHECK-DAG: [[[LOAD1:%.*]], [[LOAD1]]] for value: %arg0
// CHECK-DAG: [[[LOAD1]], scf.yield] for value: [[LOAD1]]
// CHECK-DAG: LiveIntervals for block: ^bb0
// CHECK-DAG: [[[LOAD2:%.*]], [[LOAD2]]] for value: %arg1
// CHECK-DAG: [[[LOAD2]], scf.yield] for value: [[LOAD2]]

%c1024_i32 = arith.constant 1024 : i32
%c64_i32 = arith.constant 64 : i32
Expand Down
65 changes: 65 additions & 0 deletions test/Conversion/intel/intel-allocate-shared-memory.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>

// Check no scratch memory is allocated for sub-group shuffle-like layout conversions.

// CHECK-LABEL: module attributes
// CHECK-SAME: triton_gpu.shared = 0 : i32
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
// CHECK: tt.func @test_sub_group_shuffle
// CHECK-NOT: llvm.ptr<3>
tt.func @test_sub_group_shuffle(%arg0: tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> {
%0 = triton_gpu.convert_layout %arg0 : tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return %0 : tensor<16xf16, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>

// Check scracth memory configuration for different sub-group transpose-like layout conversions.

// CHECK-LABEL: module attributes
// CHECK-SAME: triton_gpu.shared = 512 : i32
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func @test_f16(%arg0: tensor<16x16xf16, #blocked>) -> tensor<16x16xf16, #blocked1> {
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #blocked> -> tensor<16x16xf16, #blocked1>
tt.return %0 : tensor<16x16xf16, #blocked1>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}>

// Check scracth memory configuration for different sub-group transpose-like layout conversions.

// CHECK-LABEL: module attributes
// CHECK-SAME: triton_gpu.shared = 1024 : i32
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func @test_f32(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16x16xf32, #blocked1> {
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked> -> tensor<16x16xf32, #blocked1>
tt.return %0 : tensor<16x16xf32, #blocked1>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 2], order = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [4, 2], order = [0, 1]}>

// Check scracth memory configuration for different sub-group transpose-like layout conversions.

// CHECK-LABEL: module attributes
// CHECK-SAME: triton_gpu.shared = 32768 : i32
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func @test_f32(%arg0: tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked1> {
%0 = triton_gpu.convert_layout %arg0 : tensor<128x64xf32, #blocked> -> tensor<128x64xf32, #blocked1>
tt.return %0 : tensor<128x64xf32, #blocked1>
}
}
Loading

0 comments on commit fdb63be

Please sign in to comment.