issue with d type float and GPU/TPU #7194
Unanswered
kyrieheard
asked this question in
Q&A
Replies: 1 comment
-
Hi - thanks for the question and the clear reproduction! The first warning you're getting is normal on a CPU backend – JAX prefers GPU or TPU when available, and will warn you when falling back to CPU unless you explicitly request it (see #6805) The second warning you're getting is due to the fact that you're requesting float64 types when running outside X64 mode; see 🔪 JAX - The Sharp Bits 🔪: Double (64bit) precision for more information on this. With that in mind, I find that the expected output is produced if you enable X64-mode on your script by putting these lines at the top: from jax.config import config
config.update("jax_enable_x64", True) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm using code from an article for a project but I can't get it to work.
Output is supposed to be
But i'm getting this
Here is the code
Beta Was this translation helpful? Give feedback.
All reactions