Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Revert "[frontend] Remove Complex Regex for MLIR Parsing (#4924)"" #2681

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion python/src/ir.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <optional>
#include <optional>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
Expand All @@ -24,6 +24,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/LocationSnapshot.h"

#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
Expand Down Expand Up @@ -502,6 +503,16 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("get_entry_func_name",
[](ModuleOp &self) -> std::string {
for (auto &op : self.getOps()) {
if (auto func = dyn_cast<FuncOp>(op)) {
if (LLVM::isKernel(func))
return func.getName().str();
}
}
return "";
})
.def("has_function",
[](ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
Expand All @@ -512,6 +523,43 @@ void init_triton_ir(py::module &&m) {
[](ModuleOp &self, std::string &funcName) -> FuncOp {
return self.lookupSymbol<FuncOp>(funcName);
})
/*
* def ty_to_cpp(ty) is the consumer of this function.
* If the type is a ptr it expects ty[0] == '*', else the type itself.
*/

.def("get_function_signature",
[](ModuleOp &self, FuncOp &func) -> std::vector<std::string> {
std::vector<std::string> strVec;

auto type = func.getFunctionType();
unsigned numArgs = type.getNumInputs();
for (unsigned i = 0; i != numArgs; ++i) {
std::string tempType;
llvm::raw_string_ostream os(tempType);

auto ty = type.getInput(i);
if (auto attributes = func.getCallableArgAttrs()) {
Attribute attr = attributes[i];
// Check for tt.nv_tma_desc = 1
if (auto dAttr = dyn_cast<DictionaryAttr>(attr)) {
if (dAttr.contains("tt.nv_tma_desc")) {
strVec.push_back("nvTmaDesc");
continue;
}
}
}
if (auto ptrType = dyn_cast<PointerType>(ty)) {
auto pType = ptrType.getPointeeType();
os << "*";
pType.print(os);
} else {
ty.print(os);
}
strVec.push_back(tempType);
}
return strVec;
})
.def("get_int_attr",
[](ModuleOp &self, std::string name) -> py::object {
auto ret = self->getAttrOfType<IntegerAttr>(name);
Expand Down
94 changes: 94 additions & 0 deletions python/test/unit/tools/test_irsource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import tempfile
import triton
from triton.compiler import IRSource, make_backend
from triton._C.libtriton import ir

target = triton.runtime.driver.active.get_current_target()
backend = make_backend(target)


def test_mlir_attribute_parsing() -> None:
'''
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.
Checks for the following:
1. Name and type signature are parsed correctly
2. _get_num_warps_from_ir_str() works
3. tt.nv_tma_desc attribute is parsed correctly
'''

sample_ttgir = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32},
%arg4: i32 {tt.divisibility = 16 : i32},
%arg5: i32 {tt.divisibility = 16 : i32},
%arg6: i32 {tt.divisibility = 16 : i32},
%arg7: i32 {tt.divisibility = 16 : i32},
%arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32},
%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
tt.return
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir)
f.flush()
context = ir.context()
src = IRSource(f.name, context, backend)

# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"

# check num warps
assert src.parse_options()['num_warps'] == 8

sample_ttgir_vector_add = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32})
attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
%10 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
%13 = arith.addi %9, %12 : tensor<1024xi32, #blocked>
%14 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
tt.return
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir_vector_add)
f.flush()
context = ir.context()
src = IRSource(f.name, context, backend)

# now test compilation
triton.compile(f.name, target=target)
7 changes: 5 additions & 2 deletions python/triton/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
from .errors import CompilationError

__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]
__all__ = [
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError",
"LazyDict"
]
63 changes: 33 additions & 30 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,13 @@
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
"ttgir": mlir_prototype_pattern,
"ptx": ptx_prototype_pattern,
}

mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
"ttgir": mlir_arg_type_pattern,
"ptx": ptx_arg_type_pattern,
}

Expand All @@ -55,16 +49,6 @@ def convert_type_repr(x):
return x


def _get_num_warps_from_ir_str(src: str):
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
# e.g. someone has an instruction (not module) attribute named "num-warps".
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
num_warps = int(num_warps_matches[0])
return num_warps


class ASTSource:

def __init__(self, fn, signature, constants=None, attrs=None) -> None:
Expand Down Expand Up @@ -107,28 +91,42 @@ def parse_options(self):

class IRSource:

def __init__(self, path):
def __init__(self, path, context, backend):
self.path = path
path = Path(path)
self.ext = path.suffix[1:]
self.src = path.read_text()
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
ir.load_dialects(context)
backend.load_dialects(context)

# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
# TODO - replace with a proper parser
if self.ext == "ptx":
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
self.name = match.group(1)
signature = match.group(2)
types = re.findall(arg_type_pattern[self.ext], signature)
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
else:
self.module = ir.parse_mlir_module(self.path, context)
fn_name = self.module.get_entry_func_name()
self.name = "@" + fn_name
funcOp = self.module.get_function(fn_name)
func_ty = self.module.get_function_signature(funcOp)
self.signature = {k: ty for k, ty in enumerate(func_ty)}

def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, module_map, context):
module = ir.parse_mlir_module(self.path, context)
module.context = context
return module
self.module.context = context
return self.module

def parse_options(self):
if self.ext == "ttgir":
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
num_warps = self.module.get_int_attr("triton_gpu.num-warps")
assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute"
return {'num_warps': num_warps}
return dict()


Expand Down Expand Up @@ -225,7 +223,9 @@ def compile(src, target=None, options=None):
# create backend
if ir_source:
assert isinstance(src, str), "source must be either AST or a filepath"
src = IRSource(src)
context = ir.context()
src = IRSource(src, context, backend)

extra_options = src.parse_options()
options = backend.parse_options(dict(options or dict(), **extra_options))
# create cache manager
Expand Down Expand Up @@ -266,9 +266,12 @@ def compile(src, target=None, options=None):
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
if ir_source:
first_stage += 1
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)

if not isinstance(src, IRSource):
context = ir.context()
ir.load_dialects(context)
backend.load_dialects(context)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from upstream code:

image

The "else" part is missing.... perhaps it belongs to a more recent upstream commit than the one we are trying to merge ?

Also, would adding the else part eliminate the need to add lines 100-103 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from upstream code:

Yes, I made this change on purpose. Now, as far as I understand, the backend dialects should be initialized when calling the IRSource class constructor.

Also, would adding the else part eliminate the need to add lines 100-103 ?

No, because IRSource constructor also calls function self.module = ir.parse_mlir_module(self.path, context), which requires the backend dialects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example, we use dpas layout that requires an additional dialect:

DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,

codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
try:
Expand Down