Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor bucket batch sampler to support DDP sampling #478

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Type, TypeVar, Union

import torch
import torch.distributed as dist
from torch.utils.data import Sampler


Expand Down Expand Up @@ -371,6 +372,8 @@ def __init__(
base_batch_sampler_individual_kwargs: Optional[Dict[str, Iterable]] = None,
shuffle: Optional[bool] = True,
generator: Optional[torch.Generator] = None,
sampler: Optional[Union[Sampler[int], Iterable[int]]] = None,
num_batches: Optional[int] = None,
) -> None:
"""Initializes the BucketBatchSampler.

Expand All @@ -390,13 +393,24 @@ def __init__(
Default to {}.
shuffle: A boolean indicating whether to shuffle the dataset and buckets. Defaults to True.
generator: Generator used in sampling. Defaults to None.
sampler: Sampler to get data sample indices for this batch sampler. When using data distributed parallel strategy for training models,
the user should provide here the torch.utils.data.distributed.DistributedSampler or the equivalent to partition the data sample
indices across devices. Defaults to None.
num_batches: Number of mini-batches, when using data distributed parallel strategy for training models, this argument value should
be consistent across training devices so that they get the same number of batches.. Defaults to None.

Raises:
ValueError: If `sizes` is not a 1D tensor of real numbers.
ValueError: If `bucket_boundaries` is not a 1D tensor of real numbers.
TypeError: If `sizes` is not torch tensor.
ValueError: If `sizes` is not a 1D array, or its entries are not real numbers.
TypeError: If `bucket_boundaries` is not torch tensor.
ValueError: If `bucket_boundaries` is not a 1D array, or its entries are not real numbers, or it has only 1 entry, or it has duplicate entries.
TypeError: If `shuffle` is not bool.
TypeError: If `generator` is not NoneType and not torch generator.
TypeError: If `base_batch_sampler_class` is not a subclass of `torch.utils.data.Sampler`
ValueError: If `base_batch_sampler_individual_kwargs` or `base_batch_sampler_individual_kwargs` is not a keyword argument dictionary.
ValueError: If the length of values in the dict of `base_batch_sampler_individual_kwargs` must be equal to len(bucket_boundaries) - 1.
RuntimeError: If there is no elements with sizes inside the ranges specified by `bucket_boundaries`.
ValueError: If the length of values in the dict of `base_batch_sampler_individual_kwargs` is not equal to len(bucket_boundaries) - 1.
TypeError: If `sampler` is not None and not an instance of `torch.utils.data.Sampler` or is not iterable.
ValueError: if `num_batches` is None but distributed sampling is enabled.

"""
if not torch.is_tensor(sizes):
Expand Down Expand Up @@ -440,11 +454,22 @@ def __init__(
if not isinstance(shuffle, bool):
raise TypeError(f"shuffle should be a boolean value, but got shuffle={shuffle}")

if generator is not None and not isinstance(generator, torch.Generator):
raise TypeError(f"generator must be a torch generator or None value, but got generator={type(generator)}")

self.sizes = sizes
self.bucket_boundaries = bucket_boundaries
self.num_buckets = len(bucket_boundaries) - 1

num_outliers = (self.sizes < self.bucket_boundaries[0]).sum() + (
self.sizes >= self.bucket_boundaries[-1]
).sum()
if num_outliers > 0:
warnings.warn(f"{num_outliers} elements are outside the buckets provided and will be skipped")

self.shuffle = shuffle
self.generator = generator

if self.shuffle and self.generator is None:
self.generator = torch.Generator().manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))

Expand Down Expand Up @@ -486,50 +511,108 @@ def __init__(
)

self.base_batch_sampler_class = base_batch_sampler_class
self.base_batch_sampler_shared_kwargs = (
{} if base_batch_sampler_shared_kwargs is None else base_batch_sampler_shared_kwargs
)
base_batch_sampler_individual_kwargs = (
{} if base_batch_sampler_individual_kwargs is None else base_batch_sampler_individual_kwargs
)
self.base_batch_sampler_individual_kwargs = [
{key: list(base_batch_sampler_individual_kwargs[key])[k] for key in base_batch_sampler_individual_kwargs}
for k in range(self.num_buckets)
]
self.base_batch_sampler_shared_kwargs = base_batch_sampler_shared_kwargs
self.base_batch_sampler_individual_kwargs = base_batch_sampler_individual_kwargs

self.bucket_sizes: torch.Tensor # number of elements in each bucket
self.bucket_element_indices: List[List[int]] # List of elements' indices for each bucket

# bucket index for each element
element_bucket_indices = torch.bucketize(sizes, bucket_boundaries, right=True)
self.element_bucket_indices = torch.bucketize(sizes, self.bucket_boundaries, right=True)

if sampler is not None and not isinstance(sampler, (Sampler, Iterable)): # type: ignore
raise TypeError(
f"sampler should be an instance of torch.utils.data.Sampler or an iterable, but got sampler={type(sampler)}"
)

self.sampler = sampler
self.bucket_element_indices = [[] for _ in range(self.num_buckets)]

if self.sampler is None:
self._compute_bucket_element_indices()
else:
self._compute_bucket_element_indices(list(self.sampler)) # type: ignore

self.base_batch_samplers = self._init_base_batch_samplers()

if num_batches is not None and (not isinstance(num_batches, int) or num_batches <= 0):
raise ValueError(f"num_batches should be a positive integer, but got num_batches = {num_batches}")

if dist.is_initialized() and num_batches is None:
raise ValueError("num_batches is not set but distributed sampling is enabled")

self.num_batches = num_batches

def setup_for_pytorch_lightning(self) -> Sampler:
"""Add additional attributes necessary to be compatible with pytorch lightning distributed sampling.

When used for distributed sampling, Pytorch Lightning will use these attributes to re-init this batch sampler and
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
replace the sampler to torch.utils.data.DistributedSampler if the lightning trainer is instantiated with
use_distributed_sampler=True.

# element indices reordered for each bucket
reordered_element_indices = torch.argsort(element_bucket_indices, stable=True)
Returns:
Sampler: this batch sampler instance.
"""
setattr(self, "__pl_saved_args", ())
setattr(
self,
"__pl_saved_kwargs",
{
"sizes": self.sizes,
"bucket_boundaries": self.bucket_boundaries,
"base_batch_sampler_class": self.base_batch_sampler_class,
"base_batch_sampler_shared_kwargs": self.base_batch_sampler_shared_kwargs,
"base_batch_sampler_individual_kwargs": self.base_batch_sampler_individual_kwargs,
"shuffle": self.shuffle,
"generator": self.generator,
"sampler": self.sampler,
"num_batches": self.num_batches,
},
)
setattr(self, "__pl_saved_default_kwargs", {})
setattr(self, "__pl_saved_arg_names", ())
return self

def _compute_bucket_element_indices(self, indices: Optional[List[int]] = None):
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
"""Compute and update the bucket element indices with the given element indices.

Args:
indices: Data indices subset, will use all indices if it is None. Defaults to None.

Raises:
RuntimeError: If there is no elements from the given indices having sizes
inside the ranges specified by `self.bucket_boundaries`.
"""
if indices is None:
element_bucket_indices = self.element_bucket_indices
# element indices reordered for each bucket
reordered_element_indices = torch.argsort(element_bucket_indices, stable=True)
else:
indices_subst: torch.Tensor = torch.tensor(indices) # type: ignore
element_bucket_indices = self.element_bucket_indices[indices_subst]
# element indices reordered for each bucket
reordered_element_indices = indices_subst[torch.argsort(element_bucket_indices, stable=True)]

# bucket sizes, including the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
bucket_sizes = torch.bincount(element_bucket_indices, minlength=len(bucket_boundaries) + 1)
bucket_sizes = torch.bincount(element_bucket_indices, minlength=len(self.bucket_boundaries) + 1)

# bucket segments
bucket_segments = torch.cumsum(bucket_sizes, dim=0)[:-1]
bucket_sizes_cumsum = torch.cumsum(bucket_sizes, dim=0)[:-1]

self.bucket_element_indices = []
# exclude the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
for bucket_idx in range(self.num_buckets):
self.bucket_element_indices.append(
reordered_element_indices[bucket_segments[bucket_idx] : bucket_segments[bucket_idx + 1]].tolist()
)
self.bucket_element_indices[bucket_idx][:] = reordered_element_indices[
bucket_sizes_cumsum[bucket_idx] : bucket_sizes_cumsum[bucket_idx + 1]
].tolist()

self.bucket_sizes = bucket_sizes[1 : (self.num_buckets + 1)]

self.num_samples = torch.sum(self.bucket_sizes).item()
DejunL marked this conversation as resolved.
Show resolved Hide resolved
if self.num_samples == 0:
raise RuntimeError("The sizes of all elements in the dataset are outside the bucket ranges provided")
if self.num_samples < len(self.sizes):
warnings.warn(
f"{len(self.sizes) - self.num_samples} elements are outside the buckets provided and will be skipped"
raise RuntimeError(
"The sizes of the elements with given indices in the dataset are outside the bucket ranges provided"
)

self.base_batch_samplers: List[Sampler] = self._init_base_batch_samplers()

def _init_base_batch_samplers(self) -> list[Sampler[List[int]]]:
"""Initialize batch samplers for each bucket.

Expand All @@ -538,25 +621,30 @@ def _init_base_batch_samplers(self) -> list[Sampler[List[int]]]:
"""
base_batch_samplers = []
for k in range(self.num_buckets):
individual_kwargs = {
key: list(self.base_batch_sampler_individual_kwargs[key])[k]
for key in self.base_batch_sampler_individual_kwargs
}
base_batch_samplers.append(
self.base_batch_sampler_class(
self.bucket_element_indices[k],
**self.base_batch_sampler_shared_kwargs,
**self.base_batch_sampler_individual_kwargs[k],
**individual_kwargs,
)
)
return base_batch_samplers

def __len__(self) -> int:
"""Get the number of batches.

Can only be called if the `base_batch_sampler_class` has __len__() implemented
Can only be called if the `base_batch_sampler_class` has __len__() implemented if num_batches is not provided.

Returns:
int: Number of batches
"""
num_batches = sum(len(sampler) for sampler in self.base_batch_samplers) # type: ignore
return num_batches
if self.num_batches is None:
self.num_batches = sum(len(sampler) for sampler in self.base_batch_samplers) # type: ignore
return self.num_batches

def __iter__(self) -> Iterator[List[int]]:
"""Iterate over batches of indices.
Expand All @@ -566,29 +654,52 @@ def __iter__(self) -> Iterator[List[int]]:
Yields:
List[int]: A batch of indices of elements with sizes from each bucket range.
"""
if self.shuffle:
if self.num_batches is None:
self.num_batches = len(self)

if self.sampler is not None:
self._compute_bucket_element_indices(list(self.sampler)) # type: ignore
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
elif self.shuffle:
for indices in self.bucket_element_indices:
idx = torch.randperm(len(indices), generator=self.generator)
indices[:] = torch.tensor(indices)[idx].tolist()

base_batch_sampler_iters = [iter(batch_sampler) for batch_sampler in self.base_batch_samplers]
bucket_remaining_elements = self.bucket_sizes.clone()
total_remaining_elements = self.num_samples

while total_remaining_elements > 0:
if self.shuffle:
bucket_idx = torch.multinomial(
bucket_remaining_elements / total_remaining_elements, 1, generator=self.generator
)
else:
bucket_idx = torch.argmax((bucket_remaining_elements > 0).to(int)) # type: ignore
num_yield_batches = 0
while num_yield_batches < self.num_batches:
base_batch_sampler_iters = [iter(batch_sampler) for batch_sampler in self.base_batch_samplers]
bucket_remaining_elements = self.bucket_sizes.clone()
total_remaining_elements = self.num_samples

while total_remaining_elements > 0 and num_yield_batches < self.num_batches:
if self.shuffle:
bucket_idx = torch.multinomial(
bucket_remaining_elements / total_remaining_elements, 1, generator=self.generator
)
else:
bucket_idx = torch.argmax((bucket_remaining_elements > 0).to(int)) # type: ignore

try:
batch = next(base_batch_sampler_iters[bucket_idx])
batch_size = min(len(batch), bucket_remaining_elements[bucket_idx].item())
bucket_remaining_elements[bucket_idx] -= batch_size
total_remaining_elements -= batch_size
num_yield_batches += 1
yield batch
except StopIteration:
bucket_remaining_elements[bucket_idx] = 0
total_remaining_elements = torch.sum(bucket_remaining_elements)
continue

def set_epoch(self, epoch: int):
"""Set the epoch for `self.sampler`.

Args:
epoch (int): Epoch number.
"""
if self.sampler is not None:
try:
batch = next(base_batch_sampler_iters[bucket_idx])
bucket_remaining_elements[bucket_idx] -= len(batch)
total_remaining_elements -= len(batch)
yield batch
except StopIteration:
bucket_remaining_elements[bucket_idx] = 0
total_remaining_elements = torch.sum(bucket_remaining_elements)
continue
self.sampler.set_epoch(epoch) # type: ignore
except AttributeError:
raise RuntimeError(
f"Trying to call set_epoch on {self.sampler.__class__} which does not have such method"
)
Loading
Loading