Skip to content

Commit

Permalink
Fix empty mesh size and abstract_mesh
Browse files Browse the repository at this point in the history
* Fix `size` to return 0 rather than 1 for the empty mesh.
* Fix `abstract_mesh` to return an empty abstract mesh.

PiperOrigin-RevId: 665408468
  • Loading branch information
Google-ML-Automation authored and jax authors committed Aug 20, 2024
1 parent 7758a9c commit 16eb13e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
11 changes: 7 additions & 4 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def shape_tuple(self):

@property
def size(self):
return math.prod(self.shape.values())
return math.prod(self.shape.values()) if self.devices.ndim else 0

@property
def empty(self):
return self.devices.ndim == 0
return self.size == 0

@functools.cached_property
def is_multi_process(self):
Expand Down Expand Up @@ -337,7 +337,10 @@ class AbstractMesh:

def __init__(self, shape_tuple: tuple[tuple[str, int], ...]):
self.shape_tuple = shape_tuple
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
if self.shape_tuple:
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
else:
self._axis_names, self._axis_sizes = (), ()

def __hash__(self):
return hash(self.shape_tuple)
Expand All @@ -358,7 +361,7 @@ def axis_names(self):

@functools.cached_property
def size(self):
return math.prod(self._axis_sizes)
return math.prod(self._axis_sizes) if self._axis_sizes else 0

@functools.cached_property
def shape(self):
Expand Down
13 changes: 13 additions & 0 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,19 @@ def f(x):
y_ref1 = f(jax.device_put(x, jax.devices()[0]))
self.assertArraysEqual(y, y_ref1)

def test_empty_mesh_creation(self):
mesh = jax.sharding.Mesh(devices=np.empty([]), axis_names=[])
self.assertTrue(mesh.empty)
self.assertEqual(mesh.size, 0)

abstract_mesh = mesh.abstract_mesh
self.assertTrue(abstract_mesh.empty)
self.assertEqual(abstract_mesh.size, 0)

abstract_mesh2 = jax.sharding.AbstractMesh(())
self.assertTrue(abstract_mesh2.empty)
self.assertEqual(abstract_mesh2.size, 0)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 16eb13e

Please sign in to comment.