From b026ff7005003b46f165e9bfe49e9caf063729cf Mon Sep 17 00:00:00 2001 From: hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Fri, 19 Jan 2024 14:57:18 -0800 Subject: [PATCH] All DataAdapters can now create a native iterator for each backend. (#19041) - Added `get_jax_iterator` and `get_torch_dataloader` to all `DataAdapter`s. - `GeneratorDataAdapter` and `PyDatasetAdapter` can now consume tensors from any backend (added support for JAX and Torch). As a result, any combination of input format is supported by these two `DataAdapter`s. - Made `DataAdapter`s unit tests similar. - Fixed gap where `shuffle="batch"` was not implemented (used a global shuffle) for the numpy iterator. - `GeneratorDataAdapter` no longer peeks twice at the first element. - Removed the concept of `return_type` in `EpochIterator` since it is now always "auto". - Each backend has a subclass of `EpochIterator` (this is not new), which is now in charge of retrieving the correct iterator for the backend. - This prevents the double conversion that was happening in some cases (e.g. `TFDatasetAdapter` from `tf.Tensor` to numpy to JAX or Torch). Note that `ArrayDataAdapter` still needs some work to not have double conversions. - The optimization of using `np.asarray` instead of `np.array` was moved from the `DataAdapter`s to the backend's `convert_to_numpy` since that is what is now used by the `DataAdapter`s. - Also fixes https://github.com/keras-team/keras/issues/19038 --- keras/backend/jax/core.py | 4 +- keras/backend/jax/trainer.py | 9 +- keras/backend/numpy/trainer.py | 6 +- keras/backend/tensorflow/core.py | 2 +- keras/backend/tensorflow/trainer.py | 5 +- keras/backend/torch/trainer.py | 42 +------- .../normalization/batch_normalization_test.py | 14 ++- .../data_adapters/array_data_adapter.py | 77 ++++++++++++++- .../data_adapters/array_data_adapter_test.py | 88 ++++++++++------- keras/trainers/data_adapters/data_adapter.py | 19 +++- .../data_adapters/data_adapter_utils.py | 42 ++++++++ .../data_adapters/generator_data_adapter.py | 32 +++--- .../generator_data_adapter_test.py | 98 ++++++++++--------- .../data_adapters/py_dataset_adapter.py | 17 +++- .../data_adapters/py_dataset_adapter_test.py | 96 +++++++++++------- .../data_adapters/tf_dataset_adapter.py | 17 ++-- .../data_adapters/tf_dataset_adapter_test.py | 47 +++++---- .../torch_data_loader_adapter.py | 16 +-- .../torch_data_loader_adapter_test.py | 77 ++++++--------- keras/trainers/epoch_iterator.py | 21 ++-- keras/trainers/epoch_iterator_test.py | 28 +----- 21 files changed, 442 insertions(+), 315 deletions(-) diff --git a/keras/backend/jax/core.py b/keras/backend/jax/core.py index 39c51d441b9..e543da526ca 100644 --- a/keras/backend/jax/core.py +++ b/keras/backend/jax/core.py @@ -61,11 +61,11 @@ def convert_to_tensor(x, dtype=None, sparse=None): if dtype and dtype != x.dtype: return x.value.astype(dtype) return x.value - return jnp.array(x, dtype=dtype) + return jnp.asarray(x, dtype=dtype) def convert_to_numpy(x): - return np.array(x) + return np.asarray(x) def is_tensor(x): diff --git a/keras/backend/jax/trainer.py b/keras/backend/jax/trainer.py index 7041491b223..394054225cb 100644 --- a/keras/backend/jax/trainer.py +++ b/keras/backend/jax/trainer.py @@ -904,11 +904,10 @@ def distribute_single_value(d): class JAXEpochIterator(EpochIterator): - def _get_iterator(self, return_type="auto"): - if return_type in ("np", "auto"): - # enable prefetching when using numpy_iterator - return self._prefetch_numpy_iterator(super()._get_iterator("np")) - return super()._get_iterator(return_type) + def _get_iterator(self): + return self._prefetch_numpy_iterator( + self.data_adapter.get_jax_iterator() + ) def _prefetch_numpy_iterator(self, numpy_iterator): """Shard and prefetch batches on device. diff --git a/keras/backend/numpy/trainer.py b/keras/backend/numpy/trainer.py index ba5abcbbdf3..b82599a8eba 100644 --- a/keras/backend/numpy/trainer.py +++ b/keras/backend/numpy/trainer.py @@ -198,7 +198,7 @@ def append_to_outputs(batch_outputs, outputs): self.stop_predicting = False callbacks.on_predict_begin() outputs = None - for step, data in epoch_iterator.enumerate_epoch(return_type="np"): + for step, data in epoch_iterator.enumerate_epoch(): callbacks.on_predict_batch_begin(step) batch_outputs = self.predict_function(data) outputs = append_to_outputs(batch_outputs, outputs) @@ -242,7 +242,7 @@ def evaluate( if not all(layer.built for layer in self._flatten_layers()): # Build the model on one batch of data. - for _, data in epoch_iterator.enumerate_epoch(return_type="np"): + for _, data in epoch_iterator.enumerate_epoch(): data_batch = data[0] self._symbolic_build(data_batch) break @@ -264,7 +264,7 @@ def evaluate( callbacks.on_test_begin() logs = None self.reset_metrics() - for step, data in epoch_iterator.enumerate_epoch(return_type="np"): + for step, data in epoch_iterator.enumerate_epoch(): callbacks.on_test_batch_begin(step) logs = self.test_function(data) callbacks.on_test_batch_end(step, self._pythonify_logs(logs)) diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index 9385ceac28d..9555b612fae 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -126,7 +126,7 @@ def convert_to_numpy(x): x.set_shape(x_shape) elif isinstance(x, tf.IndexedSlices): x = tf.convert_to_tensor(x) - return np.array(x) + return np.asarray(x) def is_tensor(x): diff --git a/keras/backend/tensorflow/trainer.py b/keras/backend/tensorflow/trainer.py index f42f812892d..0bcef10a089 100644 --- a/keras/backend/tensorflow/trainer.py +++ b/keras/backend/tensorflow/trainer.py @@ -629,7 +629,7 @@ class TFEpochIterator(EpochIterator): def __init__(self, distribute_strategy=None, *args, **kwargs): super().__init__(*args, **kwargs) self._distribute_strategy = distribute_strategy - dataset = self.data_adapter.get_tf_dataset() + dataset = self._get_iterator() if not isinstance(dataset, tf.distribute.DistributedDataset): dataset = self._distribute_strategy.experimental_distribute_dataset( dataset @@ -637,6 +637,9 @@ def __init__(self, distribute_strategy=None, *args, **kwargs): self._distributed_dataset = dataset self._steps_seen = 0 + def _get_iterator(self): + return self.data_adapter.get_tf_dataset() + def enumerate_epoch(self): if self.steps_per_epoch: if not self._current_iterator: diff --git a/keras/backend/torch/trainer.py b/keras/backend/torch/trainer.py index 8ac19f27548..1557e18f1cb 100644 --- a/keras/backend/torch/trainer.py +++ b/keras/backend/torch/trainer.py @@ -1,5 +1,3 @@ -import collections -import itertools import warnings import numpy as np @@ -10,7 +8,6 @@ from keras import backend from keras import callbacks as callbacks_module from keras import optimizers as optimizers_module -from keras.trainers import data_adapters from keras.trainers import trainer as base_trainer from keras.trainers.data_adapters import data_adapter_utils from keras.trainers.epoch_iterator import EpochIterator @@ -496,40 +493,5 @@ def predict_on_batch(self, x): class TorchEpochIterator(EpochIterator): - def _get_iterator(self, return_type="auto"): - if return_type == "auto" and isinstance( - self.data_adapter, data_adapters.TorchDataLoaderAdapter - ): - return self.data_adapter.get_torch_dataloader() - elif return_type in ("np", "auto"): - # enable prefetching when using numpy_iterator - return self._prefetch_numpy_iterator(super()._get_iterator("np")) - return super()._get_iterator(return_type) - - def _prefetch_numpy_data(self, data): - return tree.map_structure(backend.convert_to_tensor, data) - - def _prefetch_numpy_iterator(self, numpy_iterator): - """Prefetch batches on device. - - The idea has been borrowed from - `torchtnt.utils.data.CudaDataPrefetcher` - - This utility takes an iterator and returns a new iterator which fills an - on device prefetch buffer. Eager prefetching can improve the performance - of training loops significantly by overlapping compute and data - transfer. - """ - queue = collections.deque() - - # If you're training on GPUs, 2 is generally the best choice because - # this guarantees that you can overlap a training step on GPU with a - # data prefetch step on CPU. - def enqueue(n=2): - for data in itertools.islice(numpy_iterator, n): - queue.append(self._prefetch_numpy_data(data)) - - enqueue(n=2) # TODO: should we make `n` configurable? - while queue: - yield queue.popleft() - enqueue(1) + def _get_iterator(self): + return self.data_adapter.get_torch_dataloader() diff --git a/keras/layers/normalization/batch_normalization_test.py b/keras/layers/normalization/batch_normalization_test.py index 27a36d0e9a3..7ea3c4bbba4 100644 --- a/keras/layers/normalization/batch_normalization_test.py +++ b/keras/layers/normalization/batch_normalization_test.py @@ -92,8 +92,10 @@ def test_correctness( broadcast_shape = [1] * len(input_shape) broadcast_shape[axis] = input_shape[axis] out = backend.convert_to_numpy(out) - out -= np.reshape(backend.convert_to_numpy(layer.beta), broadcast_shape) - out /= np.reshape( + out = out - np.reshape( + backend.convert_to_numpy(layer.beta), broadcast_shape + ) + out = out / np.reshape( backend.convert_to_numpy(layer.gamma), broadcast_shape ) @@ -200,8 +202,12 @@ def test_trainable_behavior(self): out = layer(x, training=True) out = backend.convert_to_numpy(out) - out -= np.reshape(backend.convert_to_numpy(layer.beta), (1, 1, 1, 3)) - out /= np.reshape(backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3)) + out = out - np.reshape( + backend.convert_to_numpy(layer.beta), (1, 1, 1, 3) + ) + out = out / np.reshape( + backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3) + ) self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3) self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3) diff --git a/keras/trainers/data_adapters/array_data_adapter.py b/keras/trainers/data_adapters/array_data_adapter.py index aa5b7910fc9..1f393f59fe9 100644 --- a/keras/trainers/data_adapters/array_data_adapter.py +++ b/keras/trainers/data_adapters/array_data_adapter.py @@ -6,6 +6,7 @@ from keras import backend from keras.trainers.data_adapters import data_adapter_utils from keras.trainers.data_adapters.data_adapter import DataAdapter +from keras.utils.dataset_utils import is_torch_tensor from keras.utils.nest import lists_to_tuples try: @@ -98,13 +99,23 @@ def __init__( def get_numpy_iterator(self): inputs = self._inputs - if self._shuffle: + if self._shuffle and self._shuffle != "batch": inputs = data_adapter_utils.sync_shuffle( inputs, num_samples=self._num_samples ) for i in range(self._size): - start, stop = i * self._batch_size, (i + 1) * self._batch_size - yield tree.map_structure(lambda x: x[start:stop], inputs) + start = i * self._batch_size + stop = min((i + 1) * self._batch_size, self._num_samples) + if self._shuffle == "batch": + + def slice_and_shuffle(x): + return data_adapter_utils.sync_shuffle( + x[start:stop], num_samples=(stop - start) + ) + + yield tree.map_structure(slice_and_shuffle, inputs) + else: + yield tree.map_structure(lambda x: x[start:stop], inputs) def get_tf_dataset(self): from keras.utils.module_utils import tensorflow as tf @@ -237,6 +248,62 @@ def shuffle_batch(*batch): dataset = dataset.with_options(options) return dataset.prefetch(tf.data.AUTOTUNE) + def get_jax_iterator(self): + return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator()) + + def get_torch_dataloader(self): + import torch + + from keras.backend.torch.core import convert_to_tensor + + class ArrayDataset(torch.utils.data.Dataset): + def __init__(self, array): + self.array = array + + def __getitem__(self, index): + def slice_and_convert(x): + return convert_to_tensor(x[index]) + + return tree.map_structure(slice_and_convert, self.array) + + def __len__(self): + return len(self.array[0]) + + class RandomBatchSampler(torch.utils.data.Sampler): + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + for batch in self.sampler: + yield [batch[i] for i in torch.randperm(len(batch))] + + def __len__(self): + return len(self.sampler) + + if self._shuffle == "batch": + batch_sampler = RandomBatchSampler( + torch.utils.data.BatchSampler( + range(self._num_samples), + batch_size=self._batch_size, + drop_last=False, + ) + ) + elif self._shuffle: + batch_sampler = torch.utils.data.BatchSampler( + torch.utils.data.RandomSampler(range(self._num_samples)), + batch_size=self._batch_size, + drop_last=False, + ) + else: + batch_sampler = torch.utils.data.BatchSampler( + torch.utils.data.SequentialSampler(range(self._num_samples)), + batch_size=self._batch_size, + drop_last=False, + ) + + dataset = ArrayDataset(self._inputs) + return torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler) + @property def num_batches(self): return self._size @@ -315,7 +382,9 @@ def convert_single_array(x): # `torch.Tensor`, as well as any other tensor-like object that has # added numpy support. if hasattr(x, "__array__"): - x = backend.convert_to_numpy(x) + if is_torch_tensor(x): + x = x.cpu() + x = np.asarray(x) else: raise ValueError( "Expected a NumPy array, tf.Tensor, tf.RaggedTensor, " diff --git a/keras/trainers/data_adapters/array_data_adapter_test.py b/keras/trainers/data_adapters/array_data_adapter_test.py index 4a0c94d8bd3..8d500437d04 100644 --- a/keras/trainers/data_adapters/array_data_adapter_test.py +++ b/keras/trainers/data_adapters/array_data_adapter_test.py @@ -1,36 +1,39 @@ +import jax import numpy as np import pandas import pytest import tensorflow as tf +import torch from absl.testing import parameterized from keras import backend from keras import testing +from keras.testing.test_utils import named_product from keras.trainers.data_adapters import array_data_adapter class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase): def make_array(self, array_type, shape, dtype="float32"): + x = np.array([[i] * shape[1] for i in range(shape[0])], dtype=dtype) if array_type == "np": - return np.ones(shape, dtype=dtype) + return x elif array_type == "tf": - return tf.ones(shape, dtype=dtype) - elif array_type == "backend": - if backend.backend() == "jax": - import jax - - return jax.numpy.ones(shape, dtype=dtype) - elif backend.backend() == "torch": - import torch - - return torch.tensor(np.ones(shape, dtype=dtype)) - else: - return tf.ones(shape, dtype=dtype) + return tf.constant(x) + elif array_type == "jax": + return jax.numpy.array(x) + elif array_type == "torch": + return torch.as_tensor(x) elif array_type == "pandas": - return pandas.DataFrame(np.ones(shape, dtype=dtype)) + return pandas.DataFrame(x) - @parameterized.parameters([("np",), ("tf",), ("backend",), ("pandas",)]) - def test_basic_flow(self, array_type): + @parameterized.named_parameters( + named_product( + array_type=["np", "tf", "jax", "torch", "pandas"], + iterator_type=["np", "tf", "jax", "torch"], + shuffle=[False, "batch", True], + ) + ) + def test_basic_flow(self, array_type, iterator_type, shuffle): x = self.make_array(array_type, (34, 4)) y = self.make_array(array_type, (34, 2)) @@ -40,41 +43,46 @@ def test_basic_flow(self, array_type): sample_weight=None, batch_size=16, steps=None, - shuffle=False, + shuffle=shuffle, ) self.assertEqual(adapter.num_batches, 3) self.assertEqual(adapter.batch_size, 16) self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - gen = adapter.get_numpy_iterator() - for i, batch in enumerate(gen): + if iterator_type == "np": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif iterator_type == "tf": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif iterator_type == "jax": + it = adapter.get_jax_iterator() + expected_class = jax.Array + elif iterator_type == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + sample_order = [] + for i, batch in enumerate(it): self.assertEqual(len(batch), 2) bx, by = batch - self.assertIsInstance(bx, np.ndarray) - self.assertIsInstance(by, np.ndarray) + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.dtype, backend.floatx()) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") if i < 2: self.assertEqual(bx.shape, (16, 4)) self.assertEqual(by.shape, (16, 2)) else: self.assertEqual(bx.shape, (2, 4)) self.assertEqual(by.shape, (2, 2)) - ds = adapter.get_tf_dataset() - for i, batch in enumerate(ds): - self.assertEqual(len(batch), 2) - bx, by = batch - self.assertIsInstance(bx, tf.Tensor) - self.assertIsInstance(by, tf.Tensor) - self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.dtype, backend.floatx()) - if i < 2: - self.assertEqual(tuple(bx.shape), (16, 4)) - self.assertEqual(tuple(by.shape), (16, 2)) - else: - self.assertEqual(tuple(bx.shape), (2, 4)) - self.assertEqual(tuple(by.shape), (2, 2)) + for i in range(by.shape[0]): + sample_order.append(by[i, 0]) + if shuffle: + self.assertNotAllClose(sample_order, list(range(34))) + else: + self.assertAllClose(sample_order, list(range(34))) def test_multi_inputs_and_outputs(self): x1 = np.random.random((34, 1)) @@ -153,7 +161,9 @@ def test_multi_inputs_and_outputs(self): self.assertEqual(tuple(bw[0].shape), (2,)) self.assertEqual(tuple(bw[1].shape), (2,)) - @parameterized.parameters([("int",), ("categorical",)]) + @parameterized.named_parameters( + named_product(target_encoding=["int", "categorical"]) + ) def test_class_weights(self, target_encoding): x = np.random.random((4, 2)) if target_encoding == "int": @@ -186,7 +196,9 @@ def test_errors(self): # TODO pass - @parameterized.parameters([("np",), ("tf",), ("backend",), ("pandas",)]) + @parameterized.named_parameters( + named_product(array_type=["np", "tf", "jax", "torch", "pandas"]) + ) def test_integer_inputs(self, array_type): x1 = self.make_array(array_type, (4, 4), dtype="float64") x2 = self.make_array(array_type, (4, 4), dtype="int32") diff --git a/keras/trainers/data_adapters/data_adapter.py b/keras/trainers/data_adapters/data_adapter.py index 269ced4cc63..23d82dce0ee 100644 --- a/keras/trainers/data_adapters/data_adapter.py +++ b/keras/trainers/data_adapters/data_adapter.py @@ -7,7 +7,8 @@ class DataAdapter(object): """ def get_numpy_iterator(self): - """Get a Python iterable for the DataAdapter, that yields NumPy arrays. + """Get a Python iterable for the `DataAdapter`, that yields NumPy + arrays. Returns: A Python iterator. @@ -28,6 +29,22 @@ def get_tf_dataset(self): """ raise NotImplementedError + def get_jax_iterator(self): + """Get a Python iterable for the `DataAdapter`, that yields JAX arrays. + + Returns: + A Python iterator. + """ + raise NotImplementedError + + def get_torch_dataloader(self): + """Get a Torch `DataLoader` for the `DataAdapter`. + + Returns: + A Torch `DataLoader`. + """ + raise NotImplementedError + @property def num_batches(self): """Return the size (number of batches) for the dataset created. diff --git a/keras/trainers/data_adapters/data_adapter_utils.py b/keras/trainers/data_adapters/data_adapter_utils.py index 82261e26bd6..ff8a273a4d2 100644 --- a/keras/trainers/data_adapters/data_adapter_utils.py +++ b/keras/trainers/data_adapters/data_adapter_utils.py @@ -5,6 +5,7 @@ from keras import backend from keras.api_export import keras_export +from keras.utils.dataset_utils import is_torch_tensor try: import pandas @@ -215,3 +216,44 @@ def class_weight_to_sample_weights(y, class_weight): for i in range(y.shape[0]): sample_weight[i] = class_weight.get(int(y[i]), 1.0) return sample_weight + + +def get_jax_iterator(iterable): + from keras.backend.jax.core import convert_to_tensor + + for batch in iterable: + yield tree.map_structure(convert_to_tensor, batch) + + +def get_numpy_iterator(iterable): + def convert_to_numpy(x): + if not isinstance(x, np.ndarray): + # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`, + # `torch.Tensor`, as well as any other tensor-like object that + # has added numpy support. + if hasattr(x, "__array__"): + if is_torch_tensor(x): + x = x.cpu() + x = np.asarray(x) + return x + + for batch in iterable: + yield tree.map_structure(convert_to_numpy, batch) + + +def get_torch_dataloader(iterable): + import torch.utils.data as torch_data + + from keras.backend.torch.core import convert_to_tensor + + class ConverterIterableDataset(torch_data.IterableDataset): + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + for batch in self.iterable: + yield tree.map_structure(convert_to_tensor, batch) + + dataset = ConverterIterableDataset(iterable) + # `batch_size=None` indicates that we should not re-batch + return torch_data.DataLoader(dataset, batch_size=None) diff --git a/keras/trainers/data_adapters/generator_data_adapter.py b/keras/trainers/data_adapters/generator_data_adapter.py index 1fbdc37326a..986fafe037d 100644 --- a/keras/trainers/data_adapters/generator_data_adapter.py +++ b/keras/trainers/data_adapters/generator_data_adapter.py @@ -3,6 +3,8 @@ import numpy as np import tree +from keras import backend +from keras.trainers.data_adapters import data_adapter_utils from keras.trainers.data_adapters.data_adapter import DataAdapter @@ -10,24 +12,22 @@ class GeneratorDataAdapter(DataAdapter): """Adapter for Python generators.""" def __init__(self, generator): - data, generator = peek_and_restore(generator) + first_batch, generator = peek_and_restore(generator) self.generator = generator + self._first_batch = first_batch self._output_signature = None - if not isinstance(data, tuple): + if not isinstance(first_batch, tuple): raise ValueError( "When passing a Python generator to a Keras model, " "the generator must return a tuple, either " "(input,) or (inputs, targets) or " "(inputs, targets, sample_weights). " - f"Received: {data}" + f"Received: {first_batch}" ) def _set_tf_output_signature(self): from keras.utils.module_utils import tensorflow as tf - data, generator = peek_and_restore(self.generator) - self.generator = generator - def get_tensor_spec(x): shape = x.shape if len(shape) < 1: @@ -39,18 +39,23 @@ def get_tensor_spec(x): ) shape = list(shape) shape[0] = None # The batch size is not guaranteed to be static. + dtype = backend.standardize_dtype(x.dtype) if isinstance(x, tf.RaggedTensor): - return tf.RaggedTensorSpec(shape=shape, dtype=x.dtype.name) + return tf.RaggedTensorSpec(shape=shape, dtype=dtype) if isinstance(x, tf.SparseTensor) or is_scipy_sparse(x): - return tf.SparseTensorSpec(shape=shape, dtype=x.dtype.name) + return tf.SparseTensorSpec(shape=shape, dtype=dtype) else: - return tf.TensorSpec(shape=shape, dtype=x.dtype.name) + return tf.TensorSpec(shape=shape, dtype=dtype) - self._output_signature = tree.map_structure(get_tensor_spec, data) + self._output_signature = tree.map_structure( + get_tensor_spec, self._first_batch + ) def get_numpy_iterator(self): - for batch in self.generator: - yield batch + return data_adapter_utils.get_numpy_iterator(self.generator) + + def get_jax_iterator(self): + return data_adapter_utils.get_jax_iterator(self.generator) def get_tf_dataset(self): from keras.utils.module_utils import tensorflow as tf @@ -74,6 +79,9 @@ def get_tf_iterator(): ds = ds.prefetch(tf.data.AUTOTUNE) return ds + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self.generator) + @property def num_batches(self): return None diff --git a/keras/trainers/data_adapters/generator_data_adapter_test.py b/keras/trainers/data_adapters/generator_data_adapter_test.py index aeb4c375bdb..8b3da56c514 100644 --- a/keras/trainers/data_adapters/generator_data_adapter_test.py +++ b/keras/trainers/data_adapters/generator_data_adapter_test.py @@ -1,11 +1,15 @@ import math +import jax import numpy as np import scipy import tensorflow as tf +import torch from absl.testing import parameterized +from jax import numpy as jnp from keras import testing +from keras.testing.test_utils import named_product from keras.trainers.data_adapters import generator_data_adapter @@ -25,18 +29,31 @@ def make(): class GeneratorDataAdapterTest(testing.TestCase, parameterized.TestCase): - @parameterized.parameters( - [ - (True,), - (False,), - ] + @parameterized.named_parameters( + named_product( + [ + {"testcase_name": "use_weight", "use_sample_weight": True}, + {"testcase_name": "no_weight", "use_sample_weight": False}, + ], + generator_type=["np", "tf", "jax", "torch"], + iterator_type=["np", "tf", "jax", "torch"], + ) ) - def test_basic_flow(self, use_sample_weight): - x = np.random.random((64, 4)) - y = np.array([[i, i] for i in range(64)], dtype="float64") - if use_sample_weight: - sw = np.random.random((64,)) - else: + def test_basic_flow(self, use_sample_weight, generator_type, iterator_type): + x = np.random.random((34, 4)).astype("float32") + y = np.array([[i, i] for i in range(34)], dtype="float32") + sw = np.random.random((34,)).astype("float32") + if generator_type == "tf": + x, y, sw = tf.constant(x), tf.constant(y), tf.constant(sw) + elif generator_type == "jax": + x, y, sw = jnp.array(x), jnp.array(y), jnp.array(sw) + elif generator_type == "torch": + x, y, sw = ( + torch.as_tensor(x), + torch.as_tensor(y), + torch.as_tensor(sw), + ) + if not use_sample_weight: sw = None make_generator = example_generator( x, @@ -44,53 +61,46 @@ def test_basic_flow(self, use_sample_weight): sample_weight=sw, batch_size=16, ) + adapter = generator_data_adapter.GeneratorDataAdapter(make_generator()) + if iterator_type == "np": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif iterator_type == "tf": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif iterator_type == "jax": + it = adapter.get_jax_iterator() + expected_class = jax.Array + elif iterator_type == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor - gen = adapter.get_numpy_iterator() sample_order = [] - for batch in gen: + for i, batch in enumerate(it): if use_sample_weight: self.assertEqual(len(batch), 3) bx, by, bsw = batch else: self.assertEqual(len(batch), 2) bx, by = batch - - self.assertIsInstance(bx, np.ndarray) - self.assertIsInstance(by, np.ndarray) + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.shape, (16, 4)) - self.assertEqual(by.shape, (16, 2)) - if use_sample_weight: - self.assertIsInstance(bsw, np.ndarray) - for i in range(by.shape[0]): - sample_order.append(by[i, 0]) - self.assertAllClose(sample_order, list(range(64))) - - adapter = generator_data_adapter.GeneratorDataAdapter( - make_generator(), - ) - ds = adapter.get_tf_dataset() - sample_order = [] - for batch in ds: - if use_sample_weight: - self.assertEqual(len(batch), 3) - bx, by, bsw = batch + self.assertContainsExactSubsequence(str(bx.dtype), "float32") + if i < 2: + self.assertEqual(bx.shape, (16, 4)) + self.assertEqual(by.shape, (16, 2)) else: - self.assertEqual(len(batch), 2) - bx, by = batch - self.assertIsInstance(bx, tf.Tensor) - self.assertIsInstance(by, tf.Tensor) - self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(tuple(bx.shape), (16, 4)) - self.assertEqual(tuple(by.shape), (16, 2)) + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) if use_sample_weight: - self.assertIsInstance(bsw, tf.Tensor) + self.assertIsInstance(bsw, expected_class) for i in range(by.shape[0]): sample_order.append(by[i, 0]) - self.assertAllClose(sample_order, list(range(64))) + self.assertAllClose(sample_order, list(range(34))) - def test_tf_sparse_tensors(self): + def test_tf_sparse_tensors_with_tf_dataset(self): def generate_tf(): for i in range(4): x = tf.SparseTensor( @@ -115,7 +125,7 @@ def generate_tf(): self.assertEqual(bx.shape, (2, 4)) self.assertEqual(by.shape, (2, 2)) - def test_scipy_sparse_tensors(self): + def test_scipy_sparse_tensors_with_tf_dataset(self): def generate_scipy(): for i in range(4): x = scipy.sparse.coo_matrix( diff --git a/keras/trainers/data_adapters/py_dataset_adapter.py b/keras/trainers/data_adapters/py_dataset_adapter.py index 851033cc024..a3704cd7f1f 100644 --- a/keras/trainers/data_adapters/py_dataset_adapter.py +++ b/keras/trainers/data_adapters/py_dataset_adapter.py @@ -10,6 +10,7 @@ import numpy as np import tree +from keras import backend from keras.api_export import keras_export from keras.trainers.data_adapters import data_adapter_utils from keras.trainers.data_adapters.data_adapter import DataAdapter @@ -200,7 +201,8 @@ def get_tensor_spec(x): ) shape = list(shape) shape[0] = None # The batch size is not guaranteed to be static. - return tf.TensorSpec(shape=shape, dtype=x.dtype.name) + dtype = backend.standardize_dtype(x.dtype) + return tf.TensorSpec(shape=shape, dtype=dtype) # Grab the first example batch = self.py_dataset[0] @@ -267,7 +269,7 @@ def generator_fn(): return generator_fn - def get_numpy_iterator(self): + def _get_iterator(self): gen_fn = self._make_multiprocessed_generator_fn() for i, batch in enumerate(gen_fn()): batch = self._standardize_batch(batch) @@ -275,6 +277,12 @@ def get_numpy_iterator(self): if i >= len(self.py_dataset) - 1 and self.enqueuer: self.enqueuer.stop() + def get_numpy_iterator(self): + return data_adapter_utils.get_numpy_iterator(self._get_iterator()) + + def get_jax_iterator(self): + return data_adapter_utils.get_jax_iterator(self._get_iterator()) + def get_tf_dataset(self): from keras.utils.module_utils import tensorflow as tf @@ -282,7 +290,7 @@ def get_tf_dataset(self): self._set_tf_output_signature() ds = tf.data.Dataset.from_generator( - self.get_numpy_iterator, + self._get_iterator, output_signature=self._output_signature, ) if self.shuffle: @@ -290,6 +298,9 @@ def get_tf_dataset(self): ds = ds.prefetch(tf.data.AUTOTUNE) return ds + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self._get_iterator()) + def on_epoch_end(self): if self.enqueuer: self.enqueuer.stop() diff --git a/keras/trainers/data_adapters/py_dataset_adapter_test.py b/keras/trainers/data_adapters/py_dataset_adapter_test.py index 7bd05fadf82..f9ff419ecf1 100644 --- a/keras/trainers/data_adapters/py_dataset_adapter_test.py +++ b/keras/trainers/data_adapters/py_dataset_adapter_test.py @@ -1,11 +1,14 @@ import math import time +import jax import numpy as np import tensorflow as tf +import torch from absl.testing import parameterized from keras import testing +from keras.testing.test_utils import named_product from keras.trainers.data_adapters import py_dataset_adapter from keras.utils.rng_utils import set_random_seed @@ -61,22 +64,51 @@ def __getitem__(self, idx): class PyDatasetAdapterTest(testing.TestCase, parameterized.TestCase): - @parameterized.parameters( - [ - (True, 2, True, 10), - (False, 2, True, 10), - (True, 2, False, 10), - (False, 2, False, 10), - (True, 0, False, 0), - (False, 0, False, 0), - ] + @parameterized.named_parameters( + named_product( + [ + { + "testcase_name": "multi_on", + "workers": 2, + "use_multiprocessing": True, + "max_queue_size": 10, + }, + { + "testcase_name": "multi_off", + "workers": 2, + "use_multiprocessing": False, + "max_queue_size": 10, + }, + { + "testcase_name": "multi_off_zero", + "workers": 0, + "use_multiprocessing": False, + "max_queue_size": 0, + }, + ], + shuffle=[True, False], + dataset_type=["np", "tf", "jax", "torch"], + iterator_type=["np", "tf", "jax", "torch"], + ) ) def test_basic_flow( - self, shuffle, workers, use_multiprocessing, max_queue_size + self, + shuffle, + workers, + use_multiprocessing, + max_queue_size, + dataset_type, + iterator_type, ): set_random_seed(1337) - x = np.random.random((64, 4)) - y = np.array([[i, i] for i in range(64)], dtype="float64") + x = np.random.random((64, 4)).astype("float32") + y = np.array([[i, i] for i in range(64)], dtype="float32") + if dataset_type == "tf": + x, y = tf.constant(x), tf.constant(y) + elif dataset_type == "jax": + x, y = jax.numpy.array(x), jax.numpy.array(y) + elif dataset_type == "torch": + x, y = torch.as_tensor(x), torch.as_tensor(y) py_dataset = ExamplePyDataset( x, y, @@ -89,37 +121,33 @@ def test_basic_flow( py_dataset, shuffle=shuffle ) - gen = adapter.get_numpy_iterator() + if iterator_type == "np": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif iterator_type == "tf": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif iterator_type == "jax": + it = adapter.get_jax_iterator() + expected_class = jax.Array + elif iterator_type == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + sample_order = [] - for batch in gen: + for batch in it: self.assertEqual(len(batch), 2) bx, by = batch - self.assertIsInstance(bx, np.ndarray) - self.assertIsInstance(by, np.ndarray) + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) self.assertEqual(bx.dtype, by.dtype) + self.assertContainsExactSubsequence(str(bx.dtype), "float32") self.assertEqual(bx.shape, (16, 4)) self.assertEqual(by.shape, (16, 2)) for i in range(by.shape[0]): sample_order.append(by[i, 0]) if shuffle: - self.assertFalse(sample_order == list(range(64))) - else: - self.assertAllClose(sample_order, list(range(64))) - - ds = adapter.get_tf_dataset() - sample_order = [] - for batch in ds: - self.assertEqual(len(batch), 2) - bx, by = batch - self.assertIsInstance(bx, tf.Tensor) - self.assertIsInstance(by, tf.Tensor) - self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(tuple(bx.shape), (16, 4)) - self.assertEqual(tuple(by.shape), (16, 2)) - for i in range(by.shape[0]): - sample_order.append(by[i, 0]) - if shuffle: - self.assertFalse(sample_order == list(range(64))) + self.assertNotAllClose(sample_order, list(range(64))) else: self.assertAllClose(sample_order, list(range(64))) diff --git a/keras/trainers/data_adapters/tf_dataset_adapter.py b/keras/trainers/data_adapters/tf_dataset_adapter.py index 1b8d59a0693..c6c96be1af8 100644 --- a/keras/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/trainers/data_adapters/tf_dataset_adapter.py @@ -1,4 +1,3 @@ -import numpy as np import tree from keras.trainers.data_adapters import data_adapter_utils @@ -37,20 +36,22 @@ def __init__(self, dataset, class_weight=None, distribution=None): self._dataset = dataset def get_numpy_iterator(self): - from keras.utils.module_utils import tensorflow as tf - - def convert_to_numpy(x): - if isinstance(x, tf.SparseTensor): - x = tf.sparse.to_dense(x) - # shared memory using `np.asarray` - return np.asarray(x) + from keras.backend.tensorflow.core import convert_to_numpy for batch in self._dataset: yield tree.map_structure(convert_to_numpy, batch) + def get_jax_iterator(self): + # We use numpy as an intermediary because the conversion + # tf -> numpy -> jax is more than 2x faster than tf -> jax. + return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator()) + def get_tf_dataset(self): return self._dataset + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self._dataset) + @property def num_batches(self): cardinality = self._dataset.cardinality diff --git a/keras/trainers/data_adapters/tf_dataset_adapter_test.py b/keras/trainers/data_adapters/tf_dataset_adapter_test.py index d1dd94d1dc9..74208ba0395 100644 --- a/keras/trainers/data_adapters/tf_dataset_adapter_test.py +++ b/keras/trainers/data_adapters/tf_dataset_adapter_test.py @@ -1,14 +1,21 @@ from unittest import mock +import jax import numpy as np import tensorflow as tf +import torch +from absl.testing import parameterized from keras import testing +from keras.testing.test_utils import named_product from keras.trainers.data_adapters import tf_dataset_adapter -class TestTFDatasetAdapter(testing.TestCase): - def test_basic_flow(self): +class TestTFDatasetAdapter(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + named_product(iterator_type=["np", "tf", "jax", "torch"]) + ) + def test_basic_flow(self, iterator_type): x = tf.random.normal((34, 4)) y = tf.random.normal((34, 2)) base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16) @@ -19,34 +26,32 @@ def test_basic_flow(self): self.assertEqual(adapter.has_partial_batch, None) self.assertEqual(adapter.partial_batch_size, None) - gen = adapter.get_numpy_iterator() - for i, batch in enumerate(gen): + if iterator_type == "np": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif iterator_type == "tf": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif iterator_type == "jax": + it = adapter.get_jax_iterator() + expected_class = jax.Array + elif iterator_type == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + for i, batch in enumerate(it): self.assertEqual(len(batch), 2) bx, by = batch - self.assertIsInstance(bx, np.ndarray) - self.assertIsInstance(by, np.ndarray) + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.dtype, "float32") + self.assertContainsExactSubsequence(str(bx.dtype), "float32") if i < 2: self.assertEqual(bx.shape, (16, 4)) self.assertEqual(by.shape, (16, 2)) else: self.assertEqual(bx.shape, (2, 4)) self.assertEqual(by.shape, (2, 2)) - ds = adapter.get_tf_dataset() - for i, batch in enumerate(ds): - self.assertEqual(len(batch), 2) - bx, by = batch - self.assertIsInstance(bx, tf.Tensor) - self.assertIsInstance(by, tf.Tensor) - self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.dtype, "float32") - if i < 2: - self.assertEqual(tuple(bx.shape), (16, 4)) - self.assertEqual(tuple(by.shape), (16, 2)) - else: - self.assertEqual(tuple(bx.shape), (2, 4)) - self.assertEqual(tuple(by.shape), (2, 2)) def _test_class_weights(self, target_encoding="int"): x = np.random.random((4, 2)) diff --git a/keras/trainers/data_adapters/torch_data_loader_adapter.py b/keras/trainers/data_adapters/torch_data_loader_adapter.py index 4cde43a65d0..32747b85093 100644 --- a/keras/trainers/data_adapters/torch_data_loader_adapter.py +++ b/keras/trainers/data_adapters/torch_data_loader_adapter.py @@ -1,6 +1,8 @@ import numpy as np import tree +from keras import backend +from keras.trainers.data_adapters import data_adapter_utils from keras.trainers.data_adapters.data_adapter import DataAdapter @@ -28,8 +30,10 @@ def get_numpy_iterator(self): tree.map_structure(lambda x: np.asarray(x.cpu()), batch) ) - def get_torch_dataloader(self): - return self._dataloader + def get_jax_iterator(self): + # We use numpy as an intermediary because the conversion + # torch -> numpy -> jax is faster than torch -> jax. + return data_adapter_utils.get_jax_iterator(self.get_numpy_iterator()) def get_tf_dataset(self): from keras.utils.module_utils import tensorflow as tf @@ -40,6 +44,9 @@ def get_tf_dataset(self): output_signature=output_signature, ) + def get_torch_dataloader(self): + return self._dataloader + def peek_and_get_tensor_spec(self): from keras.utils.module_utils import tensorflow as tf @@ -56,10 +63,7 @@ def get_tensor_spec(x): ) shape = list(shape) shape[0] = None # The batch size is not guaranteed to be static. - - # No easy way to get string representation of dtype in torch - # TODO: Figure out a better way to achieve this - dtype = str(x.dtype).replace("torch.", "") + dtype = backend.standardize_dtype(x.dtype) return tf.TensorSpec(shape=shape, dtype=dtype) return tuple(tree.map_structure(get_tensor_spec, batch_data)) diff --git a/keras/trainers/data_adapters/torch_data_loader_adapter_test.py b/keras/trainers/data_adapters/torch_data_loader_adapter_test.py index 00539b13cd7..80932dc14e6 100644 --- a/keras/trainers/data_adapters/torch_data_loader_adapter_test.py +++ b/keras/trainers/data_adapters/torch_data_loader_adapter_test.py @@ -1,28 +1,25 @@ +import jax import numpy as np -import pytest import tensorflow as tf +import torch +from absl.testing import parameterized -from keras import backend from keras import testing +from keras.testing.test_utils import named_product from keras.trainers.data_adapters.torch_data_loader_adapter import ( TorchDataLoaderAdapter, ) -@pytest.mark.skipif( - backend.backend() != "torch", - reason="Backend does not support TorchDataLoaderAdapter.", -) -class TestTorchDataLoaderAdapter(testing.TestCase): - def test_basic_dataloader(self): - import torch - from torch.utils.data import DataLoader - from torch.utils.data import TensorDataset - +class TestTorchDataLoaderAdapter(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + named_product(iterator_type=["np", "tf", "jax", "torch"]) + ) + def test_basic_dataloader(self, iterator_type): x = torch.normal(2, 3, size=(34, 4)) y = torch.normal(1, 3, size=(34, 2)) - base_ds = TensorDataset(x, y) - base_dataloader = DataLoader(base_ds, batch_size=16) + base_ds = torch.utils.data.TensorDataset(x, y) + base_dataloader = torch.utils.data.DataLoader(base_ds, batch_size=16) adapter = TorchDataLoaderAdapter(base_dataloader) self.assertEqual(adapter.num_batches, 3) @@ -30,47 +27,29 @@ def test_basic_dataloader(self): self.assertEqual(adapter.has_partial_batch, True) self.assertEqual(adapter.partial_batch_size, 2) - gen = adapter.get_numpy_iterator() - for i, batch in enumerate(gen): + if iterator_type == "np": + it = adapter.get_numpy_iterator() + expected_class = np.ndarray + elif iterator_type == "tf": + it = adapter.get_tf_dataset() + expected_class = tf.Tensor + elif iterator_type == "jax": + it = adapter.get_jax_iterator() + expected_class = jax.Array + elif iterator_type == "torch": + it = adapter.get_torch_dataloader() + expected_class = torch.Tensor + + for i, batch in enumerate(it): self.assertEqual(len(batch), 2) bx, by = batch - self.assertIsInstance(bx, np.ndarray) - self.assertIsInstance(by, np.ndarray) + self.assertIsInstance(bx, expected_class) + self.assertIsInstance(by, expected_class) self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.dtype, "float32") + self.assertContainsExactSubsequence(str(bx.dtype), "float32") if i < 2: self.assertEqual(bx.shape, (16, 4)) self.assertEqual(by.shape, (16, 2)) else: self.assertEqual(bx.shape, (2, 4)) self.assertEqual(by.shape, (2, 2)) - - ds = adapter.get_torch_dataloader() - for i, batch in enumerate(ds): - self.assertEqual(len(batch), 2) - bx, by = batch - self.assertIsInstance(bx, torch.Tensor) - self.assertIsInstance(by, torch.Tensor) - self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.dtype, torch.float32) - if i < 2: - self.assertEqual(tuple(bx.shape), (16, 4)) - self.assertEqual(tuple(by.shape), (16, 2)) - else: - self.assertEqual(tuple(bx.shape), (2, 4)) - self.assertEqual(tuple(by.shape), (2, 2)) - - ds = adapter.get_tf_dataset() - for i, batch in enumerate(ds): - self.assertEqual(len(batch), 2) - bx, by = batch - self.assertIsInstance(bx, tf.Tensor) - self.assertIsInstance(by, tf.Tensor) - self.assertEqual(bx.dtype, by.dtype) - self.assertEqual(bx.dtype, tf.float32) - if i < 2: - self.assertEqual(tuple(bx.shape), (16, 4)) - self.assertEqual(tuple(by.shape), (16, 2)) - else: - self.assertEqual(tuple(bx.shape), (2, 4)) - self.assertEqual(tuple(by.shape), (2, 2)) diff --git a/keras/trainers/epoch_iterator.py b/keras/trainers/epoch_iterator.py index b5d7606c077..b2877b25f39 100644 --- a/keras/trainers/epoch_iterator.py +++ b/keras/trainers/epoch_iterator.py @@ -71,23 +71,14 @@ def __init__( ) self._num_batches = self.data_adapter.num_batches - def _get_iterator(self, return_type="auto"): - if return_type not in ("np", "tf", "auto"): - raise ValueError( - "Argument `return_type` must be one of `{'np', 'tf', 'auto'}`. " - f"Received instead: return_type={return_type}" - ) - if return_type == "tf": - iterator = self.data_adapter.get_tf_dataset() - else: - iterator = self.data_adapter.get_numpy_iterator() - return iterator + def _get_iterator(self): + return self.data_adapter.get_numpy_iterator() - def enumerate_epoch(self, return_type="auto"): + def enumerate_epoch(self): buffer = [] if self.steps_per_epoch: - if not self._current_iterator: - self._current_iterator = self._get_iterator(return_type) + if self._current_iterator is None: + self._current_iterator = iter(self._get_iterator()) self._insufficient_data = False for step in range(self.steps_per_epoch): @@ -114,7 +105,7 @@ def enumerate_epoch(self, return_type="auto"): if buffer: yield step - len(buffer) + 1, buffer else: - for step, data in enumerate(self._get_iterator(return_type)): + for step, data in enumerate(self._get_iterator()): buffer.append(data) if len(buffer) == self.steps_per_execution: yield step - len(buffer) + 1, buffer diff --git a/keras/trainers/epoch_iterator_test.py b/keras/trainers/epoch_iterator_test.py index 5299c102d0a..0c731b0bc79 100644 --- a/keras/trainers/epoch_iterator_test.py +++ b/keras/trainers/epoch_iterator_test.py @@ -9,7 +9,7 @@ class TestEpochIterator(testing.TestCase): - def _test_basic_flow(self, return_type): + def test_basic_flow(self): x = np.random.random((100, 16)) y = np.random.random((100, 4)) sample_weight = np.random.random((100,)) @@ -23,22 +23,13 @@ def _test_basic_flow(self, return_type): shuffle=shuffle, ) steps_seen = [] - for step, batch in iterator.enumerate_epoch(return_type=return_type): + for step, batch in iterator.enumerate_epoch(): batch = batch[0] steps_seen.append(step) self.assertEqual(len(batch), 3) - if return_type == "np": - self.assertIsInstance(batch[0], np.ndarray) - else: - self.assertIsInstance(batch[0], tf.Tensor) + self.assertIsInstance(batch[0], np.ndarray) self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6]) - def test_basic_flow_np(self): - self._test_basic_flow("np") - - def test_basic_flow_tf(self): - self._test_basic_flow("tf") - def test_insufficient_data(self): batch_size = 8 steps_per_epoch = 6 @@ -97,7 +88,7 @@ def __getitem__(self, idx): torch_dataset, batch_size=8, shuffle=True ) iterator = epoch_iterator.EpochIterator(torch_dataloader) - for _, batch in iterator.enumerate_epoch(return_type="np"): + for _, batch in iterator.enumerate_epoch(): batch = batch[0] self.assertEqual(batch[0].shape, (8, 2)) self.assertEqual(batch[1].shape, (8, 1)) @@ -180,14 +171,3 @@ def test_unrecognized_data_type(self): x = "unsupported_data" with self.assertRaisesRegex(ValueError, "Unrecognized data type"): _ = epoch_iterator.EpochIterator(x=x) - - def test_invalid_return_type_in_get_iterator(self): - x = np.random.random((100, 16)) - y = np.random.random((100, 4)) - epoch_iter = epoch_iterator.EpochIterator(x=x, y=y) - - with self.assertRaisesRegex( - ValueError, - "Argument `return_type` must be one of `{'np', 'tf', 'auto'}`", - ): - _ = epoch_iter._get_iterator("unsupported")