From d1b6f5d908fa9c516d9658e42004a005f9085471 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 18 Feb 2022 14:23:49 -0800 Subject: [PATCH] Revert back to adding aval on Device buffers inside local_shards and convert the cached property to just the normal property. This slows down the pjit path because now you are paying the cost to create avals during runtime. PiperOrigin-RevId: 429647845 --- jax/experimental/global_device_array.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index fbcaabd98fff..e854a36a50dc 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -337,16 +337,19 @@ def _create_local_shards(self) -> Sequence[Shard]: for db in self._device_buffers: device = db.device() index, rid = global_indices_rid[device] - if db.aval is None: - db.aval = core.ShapedArray(db.shape, db.dtype) out.append(Shard(device, index, rid, db)) return out - @pxla.maybe_cached_property + @property def local_shards(self) -> Sequence[Shard]: + for s in self._local_shards: + # Ignore the type because mypy thinks data is None but local_shards + # cannot have data=None which is checked in `_create_local_shards`. + if s.data.aval is None: # type: ignore + s.data.aval = core.ShapedArray(s.data.shape, s.data.dtype) # type: ignore return self._local_shards - @pxla.maybe_cached_property + @property def global_shards(self) -> Sequence[Shard]: # Populating global_shards lazily (i.e. when requested) because populating # sthem eagerly leads to a performance regression when training on large