Different behavior when loading .npy and .npz files using jnp.load #16029
-
I am loading data from a To reproduce: >>> import jax.numpy as jnp
>>> a=jnp.ones(5)
>>> jnp.savez('test',a=a)
>>> type(a)
<class 'jaxlib.xla_extension.ArrayImpl'>
>>> data=jnp.load('test.npz')
>>> type(data['a'])
<class 'numpy.ndarray'>
>>> jnp.save('test2',a)
>>> type(jnp.load('test2.npy'))
<class 'jaxlib.xla_extension.ArrayImpl'> |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question. JAX's It would be nice if we could make |
Beta Was this translation helpful? Give feedback.
Thanks for the question. JAX's
jnp.save
andjnp.savez
are just aliases of numpy's functions, andjnp.load
is a light wrapper around numpy'sload
that only does special handling for the single-array case. Given that, the behavior you observe is expected.It would be nice if we could make
jnp.savez
/jnp.load
work seamlessly with JAX arrays, but I suspect that would be a bit of a project because it would involve re-implementing numpy'sNpzFile
and other structures from scratch. I think it's unlikely that that work will be done any time soon. In the meantime, I'd suggest that you anticipate that arrays loaded fromnpz
files will be NumPy arrays rather than JAX arrays, and usejnp.asarray
to co…