Skip to content

Commit

Permalink
Faster host to device transfer on Jax backend (keras-team#20018)
Browse files Browse the repository at this point in the history
* Faster host to device transfer

* Make layouts optional

* Change prefetch condition and allow distribution without passing layout

* Fix bug

* Return a generator instead of inlining it
  • Loading branch information
Hilly12 authored Jul 23, 2024
1 parent 7c6d501 commit 7b516d0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
35 changes: 27 additions & 8 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,25 +967,44 @@ def _get_jax_state(
return tuple(state)


def _distribute_data(data):
def _distribute_data(data, layouts=None):
distribution = distribution_lib.distribution()
if distribution is not None:
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(d.shape),
data,
)
return tree.map_structure(
jax_distribution_lib.distribute_data_input, data, layouts
)

def distribute_single_value(d):
layout = distribution.get_data_layout(d.shape)
return jax_distribution_lib.distribute_data_input(d, layout)

return tree.map_structure(distribute_single_value, data)
else:
return tree.map_structure(jax.device_put, data)
return tree.map_structure(jax.device_put, data)


class JAXEpochIterator(EpochIterator):
def _get_iterator(self):
distribution = distribution_lib.distribution()
if distribution is not None:
return self._get_distributed_iterator(distribution)

return self._prefetch_numpy_iterator(
self.data_adapter.get_jax_iterator()
)

def _get_distributed_iterator(self, distribution):
"""Lazily compute layouts to reduce host to device transfer latency."""
layouts = None
for data in self.data_adapter.get_jax_iterator():
if layouts is None:
layouts = tree.map_structure(
lambda d: jax_distribution_lib._to_jax_layout(
distribution.get_data_layout(d.shape)
),
data,
)
yield _distribute_data(data, layouts)

def _prefetch_numpy_iterator(self, numpy_iterator):
"""Shard and prefetch batches on device.
Expand Down
7 changes: 5 additions & 2 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,11 @@ class DataParallel(Distribution):
Args:
device_mesh: Optional `DeviceMesh` instance.
devices: Optional list of devices.
auto_shard_dataset: Automatically shard the dataset amongst processes.
Defaults to true.
"""

def __init__(self, device_mesh=None, devices=None):
def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
if device_mesh:
self._initialize_with_device_mesh(device_mesh)
elif devices:
Expand All @@ -400,6 +402,7 @@ def __init__(self, device_mesh=None, devices=None):
self._num_process = distribution_lib.num_processes()
self._process_id = distribution_lib.process_id()
self._is_multi_process = self._num_process > 1
self._auto_shard_dataset = auto_shard_dataset

def _initialize_with_device_mesh(self, device_mesh):
if not isinstance(device_mesh, DeviceMesh):
Expand Down Expand Up @@ -459,7 +462,7 @@ def distribute_dataset(self, dataset):
"Only `tf.data.Dataset` is supported for "
f"sharding, got {type(dataset)}"
)
if not self._is_multi_process:
if not self._is_multi_process or not self._auto_shard_dataset:
return dataset

batch_size = tf_data_distribute.compute_batch_size(dataset)
Expand Down

0 comments on commit 7b516d0

Please sign in to comment.