diff --git a/jax/experimental/array.py b/jax/experimental/array.py index e6669d51b4fc..f89ae5b38c6b 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -314,8 +314,8 @@ def addressable_shards(self) -> Sequence[Shard]: device = db.device() # Wrap the device arrays in `Array` until C++ returns an Array instead # of a DA. - array = Array(db.aval, SingleDeviceSharding(device), [db], committed=True, - _skip_checks=True) + array = Array(db.aval, SingleDeviceSharding(device), [db], + committed=self._committed, _skip_checks=True) out.append(Shard(device, self.sharding, self.shape, array)) return out @@ -389,8 +389,9 @@ def _value(self) -> np.ndarray: def make_array_from_callback(shape: Shape, sharding: Sharding, data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array: + device_to_index_map = sharding.devices_indices_map(shape) arrays = [ - device_put(data_callback(sharding.device_indices(device, shape)), device) + device_put(data_callback(device_to_index_map[device]), device) for device in sharding.addressable_devices ] aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False) diff --git a/tests/array_test.py b/tests/array_test.py index eb643cafb701..d79cc7c61e86 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for GlobalDeviceArray.""" +import os from absl.testing import absltest from absl.testing import parameterized import numpy as np @@ -22,6 +23,7 @@ from jax._src import config as jax_config from jax._src import test_util as jtu from jax._src.lib import xla_client as xc +from jax._src.lib import xla_bridge as xb from jax._src.util import prod from jax.experimental import PartitionSpec as P from jax.experimental import sharding @@ -31,6 +33,29 @@ config.parse_flags_with_absl() +prev_xla_flags = None + +# Run all tests with 8 CPU devices. +def setUpModule(): + global prev_xla_flags + prev_xla_flags = os.getenv("XLA_FLAGS") + flags_str = prev_xla_flags or "" + # Don't override user-specified device count, or other XLA flags. + if "xla_force_host_platform_device_count" not in flags_str: + os.environ["XLA_FLAGS"] = (flags_str + + " --xla_force_host_platform_device_count=8") + # Clear any cached backends so new CPU backend will pick up the env var. + xb.get_backend.cache_clear() + +# Reset to previous configuration in case other test modules will be run. +def tearDownModule(): + if prev_xla_flags is None: + del os.environ["XLA_FLAGS"] + else: + os.environ["XLA_FLAGS"] = prev_xla_flags + xb.get_backend.cache_clear() + + def create_array(shape, sharding, global_data=None): if global_data is None: global_data = np.arange(prod(shape)).reshape(shape) @@ -233,7 +258,7 @@ def test_sharded_zeros_like(self): a, input_data = create_array( input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y'))) out = jnp.zeros_like(a) - expected = jnp.zeros(input_data.shape, dtype=np.int32) + expected = jnp.zeros(input_data.shape, dtype=np.int64) self.assertArraysEqual(out, expected) self.assertLen(out.addressable_shards, 8) for i in out.addressable_shards: @@ -318,6 +343,21 @@ def test_mismatch_dtype(self): "Got int32, expected float32"): array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True) + @jax_config.jax_array(True) + def test_array_shards_committed(self): + if jax.device_count() < 2: + self.skipTest('Test requires >= 2 devices.') + + x = jnp.array([1, 2, 3]) + for s in x.addressable_shards: + self.assertEqual(s.data._committed, x._committed) + self.assertFalse(s.data._committed) + + y = jax.device_put(x, jax.devices()[1]) + for s in y.addressable_shards: + self.assertEqual(s.data._committed, y._committed) + self.assertTrue(s.data._committed) + class ShardingTest(jtu.JaxTestCase):