Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Nodes] Add Prebatch setting to ParallelMapper #1417

Merged
merged 7 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ repos:
- usort == 1.0.0

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
rev: 6.1.0
hooks:
- id: flake8
37 changes: 29 additions & 8 deletions test/nodes/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import itertools

import unittest
from typing import List
from typing import List, Optional

from parameterized import parameterized
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_CUDA, TestCase
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_exception_handling_mapper_multiprocess(self):
def test_exception_handling_mapper_multiprocess_cuda(self):
self._test_exception_handling_mapper(True, "process")

def _test_map(self, in_order, method) -> None:
def _test_map(self, in_order, method, prebatch) -> None:
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
Expand All @@ -68,6 +68,7 @@ def _test_map(self, in_order, method) -> None:
in_order=in_order,
method=method,
multiprocessing_context=multiprocessing_context,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)

Expand Down Expand Up @@ -98,25 +99,40 @@ def _test_map(self, in_order, method) -> None:
)

def test_in_order_threads(self):
self._test_map(True, "thread")
self._test_map(True, "thread", None)

def test_out_of_order_threads(self):
self._test_map(False, "thread")
self._test_map(False, "thread", None)

def test_in_order_process(self):
self._test_map(True, "process")
self._test_map(True, "process", None)

def test_out_of_order_process(self):
self._test_map(False, "process")
self._test_map(False, "process", None)

def test_in_order_thread_prebatch(self):
self._test_map(True, "thread", 3)

def test_out_of_order_thread_prebatch(self):
self._test_map(False, "thread", 3)

def test_in_order_process_prebatch(self):
self._test_map(True, "process", 3)

def test_out_of_order_process_prebatch(self):
self._test_map(False, "process", 3)

@parameterized.expand(
itertools.product(
[0, 7, 13],
[True], # TODO: define and fix in_order = False
[0, 1, 9], # TODO: define and fix in_order = False
[None, 3], # prebatch
)
)
def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_frequency: int):
def test_save_load_state_thread(
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
):
method = "thread"
batch_size = 6
n = 80
Expand All @@ -129,6 +145,7 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
in_order=in_order,
method=method,
snapshot_frequency=snapshot_frequency,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)
run_test_save_load_state(self, node, midpoint)
Expand All @@ -138,9 +155,12 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
[0, 7, 13],
[True], # TODO: define and fix in_order = False
[0, 1, 9], # TODO: define and fix in_order = False
[None, 3], # prebatch
)
)
def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_frequency: int):
def test_save_load_state_process(
self, midpoint: int, in_order: bool, snapshot_frequency: int, prebatch: Optional[int]
):
method = "process"
batch_size = 6
n = 80
Expand All @@ -155,6 +175,7 @@ def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_f
method=method,
multiprocessing_context=multiprocessing_context,
snapshot_frequency=snapshot_frequency,
prebatch=prebatch,
)
node = Prefetcher(node, prefetch_factor=2)
run_test_save_load_state(self, node, midpoint)
142 changes: 115 additions & 27 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import queue
import threading
import time
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Protocol, TypeVar, Union
from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, Protocol, Sequence, TypeVar, Union

import torch.multiprocessing as mp
from torchdata.nodes.base_node import BaseNode, T
from torchdata.nodes.batch import Batcher, Unbatcher
from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper
from torchdata.nodes.snapshot_store import QueueSnapshotStore, SnapshotStore

Expand Down Expand Up @@ -52,6 +53,18 @@ def Mapper(source: BaseNode[X], map_fn: Callable[[X], T]) -> "ParallelMapper[T]"
)


Xseq = Sequence[X]
Tseq = Sequence[T]


class MapOverBatch(Generic[X, T]):
def __init__(self, map_fn: Callable[[X], T]):
self.map_fn = map_fn

def __call__(self, xlist: Sequence[X]) -> Sequence[T]:
return [self.map_fn(x) for x in xlist]


def _sort_worker(in_q: Union[queue.Queue, mp.Queue], out_q: queue.Queue, stop_event: threading.Event):
buffer: Dict[int, Any] = {}
cur_idx = 0
Expand Down Expand Up @@ -272,6 +285,78 @@ def _shutdown(self):
t.join(timeout=QUEUE_TIMEOUT * 5)


class _ParallelMapperImpl(BaseNode[T]):
"""This class implements _ParallelMapperIter and _InlineMapperIter as a BaseNode,
allowing them to be composed with other BaseNodes.

TODO: In the future, this class may go away once we implement reset() on
_ParallelMapperIter and _InlineMapperIter themselves so we don't need this
additional level of abstraction.
"""

def __init__(
self,
source: BaseNode[X],
map_fn: Callable[[X], T],
num_workers: int,
in_order: bool = True,
method: Literal["thread", "process"] = "thread",
multiprocessing_context: Optional[str] = None,
max_concurrent: Optional[int] = None,
snapshot_frequency: int = 1,
):
super().__init__()
assert method in ["thread", "process"]
self.source = source
self.map_fn = map_fn
self.num_workers = num_workers
self.in_order = in_order
self.method = method
self.multiprocessing_context = multiprocessing_context
self._mp_context: Any = mp
if self.method == "process" and self.multiprocessing_context is not None:
self._mp_context = mp.get_context(self.multiprocessing_context)

if max_concurrent is not None and num_workers > 0:
if not isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!")
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if self._it is not None:
del self._it

if self.num_workers > 0:
self._it = self._parallel_reset(initial_state)
else:
self._it = self._inline_reset(initial_state)

def _inline_reset(self, initial_state: Optional[Dict[str, Any]]):
return _InlineMapperIter(source=self.source, map_fn=self.map_fn, initial_state=initial_state)

def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
return _ParallelMapperIter(
source=self.source,
map_fn=self.map_fn,
num_workers=self.num_workers,
in_order=self.in_order,
method=self.method,
mp_context=self._mp_context,
max_concurrent=self.max_concurrent,
snapshot_frequency=self.snapshot_frequency,
initial_state=initial_state,
)

def next(self) -> T:
return next(self._it) # type: ignore[arg-type, union-attr]

def get_state(self) -> Dict[str, Any]:
return self._it.get_state() # type: ignore[union-attr]


class ParallelMapper(BaseNode[T]):
"""ParallelMapper executes map_fn in parallel either in num_workers threads or
processes. For processes, multiprocessing_context can be spawn, forkserver, fork,
Expand All @@ -294,8 +379,12 @@ class ParallelMapper(BaseNode[T]):
multiprocessing_context (Optional[str]): The multiprocessing context to use for parallel processing. Default is None.
max_concurrent (Optional[int]): The maximum number of items to process at once. Default is None.
snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 1.
prebatch (Optional[int]): Optionally perform pre-batching of items from source before mapping.
For small items, this may improve throughput at the expense of peak memory.
"""

IT_STATE_KEY = "it_state"

def __init__(
self,
source: BaseNode[X],
Expand All @@ -306,58 +395,57 @@ def __init__(
multiprocessing_context: Optional[str] = None,
max_concurrent: Optional[int] = None,
snapshot_frequency: int = 1,
prebatch: Optional[int] = None,
):
super().__init__()
assert method in ["thread", "process"]
self.source = source
self.map_fn = map_fn
self.num_workers = num_workers
self.in_order = in_order
self.method = method
self.multiprocessing_context = multiprocessing_context
self._mp_context: Any = mp
if self.method == "process" and self.multiprocessing_context is not None:
self._mp_context = mp.get_context(self.multiprocessing_context)

if max_concurrent is not None and num_workers > 0:
if not isinstance(max_concurrent, int) and max_concurrent > num_workers:
raise ValueError(f"{max_concurrent=} should be >= {num_workers=}!")
self.max_concurrent = max_concurrent
self.snapshot_frequency = snapshot_frequency
self._it: Optional[Union[_InlineMapperIter[T], _ParallelMapperIter[T]]] = None

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if self._it is not None:
self._it._shutdown()
del self._it

if self.num_workers > 0:
self._parallel_reset(initial_state)
self.prebatch = prebatch
if prebatch is None:
self.map_fn = map_fn
self.source = source
else:
self._inline_reset(initial_state)

def _inline_reset(self, initial_state: Optional[Dict[str, Any]]):
self._it = _InlineMapperIter(source=self.source, map_fn=self.map_fn, initial_state=initial_state)
if prebatch <= 0:
raise ValueError(f"{prebatch=} must be a positive integer!")
self.map_fn = MapOverBatch(map_fn=map_fn) # type: ignore[assignment]
self.source = Batcher(source, batch_size=prebatch, drop_last=False) # type: ignore[assignment]

def _parallel_reset(self, initial_state: Optional[Dict[str, Any]]):
self._it = _ParallelMapperIter(
_it = _ParallelMapperImpl(
source=self.source,
map_fn=self.map_fn,
num_workers=self.num_workers,
in_order=self.in_order,
method=self.method,
mp_context=self._mp_context,
multiprocessing_context=self.multiprocessing_context,
max_concurrent=self.max_concurrent,
snapshot_frequency=self.snapshot_frequency,
initial_state=initial_state,
)

def next(self):
if self.prebatch is None:
self._it = _it
else:
self._it = Unbatcher(_it) # type: ignore[arg-type, assignment]

def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self._it.reset(initial_state[self.IT_STATE_KEY])
else:
self._it.reset()

def next(self) -> T:
return next(self._it) # type: ignore[arg-type, union-attr]

def get_state(self) -> Dict[str, Any]:
return self._it.get_state() # type: ignore[union-attr]
return {self.IT_STATE_KEY: self._it.state_dict()} # type: ignore[union-attr]


_WorkerType = Callable[
Expand Down
Loading