From 90aa8bd24d92b0fc9060172694fd8c5ac3be068d Mon Sep 17 00:00:00 2001 From: Alexander Battig Date: Sat, 11 Mar 2023 14:47:36 +0100 Subject: [PATCH] Unify buffer_size, following the convention suggested in https://github.com/pytorch/data/issues/335 --- torchdata/datapipes/iter/transform/bucketbatcher.py | 9 +++++---- torchdata/datapipes/iter/util/combining.py | 10 +++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/torchdata/datapipes/iter/transform/bucketbatcher.py b/torchdata/datapipes/iter/transform/bucketbatcher.py index 867068627..27ba48cd4 100644 --- a/torchdata/datapipes/iter/transform/bucketbatcher.py +++ b/torchdata/datapipes/iter/transform/bucketbatcher.py @@ -218,6 +218,7 @@ class MaxTokenBucketizerIterDataPipe(IterDataPipe[DataChunk[T_co]]): Note that batches are bucketized starting from the smallest size in a buffer. This can limit the variablity of batches if ``buffer_size`` is large. + If it's specified as ``None``, the buffer size is set as infinite. To increase variablity, apply ``torchdata.datapipes.iter.Shuffler`` before and after this DataPipe, and keep ``buffer_size`` small. @@ -256,7 +257,7 @@ def __init__( len_fn: Callable = _default_len_fn, min_len: int = 0, max_len: Optional[int] = None, - buffer_size: int = 1000, + buffer_size: int = 10000, include_padding: bool = False, ) -> None: if max_len is None: @@ -266,8 +267,8 @@ def __init__( raise ValueError("``min_len`` should be larger than 0 and equal to or smaller than ``max_len``.") if max_len > max_token_count: raise ValueError("``max_token_count`` must be equal to or greater than ``max_len``.") - if buffer_size <= 0: - raise ValueError("'buffer_size' is required to be a positive integer.") + if buffer_size is not None and buffer_size <= 0: + raise ValueError("'buffer_size' is required to be a positive integer or None.") self.datapipe = datapipe.map(partial(_token_len_fn, len_fn=len_fn)) self.datapipe = self.datapipe.filter(partial(_token_filter_fn, min_len=min_len, max_len=max_len)) self.max_token_count = max_token_count @@ -281,7 +282,7 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]: max_length: int = 0 for d in self.datapipe: heapq.heappush(buffer, d) - if len(buffer) == self.buffer_size: + if self.buffer_size is not None and len(buffer) == self.buffer_size: buffer, batch, batch_size, max_length, data_chunk = self._pop_buffer( buffer, batch, batch_size, max_length ) diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index bd9d70769..1919c0b70 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -255,7 +255,7 @@ class RoundRobinDemultiplexerIterDataPipe(IterDataPipe): num_instances: number of instances of the DataPipe to create buffer_size: this defines the maximum number of inputs that the buffer can hold across all child DataPipes while waiting for their values to be yielded. - Defaults to ``1000``. Use ``-1`` for the unlimited buffer. + Defaults to ``1000``. Use ``None`` for the unlimited buffer. Examples: >>> from torchdata.datapipes.iter import IterableWrapper @@ -271,7 +271,7 @@ class RoundRobinDemultiplexerIterDataPipe(IterDataPipe): 2 """ - def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000): + def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 10000): if num_instances < 1: raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") if num_instances == 1: @@ -314,7 +314,7 @@ class UnZipperIterDataPipe(IterDataPipe[T]): source_datapipe: Iterable DataPipe with sequences of data sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length. buffer_size: this restricts how far ahead the leading child DataPipe can read relative - to the slowest child DataPipe. Use -1 for the unlimited buffer. + to the slowest child DataPipe. Use None for the unlimited buffer. columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be an integer from 0 to sequence_length - 1) @@ -334,7 +334,7 @@ def __new__( cls, source_datapipe: IterDataPipe[Sequence[T]], sequence_length: int, - buffer_size: int = 1000, + buffer_size: int = 10000, columns_to_skip: Optional[Sequence[int]] = None, ): if columns_to_skip is None: @@ -355,7 +355,7 @@ def __new__( class _UnZipperIterDataPipe(_ForkerIterDataPipe): - def __init__(self, datapipe: IterDataPipe, instance_ids: List[int], buffer_size: int = 1000): + def __init__(self, datapipe: IterDataPipe, instance_ids: List[int], buffer_size: int = 10000): super().__init__(datapipe, len(instance_ids), buffer_size) # type: ignore[arg-type] self.instance_ids = instance_ids