Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CPU backend #322

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft

Add CPU backend #322

wants to merge 4 commits into from

Conversation

stephen-huan
Copy link

@stephen-huan stephen-huan commented Dec 21, 2024

(This is more of a issue/feature request than a PR, but since I have a working prototype I figured I'd share it.)

Triton now has a CPU backend from triton-cpu, which compiles LLIR to assembly using LLVM. This PR adds support for this by using jax.pure_callback to wrap calling Triton kernels from Python (generating a XLA custom call). A proper implementation of this would add a openmp cpu launcher to jaxlib's gpu_triton.py akin to triton_kernels.cc (unfortunately, the cpu backend doesn't seem to fit neatly into jaxlib's existing Triton abstractions, for example, it seems cuda/rocm are mutually exclusive since they overwrite the same names in gpu_triton.py while cpu can co-exist with gpu). I don't have enough familiarity with C++/jaxlib/xla to make this change myself, hence the feature request.

The motivation for adding a cpu backend is that it's faster than TRITON_INTERPRET=1 and allows for jax.jit'ing Triton kernels like on gpu. In addition, it would possibly allow Pallas kernels to be ran on cpu without interpret=True, which is generally very slow. Pure JAX code can be ran on either cpu or gpu with no code modifications, and it'd be nice if this was also true for Triton/Pallas kernels (for debugging/prototyping, but also to run fast on cpu itself).

Known limitations of this PR:

  • Since jax.pure_callback is used instead of a C++ XLA custom call, kernel launch overhead is relatively high
    • The codepath for the cpu backend is completely separate from gpu (doesn't use the custom call target triton_kernel_call or MLIR lowering triton_kernel_call_lowering) so the behavior is slightly different
  • Although implemented, input-output aliases are probably not handled correctly
  • zeroed_outputs doesn't receive meta parameters from Triton configurations
  • Autotuning has a runtime dependency on torch
  • The matmul tests are extremely slow (which might be triton-cpu's fault as we're just dispatching to it)

Passes (and definitely completely overfit to) all tests except for those that count the number of compilations (as it doesn't use the MLIR lowering path) and test_autotune_with_heuristics since Triton evaluates the configuration multiple times.

========================================================================== short test summary info ===========================================================================
FAILED tests/triton_call_test.py::TritonKernelCallTest::test_autotune_with_heuristics - AssertionError: Lists differ: [True, True, True, True, True, True, True, True, True, True,[147 chars]True] != [True, True, True, False]
FAILED tests/triton_call_test.py::TritonKernelCallTest::test_kernel_cache_equivalent_kernels - AssertionError: 0 != 1
FAILED tests/triton_call_test.py::TritonKernelCallTest::test_kernel_cache_same_kernel_different_params - AssertionError: 0 != 1
FAILED tests/triton_call_test.py::TritonKernelCallTest::test_specialization - AssertionError: Expected 'ast_to_ttir' to have been called once. Called 0 times.
=========================================================== 4 failed, 155 passed, 6 skipped in 2965.08s (0:49:25) ============================================================

The first two commits are unrelated fixes to the tests which can be merged, and I've opened #321 with them verbatim.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant