Skip to content

Commit

Permalink
Revert back to adding aval on Device buffers inside local_shards and …
Browse files Browse the repository at this point in the history
…convert the cached property to just the normal property.

This slows down the pjit path because now you are paying the cost to create avals during runtime.

PiperOrigin-RevId: 429647845
  • Loading branch information
yashk2810 authored and jax authors committed Feb 18, 2022
1 parent 1486be7 commit d1b6f5d
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions jax/experimental/global_device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,19 @@ def _create_local_shards(self) -> Sequence[Shard]:
for db in self._device_buffers:
device = db.device()
index, rid = global_indices_rid[device]
if db.aval is None:
db.aval = core.ShapedArray(db.shape, db.dtype)
out.append(Shard(device, index, rid, db))
return out

@pxla.maybe_cached_property
@property
def local_shards(self) -> Sequence[Shard]:
for s in self._local_shards:
# Ignore the type because mypy thinks data is None but local_shards
# cannot have data=None which is checked in `_create_local_shards`.
if s.data.aval is None: # type: ignore
s.data.aval = core.ShapedArray(s.data.shape, s.data.dtype) # type: ignore
return self._local_shards

@pxla.maybe_cached_property
@property
def global_shards(self) -> Sequence[Shard]:
# Populating global_shards lazily (i.e. when requested) because populating
# sthem eagerly leads to a performance regression when training on large
Expand Down

0 comments on commit d1b6f5d

Please sign in to comment.