-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: jax integration #590
Conversation
Since sliding and patching use sliding_window_view that is currently not available in JAX, we will not support them in the jax backend for the moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, great job :-)
I left just some nit ;-)
Thanks! |
@cako you may want to look at this as supporting document https://github.com/PyLops/pylops_notebooks/blob/master/developement-cupy/Timing_CupyJAX.ipynb. It contains timing for most of the methods ported to the Jax backend and a comparison with numpy and cupy |
* Uses positional arguments instead of `n` as int | tuple, which is the correct usage with `np.random.randn` * Corrects input/output types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrava87 nicely done! Going to leave it as approved, but please have a look at some of the comments and my commits.
By the way I ran the notebook... seems like Jax is generally slower than CuPy? Am I reading this wrong?
This is also what I see when running this both locally and on colab... my guess/suspicion is that when the operator has a limited number of steps all calling np/cp, cupy is already very well optimized so the I read a lot about jax being very optimized for GPUs/TPUs so this was also an exercise to compare it with cupy, but so far what I observe is somehow that cupy is better ;) |
Motivation
This PR introduces a new backend in PyLops to enable using JAX arrays.
As a by-product of JAX-enabled operators, we inherit JAX features like jit, automatic differentiation, and automatic vectorization.
Highlights
JaxOperator
backend
module with new logic to detect whether np,cp, or jnp methods should be used based on the input typejaxop
gpu.rst
documentation page