-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example notebook with attention isn't working #180
Comments
Sorry about that, I don't expect that that particular example will work. The examples don't have regression tests on them and the notebook hasn't been touched for a while. I haven't had the bandwidth to keep the examples up to date with Triton changes.
Do you mean flash attention? If so, the implementation in jax_triton/pallas/ops/attention.py should always be working. |
Thanks for a quick answer!
Oops. I guess I've been thinking a lot about flax lately.
I don't fully understand what pallas is and what pallas ops are. Are they supposed to be as performant as native triton kernels, but more powerful because they can be vmapped etc.? |
Pallas is an extension to JAX that allows you to write your Triton kernels using JAX directly. As you surmised, one of the core benefits is compatibility w/ JAX transformations (vmap just works, AD is WIP). Personally, I also find it a more friendly front-end than Triton's as well.
That is the goal, though we since we are generating Triton kernels, we might hit some unoptimized code paths. The main difference between the two is that Pallas will handle a lot of the pointer arithmetic indexing logic that you normally have to do by hand in Triton. As a result, we might not generate the exact indexing logic that the fastest Triton kernel might. If you find any performance gaps, please let us know and we can investigate! |
I've tried running the example notebook using jax_triton and jaxlib installed from head. Unfortunately, it doesn't seem to work: running the cell with
test_triton_jax(2, 32, 2048, 64)
hangs indefinitely.Thanks!
The text was updated successfully, but these errors were encountered: