Skip to content

Commit

Permalink
All DataAdapters can now create a native iterator for each backend. (#…
Browse files Browse the repository at this point in the history
…19041)

- Added `get_jax_iterator` and `get_torch_dataloader` to all `DataAdapter`s.
- `GeneratorDataAdapter` and `PyDatasetAdapter` can now consume tensors from any backend (added support for JAX and Torch). As a result, any combination of input format is supported by these two `DataAdapter`s.
- Made `DataAdapter`s unit tests similar.
- Fixed gap where `shuffle="batch"` was not implemented (used a global shuffle) for the numpy iterator.
- `GeneratorDataAdapter` no longer peeks twice at the first element.
- Removed the concept of `return_type` in `EpochIterator` since it is now always "auto".
- Each backend has a subclass of `EpochIterator` (this is not new), which is now in charge of retrieving the correct iterator for the backend.
- This prevents the double conversion that was happening in some cases (e.g. `TFDatasetAdapter` from `tf.Tensor` to numpy to JAX or Torch). Note that `ArrayDataAdapter` still needs some work to not have double conversions.
- The optimization of using `np.asarray` instead of `np.array` was moved from the `DataAdapter`s to the backend's `convert_to_numpy` since that is what is now used by the `DataAdapter`s.
- Also fixes #19038
  • Loading branch information
hertschuh authored Jan 19, 2024
1 parent 9815ac1 commit b026ff7
Show file tree
Hide file tree
Showing 21 changed files with 442 additions and 315 deletions.
4 changes: 2 additions & 2 deletions keras/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def convert_to_tensor(x, dtype=None, sparse=None):
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
return jnp.array(x, dtype=dtype)
return jnp.asarray(x, dtype=dtype)


def convert_to_numpy(x):
return np.array(x)
return np.asarray(x)


def is_tensor(x):
Expand Down
9 changes: 4 additions & 5 deletions keras/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,11 +904,10 @@ def distribute_single_value(d):


class JAXEpochIterator(EpochIterator):
def _get_iterator(self, return_type="auto"):
if return_type in ("np", "auto"):
# enable prefetching when using numpy_iterator
return self._prefetch_numpy_iterator(super()._get_iterator("np"))
return super()._get_iterator(return_type)
def _get_iterator(self):
return self._prefetch_numpy_iterator(
self.data_adapter.get_jax_iterator()
)

def _prefetch_numpy_iterator(self, numpy_iterator):
"""Shard and prefetch batches on device.
Expand Down
6 changes: 3 additions & 3 deletions keras/backend/numpy/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def append_to_outputs(batch_outputs, outputs):
self.stop_predicting = False
callbacks.on_predict_begin()
outputs = None
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
for step, data in epoch_iterator.enumerate_epoch():
callbacks.on_predict_batch_begin(step)
batch_outputs = self.predict_function(data)
outputs = append_to_outputs(batch_outputs, outputs)
Expand Down Expand Up @@ -242,7 +242,7 @@ def evaluate(

if not all(layer.built for layer in self._flatten_layers()):
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
for _, data in epoch_iterator.enumerate_epoch():
data_batch = data[0]
self._symbolic_build(data_batch)
break
Expand All @@ -264,7 +264,7 @@ def evaluate(
callbacks.on_test_begin()
logs = None
self.reset_metrics()
for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
for step, data in epoch_iterator.enumerate_epoch():
callbacks.on_test_batch_begin(step)
logs = self.test_function(data)
callbacks.on_test_batch_end(step, self._pythonify_logs(logs))
Expand Down
2 changes: 1 addition & 1 deletion keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def convert_to_numpy(x):
x.set_shape(x_shape)
elif isinstance(x, tf.IndexedSlices):
x = tf.convert_to_tensor(x)
return np.array(x)
return np.asarray(x)


def is_tensor(x):
Expand Down
5 changes: 4 additions & 1 deletion keras/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,14 +629,17 @@ class TFEpochIterator(EpochIterator):
def __init__(self, distribute_strategy=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self._distribute_strategy = distribute_strategy
dataset = self.data_adapter.get_tf_dataset()
dataset = self._get_iterator()
if not isinstance(dataset, tf.distribute.DistributedDataset):
dataset = self._distribute_strategy.experimental_distribute_dataset(
dataset
)
self._distributed_dataset = dataset
self._steps_seen = 0

def _get_iterator(self):
return self.data_adapter.get_tf_dataset()

def enumerate_epoch(self):
if self.steps_per_epoch:
if not self._current_iterator:
Expand Down
42 changes: 2 additions & 40 deletions keras/backend/torch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import collections
import itertools
import warnings

import numpy as np
Expand All @@ -10,7 +8,6 @@
from keras import backend
from keras import callbacks as callbacks_module
from keras import optimizers as optimizers_module
from keras.trainers import data_adapters
from keras.trainers import trainer as base_trainer
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.epoch_iterator import EpochIterator
Expand Down Expand Up @@ -496,40 +493,5 @@ def predict_on_batch(self, x):


class TorchEpochIterator(EpochIterator):
def _get_iterator(self, return_type="auto"):
if return_type == "auto" and isinstance(
self.data_adapter, data_adapters.TorchDataLoaderAdapter
):
return self.data_adapter.get_torch_dataloader()
elif return_type in ("np", "auto"):
# enable prefetching when using numpy_iterator
return self._prefetch_numpy_iterator(super()._get_iterator("np"))
return super()._get_iterator(return_type)

def _prefetch_numpy_data(self, data):
return tree.map_structure(backend.convert_to_tensor, data)

def _prefetch_numpy_iterator(self, numpy_iterator):
"""Prefetch batches on device.
The idea has been borrowed from
`torchtnt.utils.data.CudaDataPrefetcher`
This utility takes an iterator and returns a new iterator which fills an
on device prefetch buffer. Eager prefetching can improve the performance
of training loops significantly by overlapping compute and data
transfer.
"""
queue = collections.deque()

# If you're training on GPUs, 2 is generally the best choice because
# this guarantees that you can overlap a training step on GPU with a
# data prefetch step on CPU.
def enqueue(n=2):
for data in itertools.islice(numpy_iterator, n):
queue.append(self._prefetch_numpy_data(data))

enqueue(n=2) # TODO: should we make `n` configurable?
while queue:
yield queue.popleft()
enqueue(1)
def _get_iterator(self):
return self.data_adapter.get_torch_dataloader()
14 changes: 10 additions & 4 deletions keras/layers/normalization/batch_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def test_correctness(
broadcast_shape = [1] * len(input_shape)
broadcast_shape[axis] = input_shape[axis]
out = backend.convert_to_numpy(out)
out -= np.reshape(backend.convert_to_numpy(layer.beta), broadcast_shape)
out /= np.reshape(
out = out - np.reshape(
backend.convert_to_numpy(layer.beta), broadcast_shape
)
out = out / np.reshape(
backend.convert_to_numpy(layer.gamma), broadcast_shape
)

Expand Down Expand Up @@ -200,8 +202,12 @@ def test_trainable_behavior(self):
out = layer(x, training=True)

out = backend.convert_to_numpy(out)
out -= np.reshape(backend.convert_to_numpy(layer.beta), (1, 1, 1, 3))
out /= np.reshape(backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3))
out = out - np.reshape(
backend.convert_to_numpy(layer.beta), (1, 1, 1, 3)
)
out = out / np.reshape(
backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3)
)

self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)
self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)
77 changes: 73 additions & 4 deletions keras/trainers/data_adapters/array_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras import backend
from keras.trainers.data_adapters import data_adapter_utils
from keras.trainers.data_adapters.data_adapter import DataAdapter
from keras.utils.dataset_utils import is_torch_tensor
from keras.utils.nest import lists_to_tuples

try:
Expand Down Expand Up @@ -98,13 +99,23 @@ def __init__(

def get_numpy_iterator(self):
inputs = self._inputs
if self._shuffle:
if self._shuffle and self._shuffle != "batch":
inputs = data_adapter_utils.sync_shuffle(
inputs, num_samples=self._num_samples
)
for i in range(self._size):
start, stop = i * self._batch_size, (i + 1) * self._batch_size
yield tree.map_structure(lambda x: x[start:stop], inputs)
start = i * self._batch_size
stop = min((i + 1) * self._batch_size, self._num_samples)
if self._shuffle == "batch":

def slice_and_shuffle(x):
return data_adapter_utils.sync_shuffle(
x[start:stop], num_samples=(stop - start)
)

yield tree.map_structure(slice_and_shuffle, inputs)
else:
yield tree.map_structure(lambda x: x[start:stop], inputs)

def get_tf_dataset(self):
from keras.utils.module_utils import tensorflow as tf
Expand Down Expand Up @@ -237,6 +248,62 @@ def shuffle_batch(*batch):
dataset = dataset.with_options(options)
return dataset.prefetch(tf.data.AUTOTUNE)

def get_jax_iterator(self):
return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator())

def get_torch_dataloader(self):
import torch

from keras.backend.torch.core import convert_to_tensor

class ArrayDataset(torch.utils.data.Dataset):
def __init__(self, array):
self.array = array

def __getitem__(self, index):
def slice_and_convert(x):
return convert_to_tensor(x[index])

return tree.map_structure(slice_and_convert, self.array)

def __len__(self):
return len(self.array[0])

class RandomBatchSampler(torch.utils.data.Sampler):
def __init__(self, sampler):
self.sampler = sampler

def __iter__(self):
for batch in self.sampler:
yield [batch[i] for i in torch.randperm(len(batch))]

def __len__(self):
return len(self.sampler)

if self._shuffle == "batch":
batch_sampler = RandomBatchSampler(
torch.utils.data.BatchSampler(
range(self._num_samples),
batch_size=self._batch_size,
drop_last=False,
)
)
elif self._shuffle:
batch_sampler = torch.utils.data.BatchSampler(
torch.utils.data.RandomSampler(range(self._num_samples)),
batch_size=self._batch_size,
drop_last=False,
)
else:
batch_sampler = torch.utils.data.BatchSampler(
torch.utils.data.SequentialSampler(range(self._num_samples)),
batch_size=self._batch_size,
drop_last=False,
)

dataset = ArrayDataset(self._inputs)
return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)

@property
def num_batches(self):
return self._size
Expand Down Expand Up @@ -315,7 +382,9 @@ def convert_single_array(x):
# `torch.Tensor`, as well as any other tensor-like object that has
# added numpy support.
if hasattr(x, "__array__"):
x = backend.convert_to_numpy(x)
if is_torch_tensor(x):
x = x.cpu()
x = np.asarray(x)
else:
raise ValueError(
"Expected a NumPy array, tf.Tensor, tf.RaggedTensor, "
Expand Down
Loading

0 comments on commit b026ff7

Please sign in to comment.