Skip to content

jax + jaxlib in pyproject.toml -- or just jax? #16380

Answered by jakevdp
femtomc asked this question in Q&A
Discussion options

You must be logged in to vote

jaxlib is required to run any jax code, but jax does not pin jaxlib in its requirements because different jaxlib libraries are needed for different hardware (cpu, gpu, tpu, etc.) and Python's dependency tracking system does not have the granularity of expression for that. JAX itself addressed that via the installation instructions.

If you have a jax-dependent library which is meant to be run on multiple platforms, you inherit these difficulties. If you list jaxlib as a hard dependency in your pyproject.toml, it may prevent users from installing the correct jaxlib version for their system. So I'd recommend following JAX, and not listing jaxlib in pyproject.toml, but rather be clear in your…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@femtomc
Comment options

Answer selected by femtomc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants