Skip to content

Commit

Permalink
Some minor changes to make_array_from_callback to use the device_indi…
Browse files Browse the repository at this point in the history
…ces_map method and calculate the indices just once. Also set the `_committed` attribute of shards to what the parent Array has.

PiperOrigin-RevId: 471167295
  • Loading branch information
yashk2810 authored and jax authors committed Aug 31, 2022
1 parent c26c7fd commit da24b99
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
7 changes: 4 additions & 3 deletions jax/experimental/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def addressable_shards(self) -> Sequence[Shard]:
device = db.device()
# Wrap the device arrays in `Array` until C++ returns an Array instead
# of a DA.
array = Array(db.aval, SingleDeviceSharding(device), [db], committed=True,
_skip_checks=True)
array = Array(db.aval, SingleDeviceSharding(device), [db],
committed=self._committed, _skip_checks=True)
out.append(Shard(device, self.sharding, self.shape, array))
return out

Expand Down Expand Up @@ -389,8 +389,9 @@ def _value(self) -> np.ndarray:

def make_array_from_callback(shape: Shape, sharding: Sharding,
data_callback: Callable[[Optional[Index]], ArrayLike]) -> Array:
device_to_index_map = sharding.devices_indices_map(shape)
arrays = [
device_put(data_callback(sharding.device_indices(device, shape)), device)
device_put(data_callback(device_to_index_map[device]), device)
for device in sharding.addressable_devices
]
aval = core.ShapedArray(shape, arrays[0].dtype, weak_type=False)
Expand Down
42 changes: 41 additions & 1 deletion tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for GlobalDeviceArray."""

import os
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
Expand All @@ -22,6 +23,7 @@
from jax._src import config as jax_config
from jax._src import test_util as jtu
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.util import prod
from jax.experimental import PartitionSpec as P
from jax.experimental import sharding
Expand All @@ -31,6 +33,29 @@
config.parse_flags_with_absl()


prev_xla_flags = None

# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xb.get_backend.cache_clear()

# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xb.get_backend.cache_clear()


def create_array(shape, sharding, global_data=None):
if global_data is None:
global_data = np.arange(prod(shape)).reshape(shape)
Expand Down Expand Up @@ -233,7 +258,7 @@ def test_sharded_zeros_like(self):
a, input_data = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
out = jnp.zeros_like(a)
expected = jnp.zeros(input_data.shape, dtype=np.int32)
expected = jnp.zeros(input_data.shape, dtype=np.int64)
self.assertArraysEqual(out, expected)
self.assertLen(out.addressable_shards, 8)
for i in out.addressable_shards:
Expand Down Expand Up @@ -318,6 +343,21 @@ def test_mismatch_dtype(self):
"Got int32, expected float32"):
array.Array(jax.ShapedArray(shape, np.float32), s, bufs, committed=True)

@jax_config.jax_array(True)
def test_array_shards_committed(self):
if jax.device_count() < 2:
self.skipTest('Test requires >= 2 devices.')

x = jnp.array([1, 2, 3])
for s in x.addressable_shards:
self.assertEqual(s.data._committed, x._committed)
self.assertFalse(s.data._committed)

y = jax.device_put(x, jax.devices()[1])
for s in y.addressable_shards:
self.assertEqual(s.data._committed, y._committed)
self.assertTrue(s.data._committed)


class ShardingTest(jtu.JaxTestCase):

Expand Down

0 comments on commit da24b99

Please sign in to comment.