Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make itertomap loading more lazy #1016

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,18 @@ def test_itertomap_mapdatapipe(self):
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Found duplicate key")

# More lazily: load only until necessary
source_dp = IterableWrapper(list(zip(keys, values)))
lazy_map_dp = source_dp.to_map_datapipe()
_ = lazy_map_dp["k" + str(4)]
self.assertEqual(len(lazy_map_dp._map), 5)
_ = lazy_map_dp["k" + str(7)]
self.assertEqual(len(lazy_map_dp._map), 8)
try:
_ = lazy_map_dp["k" + str(20)]
except IndexError:
self.assertEqual(len(lazy_map_dp._map), 10)

def test_mux_longest_iterdatapipe(self):

# Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted
Expand Down
113 changes: 72 additions & 41 deletions torchdata/datapipes/iter/util/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
# LICENSE file in the root directory of this source tree.

import warnings

from typing import Callable, Dict, Optional
from typing import Callable, Dict, Iterator, Optional

from torch.utils.data import IterDataPipe, MapDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, DILL_AVAILABLE
Expand Down Expand Up @@ -37,28 +36,43 @@ class IterToMapConverterMapDataPipe(MapDataPipe):
will be replaced by the new value.

Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> source_dp = IterableWrapper([(i, i) for i in range(10)])
>>> map_dp = source_dp.to_map_datapipe()
>>> list(map_dp)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)])
>>> map_dp2 = source_dp2.to_map_datapipe()
>>> map_dp2['a']
1
>>> def row_to_tuple(row):
>>> label = row[0]
>>> data = row[1:]
>>> return label, data
>>> source_dp3 = IterableWrapper([('a', 1, 1, 1, 1, 1, 1), ('b', 2, 2, 2, 2, 2, 2), ('c', 3, 3, 3, 3, 3, 3)])
>>> map_dp3 = source_dp3.to_map_datapipe(key_value_fn=row_to_tuple)
>>> map_dp3['a']

.. testsetup::

from torchdata.datapipes.iter import IterableWrapper

.. testcode::

source_dp = IterableWrapper([(i, i) for i in range(10)])
map_dp = source_dp.to_map_datapipe()
assert list(map_dp) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)])
map_dp2 = source_dp2.to_map_datapipe()
assert map_dp2['a'] == 1

.. testcode::

def row_to_tuple(row):
label = row[0]
data = row[1:]
return label, data
source_dp3 = IterableWrapper([('a', 1, 1, 1, 1, 1, 1), ('b', 2, 2, 2, 2, 2, 2), ('c', 3, 3, 3, 3, 3, 3)])
map_dp3 = source_dp3.to_map_datapipe(key_value_fn=row_to_tuple)
print(map_dp3['a'])

.. testoutput::

(1, 1, 1, 1, 1, 1)

"""

datapipe: IterDataPipe
key_value_fn: Optional[Callable]
_map: Optional[Dict]
_length: int
_itr: Optional[Iterator]
_depleted: bool

def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = None):
if not isinstance(datapipe, IterDataPipe):
Expand All @@ -68,32 +82,53 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
_check_unpickable_fn(key_value_fn)
self.key_value_fn = key_value_fn # type: ignore[assignment]
self._map = None
self._itr = None
self._depleted = False

def _load_map(self):
self._map = {}
for d in self.datapipe:
inp = d if self.key_value_fn is None else self.key_value_fn(d)
while not self._depleted:
try:
length = len(inp)
except TypeError:
raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence")
if length != 2:
raise ValueError(f"dictionary update sequence element has length {length}, 2 is required")
key, value = inp
if key in self._map:
warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`")
self._map[key] = value
self._load_next_item()
except StopIteration:
self._depleted = True
self._itr = None

def __getitem__(self, index):
try:
if self._map is None:
self._load_map()
return self._map[index] # type: ignore[index]
if self._map is not None:
return self._map[index]
except KeyError:
raise IndexError(f"Index {index} is invalid for IterToMapConverter.")
pass
while not self._depleted:
try:
key, value = self._load_next_item()
if key == index:
return value
except StopIteration:
self._depleted = True
SvenDS9 marked this conversation as resolved.
Show resolved Hide resolved
self._itr = None
raise IndexError(f"Index {index} is invalid for IterToMapConverter.")

def _load_next_item(self):
if self._map is None:
self._map = {}
self._itr = iter(self.datapipe)
elem = next(self._itr) # type: ignore[arg-type]
inp = elem if self.key_value_fn is None else self.key_value_fn(elem)
try:
length = len(inp)
except TypeError:
raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence")
if length != 2:
raise ValueError(f"dictionary update sequence element has length {length}, 2 is required")
key, value = inp
if key in self._map:
warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`")
self._map[key] = value
return key, value

def __len__(self):
if self._map is not None:
if self._depleted:
return len(self._map) # type: ignore[arg-type]
try:
return len(self.datapipe)
Expand All @@ -112,14 +147,10 @@ def __getstate__(self):
dill_key_value_fn = dill.dumps(self.key_value_fn)
else:
dill_key_value_fn = self.key_value_fn
return (
self.datapipe,
dill_key_value_fn,
self._map,
)
return (self.datapipe, dill_key_value_fn, self._map, self._depleted)

def __setstate__(self, state):
(self.datapipe, dill_key_value_fn, self._map) = state
(self.datapipe, dill_key_value_fn, self._map, self._depleted) = state
if DILL_AVAILABLE:
self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment]
else:
Expand Down