-
Curious about the best practices on doing type casting in either numpy arrays or jax numpy arrays. How compatible it is for numpy array to handle jax dtypes? i.e., can I use JAX types everywhere for both numpy array handlings and jax numpy array handlings? e.g., would |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Short answer, yes! Long answer: NumPy scalar types like JAX scalar types like
CPUs do have In [1]: import ml_dtypes
In [2]: import numpy as np
In [3]: np.arange(10, dtype=ml_dtypes.bfloat16)
Out[3]: array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=bfloat16) This is the same |
Beta Was this translation helpful? Give feedback.
Short answer, yes!
Long answer:
jnp.float16
is not a dtype, it's a JAX scalar array constructor. This is a mirror of the design decision of NumPy, wherenp.float16
is not a dtype, but rather a scalar object constructor. For convenience, NumPy lets you write things likedtype=np.float16
as a shorthand fordtype=np.dtype('float16')
, and we wanteddtype=jnp.float16
to work similarly.NumPy scalar types like
np.float16
are recognized as valid dtypes due to specific code paths in thenp.dtype
constructor designed to support this shortcut.JAX scalar types like
jnp.float16
are recognized as valid dtypes becausenp.dtype
looks for adtype
attribute on any unknown object passed to it, and JAX sca…