Error with device_put on TPUv3-32 pod with NamedSharding #14578
-
I'm trying to shard my JAX model on a TPUv3-32 pod using device_put. The code works fine on a single TPUv3-8 device, but I'm getting the following error on the TPUv3-32 pod: device_put's second argument must be a Device or a Sharding which represents addressable devices, but got NamedSharding(mesh={'model': 32, 'data': 1}, spec=PartitionSpec('model', 'data')) It seems like the issue is with the NamedSharding object that I'm passing as the second argument to device_put. Here's a simplified version of my code: import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec, NamedSharding
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
devices = mesh_utils.create_device_mesh(mesh_shape)
mesh = Mesh(devices, axis_names=('model', 'data'))
y = jax.device_put(x, NamedSharding(mesh, PartitionSpec('model', 'data'))) I would appreciate any help in resolving this issue. Thank you! Steps to reproduce: Run the code snippet above on a TPUv3-32 pod. The code should successfully shard the input array and place it on the TPUv3-32 devices using device_put. Actual behavior: The code throws an error: device_put's second argument must be a Device or a Sharding which represents addressable devices, but got NamedSharding(mesh={'model': 32, 'data': 1}, spec=PartitionSpec('model', 'data')). |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
device_put does not work across multiple processes. You can use |
Beta Was this translation helpful? Give feedback.
device_put does not work across multiple processes. You can use
jax.make_array_from_callback
orjax.make_array_from_single_device_arrays
or an identity pjit to do this!