Skip to content

Different behavior when loading .npy and .npz files using jnp.load #16029

Answered by jakevdp
Xin-yang-Liu asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question. JAX's jnp.save and jnp.savez are just aliases of numpy's functions, and jnp.load is a light wrapper around numpy's load that only does special handling for the single-array case. Given that, the behavior you observe is expected.

It would be nice if we could make jnp.savez/jnp.load work seamlessly with JAX arrays, but I suspect that would be a bit of a project because it would involve re-implementing numpy's NpzFile and other structures from scratch. I think it's unlikely that that work will be done any time soon. In the meantime, I'd suggest that you anticipate that arrays loaded from npz files will be NumPy arrays rather than JAX arrays, and use jnp.asarray to co…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Xin-yang-Liu
Comment options

Answer selected by Xin-yang-Liu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants