Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify buffer_size across TorchData DataPipes #1077

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions torchdata/datapipes/iter/transform/bucketbatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class MaxTokenBucketizerIterDataPipe(IterDataPipe[DataChunk[T_co]]):

Note that batches are bucketized starting from the smallest size in a buffer.
This can limit the variablity of batches if ``buffer_size`` is large.
If it's specified as ``None``, the buffer size is set as infinite.
To increase variablity, apply ``torchdata.datapipes.iter.Shuffler`` before and after this DataPipe,
and keep ``buffer_size`` small.

Expand Down Expand Up @@ -256,7 +257,7 @@ def __init__(
len_fn: Callable = _default_len_fn,
min_len: int = 0,
max_len: Optional[int] = None,
buffer_size: int = 1000,
buffer_size: int = 10000,
include_padding: bool = False,
) -> None:
if max_len is None:
Expand All @@ -266,8 +267,8 @@ def __init__(
raise ValueError("``min_len`` should be larger than 0 and equal to or smaller than ``max_len``.")
if max_len > max_token_count:
raise ValueError("``max_token_count`` must be equal to or greater than ``max_len``.")
if buffer_size <= 0:
raise ValueError("'buffer_size' is required to be a positive integer.")
if buffer_size is not None and buffer_size <= 0:
raise ValueError("'buffer_size' is required to be a positive integer or None.")
self.datapipe = datapipe.map(partial(_token_len_fn, len_fn=len_fn))
self.datapipe = self.datapipe.filter(partial(_token_filter_fn, min_len=min_len, max_len=max_len))
self.max_token_count = max_token_count
Expand All @@ -281,7 +282,7 @@ def __iter__(self) -> Iterator[DataChunk[T_co]]:
max_length: int = 0
for d in self.datapipe:
heapq.heappush(buffer, d)
if len(buffer) == self.buffer_size:
if self.buffer_size is not None and len(buffer) == self.buffer_size:
buffer, batch, batch_size, max_length, data_chunk = self._pop_buffer(
buffer, batch, batch_size, max_length
)
Expand Down
10 changes: 5 additions & 5 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class RoundRobinDemultiplexerIterDataPipe(IterDataPipe):
num_instances: number of instances of the DataPipe to create
buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
DataPipes while waiting for their values to be yielded.
Defaults to ``1000``. Use ``-1`` for the unlimited buffer.
Defaults to ``1000``. Use ``None`` for the unlimited buffer.

Examples:
>>> from torchdata.datapipes.iter import IterableWrapper
Expand All @@ -271,7 +271,7 @@ class RoundRobinDemultiplexerIterDataPipe(IterDataPipe):
2
"""

def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000):
def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 10000):
if num_instances < 1:
raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")
if num_instances == 1:
Expand Down Expand Up @@ -314,7 +314,7 @@ class UnZipperIterDataPipe(IterDataPipe[T]):
source_datapipe: Iterable DataPipe with sequences of data
sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length.
buffer_size: this restricts how far ahead the leading child DataPipe can read relative
to the slowest child DataPipe. Use -1 for the unlimited buffer.
to the slowest child DataPipe. Use None for the unlimited buffer.
columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be
an integer from 0 to sequence_length - 1)

Expand All @@ -334,7 +334,7 @@ def __new__(
cls,
source_datapipe: IterDataPipe[Sequence[T]],
sequence_length: int,
buffer_size: int = 1000,
buffer_size: int = 10000,
columns_to_skip: Optional[Sequence[int]] = None,
):
if columns_to_skip is None:
Expand All @@ -355,7 +355,7 @@ def __new__(


class _UnZipperIterDataPipe(_ForkerIterDataPipe):
def __init__(self, datapipe: IterDataPipe, instance_ids: List[int], buffer_size: int = 1000):
def __init__(self, datapipe: IterDataPipe, instance_ids: List[int], buffer_size: int = 10000):
super().__init__(datapipe, len(instance_ids), buffer_size) # type: ignore[arg-type]
self.instance_ids = instance_ids

Expand Down