Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
(This is more of a issue/feature request than a PR, but since I have a working prototype I figured I'd share it.)
Triton now has a CPU backend from triton-cpu, which compiles LLIR to assembly using LLVM. This PR adds support for this by using
jax.pure_callback
to wrap calling Triton kernels from Python (generating a XLA custom call). A proper implementation of this would add a openmp cpu launcher to jaxlib's gpu_triton.py akin to triton_kernels.cc (unfortunately, the cpu backend doesn't seem to fit neatly into jaxlib's existing Triton abstractions, for example, it seems cuda/rocm are mutually exclusive since they overwrite the same names ingpu_triton.py
while cpu can co-exist with gpu). I don't have enough familiarity with C++/jaxlib/xla to make this change myself, hence the feature request.The motivation for adding a cpu backend is that it's faster than
TRITON_INTERPRET=1
and allows forjax.jit
'ing Triton kernels like on gpu. In addition, it would possibly allow Pallas kernels to be ran on cpu withoutinterpret=True
, which is generally very slow. Pure JAX code can be ran on either cpu or gpu with no code modifications, and it'd be nice if this was also true for Triton/Pallas kernels (for debugging/prototyping, but also to run fast on cpu itself).Known limitations of this PR:
jax.pure_callback
is used instead of a C++ XLA custom call, kernel launch overhead is relatively hightriton_kernel_call
or MLIR loweringtriton_kernel_call_lowering
) so the behavior is slightly differentzeroed_outputs
doesn't receive meta parameters from Triton configurationsPasses (and definitely completely overfit to) all tests except for those that count the number of compilations (as it doesn't use the MLIR lowering path) and
test_autotune_with_heuristics
since Triton evaluates the configuration multiple times.The first two commits are unrelated fixes to the tests which can be merged, and I've opened #321 with them verbatim.