Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor bucket batch sampler to support DDP sampling #478

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ def __init__(
base_batch_sampler_individual_kwargs: Optional[Dict[str, Iterable]] = None,
shuffle: Optional[bool] = True,
generator: Optional[torch.Generator] = None,
sampler: Optional[Sampler[List[int]]] = None,
num_batches: Optional[int] = None,
) -> None:
"""Initializes the BucketBatchSampler.

Expand All @@ -390,12 +392,14 @@ def __init__(
Default to {}.
shuffle: A boolean indicating whether to shuffle the dataset and buckets. Defaults to True.
generator: Generator used in sampling. Defaults to None.
sampler: Sampler to get element indices for this batch sampler, required for distributed sampling. Defaults to None.
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
num_batches: Number of mini-batches, required for distributed sampling. Defaults to None.
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved

Raises:
ValueError: If `sizes` is not a 1D tensor of real numbers.
ValueError: If `bucket_boundaries` is not a 1D tensor of real numbers.
ValueError: If `base_batch_sampler_individual_kwargs` or `base_batch_sampler_individual_kwargs` is not a keyword argument dictionary.
ValueError: If the length of values in the dict of `base_batch_sampler_individual_kwargs` must be equal to len(bucket_boundaries) - 1.
ValueError: If the length of values in the dict of `base_batch_sampler_individual_kwargs` is not equal to len(bucket_boundaries) - 1.
RuntimeError: If there is no elements with sizes inside the ranges specified by `bucket_boundaries`.

"""
Expand Down Expand Up @@ -486,50 +490,97 @@ def __init__(
)

self.base_batch_sampler_class = base_batch_sampler_class
self.base_batch_sampler_shared_kwargs = (
{} if base_batch_sampler_shared_kwargs is None else base_batch_sampler_shared_kwargs
)
base_batch_sampler_individual_kwargs = (
{} if base_batch_sampler_individual_kwargs is None else base_batch_sampler_individual_kwargs
)
self.base_batch_sampler_individual_kwargs = [
{key: list(base_batch_sampler_individual_kwargs[key])[k] for key in base_batch_sampler_individual_kwargs}
for k in range(self.num_buckets)
]
self.base_batch_sampler_shared_kwargs = base_batch_sampler_shared_kwargs
self.base_batch_sampler_individual_kwargs = base_batch_sampler_individual_kwargs

self.bucket_sizes: torch.Tensor # number of elements in each bucket
self.bucket_element_indices: List[List[int]] # List of elements' indices for each bucket

# bucket index for each element
element_bucket_indices = torch.bucketize(sizes, bucket_boundaries, right=True)
self.element_bucket_indices = torch.bucketize(sizes, self.bucket_boundaries, right=True)

if sampler is not None and not isinstance(sampler, Sampler): # type: ignore
raise TypeError(
f"sampler should be a sampler class inherited from torch.utils.data.Sampler, but got sampler={sampler}"
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
)

# element indices reordered for each bucket
reordered_element_indices = torch.argsort(element_bucket_indices, stable=True)
self.sampler = sampler
self.bucket_element_indices = [[] for _ in range(self.num_buckets)]

if self.sampler is None:
self._compute_bucket_element_indices()
else:
self._compute_bucket_element_indices(list(self.sampler)) # type: ignore

self.base_batch_samplers = self._init_base_batch_samplers()

if num_batches is not None and (not isinstance(num_batches, int) or num_batches <= 0):
raise ValueError(f"num_batches should be a positive integer, but got num_batches = {num_batches}")

self.num_batches = num_batches

def setup_for_pytorch_lightning(self) -> Sampler:
"""Add additional attributes necessary to be compatible with pytorch lightning distributed sampling.

When used for distributed sampling, Pytorch Lightning will use these attributes to re-init this batch sampler and
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
replace the sampler to torch.utils.data.DistributedSampler.

Returns:
Sampler: this batch sampler instance.
"""
setattr(self, "__pl_saved_args", ())
setattr(
self,
"__pl_saved_kwargs",
{
"sizes": self.sizes,
"bucket_boundaries": self.bucket_boundaries,
"base_batch_sampler_class": self.base_batch_sampler_class,
"base_batch_sampler_shared_kwargs": self.base_batch_sampler_shared_kwargs,
"base_batch_sampler_individual_kwargs": self.base_batch_sampler_individual_kwargs,
"shuffle": self.shuffle,
"generator": self.generator,
"sampler": self.sampler,
"num_batches": self.num_batches,
},
)
setattr(self, "__pl_saved_default_kwargs", {})
setattr(self, "__pl_saved_arg_names", ())
return self

def _compute_bucket_element_indices(self, indices: Optional[List[int]] = None):
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
if indices is None:
element_bucket_indices = self.element_bucket_indices
# element indices reordered for each bucket
reordered_element_indices = torch.argsort(element_bucket_indices, stable=True)
else:
indices: torch.Tensor = torch.tensor(indices) # type: ignore
element_bucket_indices = self.element_bucket_indices[indices]
# element indices reordered for each bucket
reordered_element_indices = indices[torch.argsort(element_bucket_indices, stable=True)]

# bucket sizes, including the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
bucket_sizes = torch.bincount(element_bucket_indices, minlength=len(bucket_boundaries) + 1)
bucket_sizes = torch.bincount(element_bucket_indices, minlength=len(self.bucket_boundaries) + 1)

# bucket segments
bucket_segments = torch.cumsum(bucket_sizes, dim=0)[:-1]
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved

self.bucket_element_indices = []
# exclude the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
for bucket_idx in range(self.num_buckets):
self.bucket_element_indices.append(
reordered_element_indices[bucket_segments[bucket_idx] : bucket_segments[bucket_idx + 1]].tolist()
)
self.bucket_element_indices[bucket_idx][:] = reordered_element_indices[
bucket_segments[bucket_idx] : bucket_segments[bucket_idx + 1]
].tolist()

self.bucket_sizes = bucket_sizes[1 : (self.num_buckets + 1)]

self.num_samples = torch.sum(self.bucket_sizes).item()
DejunL marked this conversation as resolved.
Show resolved Hide resolved
if self.num_samples == 0:
raise RuntimeError("The sizes of all elements in the dataset are outside the bucket ranges provided")
if self.num_samples < len(self.sizes):
if self.num_samples < len(element_bucket_indices):
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
f"{len(self.sizes) - self.num_samples} elements are outside the buckets provided and will be skipped"
f"{len(element_bucket_indices) - self.num_samples} elements are outside the buckets provided and will be skipped"
)

self.base_batch_samplers: List[Sampler] = self._init_base_batch_samplers()

def _init_base_batch_samplers(self) -> list[Sampler[List[int]]]:
"""Initialize batch samplers for each bucket.

Expand All @@ -538,11 +589,15 @@ def _init_base_batch_samplers(self) -> list[Sampler[List[int]]]:
"""
base_batch_samplers = []
for k in range(self.num_buckets):
individual_kwargs = {
key: list(self.base_batch_sampler_individual_kwargs[key])[k]
for key in self.base_batch_sampler_individual_kwargs
}
base_batch_samplers.append(
self.base_batch_sampler_class(
self.bucket_element_indices[k],
**self.base_batch_sampler_shared_kwargs,
**self.base_batch_sampler_individual_kwargs[k],
**individual_kwargs,
)
)
return base_batch_samplers
Expand All @@ -555,8 +610,9 @@ def __len__(self) -> int:
Returns:
int: Number of batches
"""
num_batches = sum(len(sampler) for sampler in self.base_batch_samplers) # type: ignore
return num_batches
if self.num_batches is None:
self.num_batches = sum(len(sampler) for sampler in self.base_batch_samplers) # type: ignore
return self.num_batches

def __iter__(self) -> Iterator[List[int]]:
"""Iterate over batches of indices.
Expand All @@ -566,29 +622,53 @@ def __iter__(self) -> Iterator[List[int]]:
Yields:
List[int]: A batch of indices of elements with sizes from each bucket range.
"""
if self.shuffle:
if self.num_batches is None:
self.num_batches = len(self)

if self.sampler is not None:
indices = list(self.sampler)
self._compute_bucket_element_indices(list(self.sampler)) # type: ignore
guoqing-zhou marked this conversation as resolved.
Show resolved Hide resolved
elif self.shuffle:
for indices in self.bucket_element_indices:
idx = torch.randperm(len(indices), generator=self.generator)
indices[:] = torch.tensor(indices)[idx].tolist()

base_batch_sampler_iters = [iter(batch_sampler) for batch_sampler in self.base_batch_samplers]
bucket_remaining_elements = self.bucket_sizes.clone()
total_remaining_elements = self.num_samples

while total_remaining_elements > 0:
if self.shuffle:
bucket_idx = torch.multinomial(
bucket_remaining_elements / total_remaining_elements, 1, generator=self.generator
)
else:
bucket_idx = torch.argmax((bucket_remaining_elements > 0).to(int)) # type: ignore
num_yield_batches = 0
while num_yield_batches < self.num_batches:
base_batch_sampler_iters = [iter(batch_sampler) for batch_sampler in self.base_batch_samplers]
bucket_remaining_elements = self.bucket_sizes.clone()
total_remaining_elements = self.num_samples

while total_remaining_elements > 0 and num_yield_batches < self.num_batches:
if self.shuffle:
bucket_idx = torch.multinomial(
bucket_remaining_elements / total_remaining_elements, 1, generator=self.generator
)
else:
bucket_idx = torch.argmax((bucket_remaining_elements > 0).to(int)) # type: ignore

try:
batch = next(base_batch_sampler_iters[bucket_idx])
batch_size = min(len(batch), bucket_remaining_elements[bucket_idx].item())
bucket_remaining_elements[bucket_idx] -= batch_size
total_remaining_elements -= batch_size
num_yield_batches += 1
yield batch
except StopIteration:
bucket_remaining_elements[bucket_idx] = 0
total_remaining_elements = torch.sum(bucket_remaining_elements)
continue

def set_epoch(self, epoch: int):
"""Set the epoch for `self.sampler`.

Args:
epoch (int): Epoch number.
"""
if self.sampler is not None:
try:
batch = next(base_batch_sampler_iters[bucket_idx])
bucket_remaining_elements[bucket_idx] -= len(batch)
total_remaining_elements -= len(batch)
yield batch
except StopIteration:
bucket_remaining_elements[bucket_idx] = 0
total_remaining_elements = torch.sum(bucket_remaining_elements)
continue
self.sampler.set_epoch(epoch) # type: ignore
except AttributeError:
raise RuntimeError(
f"Trying to call set_epoch on {self.sampler.__class__} which does not have such method"
)
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,64 @@ def test_init_bucket_batch_sampler_with_invalid_base_batch_sampler_kwargs(sample
base_batch_sampler_shared_kwargs={},
base_batch_sampler_individual_kwargs={"batch_size": [2, 3, 5]},
)
# base_batch_sampler_shared_kwargs and base_batch_sampler_individual_kwargs should be keyword argument dictionary
with pytest.raises(TypeError):
BucketBatchSampler(
sizes=sizes,
bucket_boundaries=bucket_boundaries,
base_batch_sampler_class=base_batch_sampler_class,
base_batch_sampler_shared_kwargs={1: False}, # type: ignore
base_batch_sampler_individual_kwargs=base_batch_sampler_individual_kwargs,
)


def test_bucket_batch_sampler_with_invalid_sampler(sample_data):
(
sizes,
bucket_boundaries,
base_batch_sampler_class,
base_batch_sampler_shared_kwargs,
base_batch_sampler_individual_kwargs,
) = sample_data
# sampler must be an instance of class inherited from torch.utils.data.Sampler
with pytest.raises(TypeError):
BucketBatchSampler(
sizes=sizes,
bucket_boundaries=bucket_boundaries,
base_batch_sampler_class=base_batch_sampler_class,
base_batch_sampler_shared_kwargs=base_batch_sampler_shared_kwargs,
base_batch_sampler_individual_kwargs=base_batch_sampler_individual_kwargs,
sampler=DataLoader, # type: ignore
)


def test_bucket_batch_sampler_with_invalid_num_batches(sample_data):
(
sizes,
bucket_boundaries,
base_batch_sampler_class,
base_batch_sampler_shared_kwargs,
base_batch_sampler_individual_kwargs,
) = sample_data
# num_batches must be a positive integer.
with pytest.raises(ValueError):
BucketBatchSampler(
sizes=sizes,
bucket_boundaries=bucket_boundaries,
base_batch_sampler_class=base_batch_sampler_class,
base_batch_sampler_shared_kwargs=base_batch_sampler_shared_kwargs,
base_batch_sampler_individual_kwargs=base_batch_sampler_individual_kwargs,
num_batches=-1,
)
with pytest.raises(ValueError):
BucketBatchSampler(
sizes=sizes,
bucket_boundaries=bucket_boundaries,
base_batch_sampler_class=base_batch_sampler_class,
base_batch_sampler_shared_kwargs=base_batch_sampler_shared_kwargs,
base_batch_sampler_individual_kwargs=base_batch_sampler_individual_kwargs,
num_batches=2.0, # type: ignore
)


def test_bucket_batch_sampler_attributes(sample_data):
Expand Down Expand Up @@ -528,6 +586,7 @@ def cost_of_element(index):
base_batch_sampler_shared_kwargs={"sizeof": cost_of_element},
base_batch_sampler_individual_kwargs={"max_total_size": [10, 30, 50]},
shuffle=False,
num_batches=11,
)
batch_lists_first_iter = list(iter(batch_sampler))
ref_batch_lists = [
Expand Down Expand Up @@ -556,6 +615,7 @@ def cost_of_element(index):
base_batch_sampler_individual_kwargs={"max_total_size": [10, 30, 50]},
shuffle=True,
generator=torch.Generator(),
num_batches=20,
)
batch_lists_first_iter = list(iter(batch_sampler))
assert set(functools.reduce(operator.iadd, batch_lists_first_iter, [])) == set(range(25))
Expand Down
Loading