Skip to content

Error with device_put on TPUv3-32 pod with NamedSharding #14578

Answered by yashk2810
magicknight asked this question in Q&A
Discussion options

You must be logged in to vote

device_put does not work across multiple processes. You can use jax.make_array_from_callback or jax.make_array_from_single_device_arrays or an identity pjit to do this!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@magicknight
Comment options

Answer selected by magicknight
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants