Skip to content

Is it safe to use JAX dtypes for type casting in numpy? #24844

Answered by jakevdp
uduse asked this question in Q&A
Discussion options

You must be logged in to vote

Short answer, yes!

Long answer: jnp.float16 is not a dtype, it's a JAX scalar array constructor. This is a mirror of the design decision of NumPy, where np.float16 is not a dtype, but rather a scalar object constructor. For convenience, NumPy lets you write things like dtype=np.float16 as a shorthand for dtype=np.dtype('float16'), and we wanted dtype=jnp.float16 to work similarly.

NumPy scalar types like np.float16 are recognized as valid dtypes due to specific code paths in the np.dtype constructor designed to support this shortcut.

JAX scalar types like jnp.float16 are recognized as valid dtypes because np.dtype looks for a dtype attribute on any unknown object passed to it, and JAX sca…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by uduse
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants