diff --git a/examples/matrix_multiplication.py b/examples/matrix_multiplication.py index f3d8219b..19aac4d9 100644 --- a/examples/matrix_multiplication.py +++ b/examples/matrix_multiplication.py @@ -20,7 +20,6 @@ import jax import jax.numpy as jnp -import math m=512 n=512 diff --git a/jax_triton/triton_call.py b/jax_triton/triton_call.py index e6b30b49..7c4977fe 100644 --- a/jax_triton/triton_call.py +++ b/jax_triton/triton_call.py @@ -26,6 +26,7 @@ from jax.interpreters import mlir from jax import tree_util from jax._src import util +from jax._src.lib import xla_bridge as xb from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import mhlo import numpy as np @@ -33,13 +34,13 @@ import triton import triton.language as tl -from jax_triton import custom_call +from jax_triton import triton_kernel_call os.environ["TRITON_CACHE_DIR"] = "" map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip -xc.register_custom_call_target("triton_call", custom_call.get_custom_call(), platform="CUDA") +xc.register_custom_call_target("triton_kernel_call", triton_kernel_call.get_custom_call(), platform="CUDA") def get_triton_type(obj: Any) -> str: type_map = { @@ -79,8 +80,6 @@ def get_triton_python_ir(aval): def compile(triton_function, constants, *, key, device=0, num_warps=4, num_stages=2): def lower(*args): - arg_types = [get_triton_python_ir(a) for a in args] - attributes = {i: 16 for i in range(len(args))} triton_function._warmup(arg_types=arg_types, device=device, attributes=attributes, constants=constants, num_warps=num_warps, num_stages=num_stages, key=key, is_manual_warmup=True, @@ -133,18 +132,24 @@ def aval_to_layout(aval): arange = np.arange(aval.ndim, dtype='int64')[::-1].copy() return ir.DenseIntElementsAttr.get(arange, type=ir.IndexType.get()) -def emit_triton_call(triton_func, avals_in, avals_out, grid, num_warps, num_stages, +def emit_triton_call(ctx, triton_func, grid, num_warps, num_stages, dump_binary_path: Optional[str], **metaparams): metadata = {triton_func.arg_names.index(k) : v for k, v in metaparams.items()} - compile(triton_func, metadata, num_warps=num_warps, num_stages=num_stages, key="foo")(*avals_in, *avals_out) - loaded_binary = triton_func.bin_cache["foo"] - kernel_ptr = loaded_binary.kernel - shared_mem = loaded_binary.shared_mem + all_args = [*ctx.avals_in, *ctx.avals_out] + arg_types = [get_triton_python_ir(a) for a in all_args] + attributes = {i: 16 for i in range(len(all_args))} + # TODO(sharadmv): handle multiple devices, right now we assume device 0 which + # is fine when we have multiple of the same GPU but this won't work in + # general. + binary = triton_func._compile(arg_types=arg_types, device=0, + attributes=attributes, constants=metadata, num_warps=num_warps, + num_stages=num_stages, extern_libs={}) + name, asm, shared_mem = binary.name, binary.asm, binary.shared_mem if dump_binary_path is not None: binary = dict( - asm=loaded_binary.asm, + asm=asm, shared_mem=shared_mem, - name=loaded_binary.bin.name) + name=name) with open(dump_binary_path, "wb") as fp: pickle.dump(binary, fp) @@ -158,9 +163,10 @@ def emit_triton_call(triton_func, avals_in, avals_out, grid, num_warps, num_stag grid_1, grid_2 = grid_[1], grid_[2] else: assert False - arity = len(avals_in) + len(avals_out) - descriptor = custom_call.make_triton_call_descriptor(kernel_ptr, shared_mem, grid_0, grid_1, grid_2, num_warps, arity) - return descriptor + arity = len(ctx.avals_in) + len(ctx.avals_out) + descriptor, keepalive = triton_kernel_call.make_triton_call_descriptor( + name, asm, shared_mem, grid_0, grid_1, grid_2, num_warps, arity) + return descriptor, keepalive def triton_call_lowering(ctx, *args, kernel, out_shapes, grid, num_warps=4, num_stages=2, dump_binary_path: Optional[str], **metaparams): @@ -168,12 +174,13 @@ def triton_call_lowering(ctx, *args, kernel, out_shapes, grid, num_warps=4, num_ ir.RankedTensorType.get(out_shape.shape, mlir.dtype_to_ir_type(out_shape.dtype)) for out_shape in out_shapes]) i32_type = ir.IntegerType.get_signless(32) - descriptor = emit_triton_call(kernel, ctx.avals_in, ctx.avals_out, grid, - num_warps, num_stages, dump_binary_path, - **metaparams) + descriptor, keepalive = emit_triton_call(ctx, kernel, grid, + num_warps, num_stages, dump_binary_path, + **metaparams) + ctx.module_context.add_keepalive(keepalive) out = mhlo.CustomCallOp( [out_type], args, - call_target_name=ir.StringAttr.get("triton_call"), + call_target_name=ir.StringAttr.get("triton_kernel_call"), has_side_effect=ir.BoolAttr.get(False), backend_config=ir.StringAttr.get(descriptor), api_version=ir.IntegerAttr.get(i32_type, 1), diff --git a/lib/custom_call.cc b/lib/custom_call.cc deleted file mode 100644 index c2d4e4c6..00000000 --- a/lib/custom_call.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright 2022 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include -#include "cuda.h" - -namespace py = pybind11; - -template -std::string PackDescriptorAsString(const T& descriptor) { - return std::string(reinterpret_cast(&descriptor), sizeof(T)); -} - -template -void UnpackDescriptor(T* descriptor_ptr, const char* opaque, std::size_t opaque_len) { - if (opaque_len != sizeof(T)) { - throw std::invalid_argument( "received negative value" ); - } - std::memcpy(descriptor_ptr, opaque, opaque_len); -} - -struct TritonCallDescriptor { - CUfunction kernel_ptr; - std::uint32_t shared_mem; - std::uint32_t grid_0; - std::uint32_t grid_1; - std::uint32_t grid_2; - std::uint32_t num_warps; - std::uint32_t arity; -}; - -void do_custom_call(CUstream stream, void** buffers, - char* opaque, size_t opaque_len) { - TritonCallDescriptor descriptor; - UnpackDescriptor(&descriptor, opaque, opaque_len); - CUfunction kernel = descriptor.kernel_ptr; - int grid_0 = descriptor.grid_0; - int grid_1 = descriptor.grid_1; - int grid_2 = descriptor.grid_2; - int num_warps = descriptor.num_warps; - int arity = descriptor.arity; - std::string params; - params.resize(8 * arity); - char* params_ptr = ¶ms[0]; - for (int i = 0; i < arity; i++) { - params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); - std::memcpy(params_ptr, &buffers[i], 8); - params_ptr += 8; - } - size_t params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); - void* config[] = { - CU_LAUNCH_PARAM_BUFFER_POINTER, - static_cast(const_cast(params.data())), - CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size, - CU_LAUNCH_PARAM_END - }; - CUresult result = cuLaunchKernel(kernel, grid_0, grid_1, grid_2, num_warps * 32, 1, 1, descriptor.shared_mem, stream, nullptr, config); - if (result != 0) { - std::cout << "Failed launch: " << result << std::endl; - } - // cuStreamSynchronize(stream); -} - -std::string MakeTritonCallDescriptor(uint64_t kernel_ptr, uint32_t shared_mem, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2, uint32_t num_warps, uint32_t arity) { - TritonCallDescriptor descriptor; - descriptor.kernel_ptr = reinterpret_cast(kernel_ptr); - descriptor.shared_mem = shared_mem; - descriptor.grid_0 = grid_0; - descriptor.grid_1 = grid_1; - descriptor.grid_2 = grid_2; - descriptor.num_warps = num_warps; - descriptor.arity = arity; - return PackDescriptorAsString(descriptor); -} - -template -pybind11::capsule EncapsulateFunction(T* fn) { - return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); -} - -PYBIND11_MODULE(custom_call, m) { - m.def("make_triton_call_descriptor", [](uint64_t kernel_ptr, uint32_t shared_mem, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2, uint32_t num_warps, uint32_t arity){ return py::bytes(MakeTritonCallDescriptor(kernel_ptr, shared_mem, grid_0, grid_1, grid_2, num_warps, arity)); - }); - m.def("get_custom_call", [](){ return EncapsulateFunction(do_custom_call); }); -} diff --git a/lib/triton_kernel_call.cc b/lib/triton_kernel_call.cc new file mode 100644 index 00000000..43504892 --- /dev/null +++ b/lib/triton_kernel_call.cc @@ -0,0 +1,135 @@ +/* Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "triton_kernel_call.h" + +#include +#include +#include + +#include +#include "cuda.h" + +namespace py = pybind11; + +namespace jax_triton { + +const int TRITON_MAX_N_SHARED_BYTES = 49152; +const int TRITON_MAX_SHARED_OPTIN = 49152; + + +void TritonExecutable::launch(CUstream stream, void** buffers) { + CUdevice dev; + CUcontext ctx; + // Set the current context to the stream context so we can query the stream + // device + cuStreamGetCtx(stream, &ctx); + cuCtxSetCurrent(ctx); + /// Only load the kernel if it hasn't already been loaded for this device + cuCtxGetDevice(&dev); + CUfunction kernel = load(dev); + std::string params; + params.resize(8 * arity); + char* params_ptr = ¶ms[0]; + for (uint32_t i = 0; i < arity; i++) { + params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); + std::memcpy(params_ptr, &buffers[i], 8); + params_ptr += 8; + } + size_t params_size = static_cast(params_ptr - ¶ms[0]); + void* config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, + static_cast(const_cast(params.data())), + CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size, + CU_LAUNCH_PARAM_END + }; + CUresult result = cuLaunchKernel(kernel, grid_0, grid_1, grid_2, num_warps * 32, 1, 1, shared_mem, stream, nullptr, config); + if (result != 0) { + std::cout << "Failed launch: " << result << std::endl; + } +}; + +CUfunction TritonExecutable::load(CUdevice device) { + const std::lock_guard lock(mut); + if (is_loaded(device)) { + return kernels[device]; + } + // Mimics Triton kernel loading + std::string assembly; + auto iter = asm_map.find("cubin"); + if (iter != asm_map.end()) + assembly = py::cast(asm_map["cubin"]); + else { + assert(asm_map.contains("ptx")); + assembly = py::cast(asm_map["ptx"]); + } + CUfunction fun; + CUmodule mod; + cuModuleLoadData(&mod, assembly.c_str()); + cuModuleGetFunction(&fun, mod, name.c_str()); + int n_regs = 0; + int n_spills = 0; + cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); + cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); + n_spills /= 4; + int shared_optin; + cuDeviceGetAttribute(&shared_optin, + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device); + if (shared_mem > TRITON_MAX_N_SHARED_BYTES && + shared_optin > TRITON_MAX_SHARED_OPTIN) { + cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); + int shared_total, shared_static; + cuDeviceGetAttribute( + &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device); + cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + fun); + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static); + } + kernels[device] = fun; + return fun; +}; + +void do_custom_call(CUstream stream, void** buffers, + char* opaque, size_t opaque_len) { + uint64_t descriptor = std::strtoull(opaque, NULL, 0); + TritonExecutable* executable = TritonExecutable::from_descriptor(descriptor); + executable->launch(stream, buffers); +} + +std::pair MakeTritonExecutable(std::string name, asm_map_t asm_map, uint32_t shared_mem, uint32_t grid_0, uint32_t grid_1, uint32_t grid_2, uint32_t num_warps, uint32_t arity) { + auto triton_call = std::make_unique( + name, asm_map, shared_mem, grid_0, grid_1, grid_2, num_warps, arity); + std::string descriptor = std::to_string(reinterpret_cast(triton_call.get())); + py::capsule callback_capsule(triton_call.release(), [](void* ptr) { + delete reinterpret_cast(ptr); + }); + return std::make_pair(descriptor, py::object(std::move(callback_capsule))); +} + +template +pybind11::capsule EncapsulateFunction(T* fn) { + return pybind11::capsule(reinterpret_cast(fn), "xla._CUSTOM_CALL_TARGET"); +} + +PYBIND11_MODULE(triton_kernel_call, m) { + m.def("make_triton_call_descriptor", &MakeTritonExecutable); + m.def("get_custom_call", [](){ + return EncapsulateFunction(do_custom_call); + }); +} + +} // namespace jax_triton diff --git a/lib/triton_kernel_call.h b/lib/triton_kernel_call.h new file mode 100644 index 00000000..04e78a00 --- /dev/null +++ b/lib/triton_kernel_call.h @@ -0,0 +1,66 @@ +/* Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "cuda.h" + +namespace jax_triton { + +using asm_map_t = std::unordered_map; + +class TritonExecutable { + public: + explicit TritonExecutable(std::string name, asm_map_t asm_map, std::uint32_t shared_mem, + std::uint32_t grid_0, std::uint32_t grid_1, + std::uint32_t grid_2, std::uint32_t num_warps, + std::uint32_t arity) + : name(std::move(name)), + asm_map(std::move(asm_map)), + shared_mem(shared_mem), + grid_0(grid_0), + grid_1(grid_1), + grid_2(grid_2), + num_warps(num_warps), + arity(arity), + kernels(), + mut() {} + + static TritonExecutable* from_descriptor(uint64_t descriptor) { + return reinterpret_cast(static_cast(descriptor)); + }; + void launch(CUstream stream, void** buffers); + + private: + bool is_loaded(CUdevice device) const { + return kernels.count(device) > 0; + } + + CUfunction load(CUdevice device); + std::string name; + asm_map_t asm_map; + std::uint32_t shared_mem; + std::uint32_t grid_0; + std::uint32_t grid_1; + std::uint32_t grid_2; + std::uint32_t num_warps; + std::uint32_t arity; + std::unordered_map kernels; + std::mutex mut; +}; + +} // namespace jax_triton diff --git a/setup.py b/setup.py index 5ba380b1..67ba6483 100644 --- a/setup.py +++ b/setup.py @@ -19,8 +19,8 @@ packages = ["jax_triton"], ext_modules = [ Extension( - name="jax_triton.custom_call", - sources=["lib/custom_call.cc"], + name="jax_triton.triton_kernel_call", + sources=["lib/triton_kernel_call.cc"], include_dirs = [ "/usr/local/cuda/include", pybind11.get_include()], diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index 2c986cc9..d369d2b0 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -87,6 +87,8 @@ def matmul_kernel( def triton_call(*args, **kwargs): return jax.jit(lambda *args: jt.triton_call(*args, **kwargs))(*args) +def triton_call_pmap(*args, **kwargs): + return jax.pmap(lambda *args: jt.triton_call(*args, **kwargs))(*args) class TritonKernelCallTest(parameterized.TestCase): @@ -110,6 +112,31 @@ def test_add_vectors(self, size, dtype, block_size): expected = x + y np.testing.assert_allclose(out, expected) + @parameterized.named_parameters(*[ + (f"size_{size}_dtype_{dtype}_blocksize_{block_size}", size, dtype, block_size) + for size in [1, 2, 5, 8, 100, 256, 1024] + for dtype in ['int32', 'float32', 'float16', 'int64', 'float64'] + for block_size in [1, 8, 32, 256] + ]) + def test_pmap_add_vectors(self, size, dtype, block_size): + n_devices = jax.local_device_count() + if n_devices < 2: + self.skipTest("Not enough devices") + grid = lambda meta: (size // meta["block_size"] + 1,) + k1, k2 = random.split(random.PRNGKey(0), 2) + if dtype in {"float32", "float16", "float64"}: + x, y = (random.normal(k1, [n_devices, size], dtype=dtype), + random.normal(k2, [n_devices, size], dtype=dtype)) + elif dtype in {"int32", "int64"}: + x, y = (random.randint(k1, [n_devices, size], -100, 100, dtype=dtype), + random.randint(k2, [n_devices, size], -100, 100, dtype=dtype)) + + out = triton_call_pmap(x, y, kernel=add_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape[1:], x.dtype), + grid=grid, block_size=block_size, n_elements=size) + expected = x + y + np.testing.assert_allclose(out, expected) + @parameterized.named_parameters(*[ (f"m_{m}_n_{n}_k_{k}_dtype_{dtype}_bm_{block_size_m}_" f"bn_{block_size_n}_bk_{block_size_k}_gm_{group_size_m}", m, n, k, dtype,