Skip to content

Commit

Permalink
Speed up batched PyTorch DataLoader (huggingface#5512)
Browse files Browse the repository at this point in the history
* speed up batched torch dataloader

* use latest torch

* style

* fix

* update torchaudio as well

* dont use latest torch in CI

* Update tests/test_arrow_dataset.py

Co-authored-by: Mario Šaško <[email protected]>

---------

Co-authored-by: Mario Šaško <[email protected]>
  • Loading branch information
lhoestq and mariosasko authored Feb 19, 2023
1 parent 29de617 commit f401758
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
7 changes: 5 additions & 2 deletions docs/source/use_with_pytorch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ Reloading the dataset inside a worker doesn't fill up your RAM, since it simply
>>> dataloader = DataLoader(ds, batch_size=32, num_workers=4)
```

#### Use a BatchSampler
#### Use a BatchSampler (torch<=1.12.1)

By default, the PyTorch `DataLoader` load batches of data from a dataset one by one like this:
For old versions of PyTorch, using a `BatchSampler` can speed up data loading.
Indeed if you are using `torch<=1.12.1`, the PyTorch `DataLoader` load batches of data from a dataset one by one like this:

```py
batch = [dataset[idx] for idx in range(start, end)]
Expand All @@ -198,6 +199,8 @@ For the PyTorch `DataLoader` to query batches using a list, you can use a `Batch
Moreover, this is particularly useful if you used [`set_transform`] to apply a transform on-the-fly when examples are accessed.
You must use a `BatchSampler` if you want the transform to be given full batches instead of receiving `batch_size` times one single element.

Recent versions of PyTorch use a list of indices, so a `BatchSampler` is not needed to get the best speed even if you used [`set_transform`].

### Stream data

Stream a dataset by loading it as an [`IterableDataset`]. This allows you to progressively iterate over a remote dataset without downloading it on disk and or over local data files.
Expand Down
10 changes: 7 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2645,9 +2645,13 @@ def __getitem__(self, key: str) -> List: # noqa: F811

def __getitem__(self, key): # noqa: F811
"""Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
return self._getitem(
key,
)
return self._getitem(key)

def __getitems__(self, keys: List) -> List:
"""Can be used to get a batch using a list of integers indices."""
batch = self.__getitem__(keys)
n_examples = len(batch[next(iter(batch))])
return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]

def cleanup_cache_files(self) -> int:
"""Clean up all cache files in the dataset cache directory, excepted the currently used cache file if there is
Expand Down
18 changes: 18 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pyarrow as pa
import pytest
from absl.testing import parameterized
from packaging import version

import datasets.arrow_dataset
from datasets import concatenate_datasets, interleave_datasets, load_from_disk
Expand Down Expand Up @@ -4240,6 +4241,23 @@ def test_dataset_to_iterable_dataset(dataset):
dataset.to_iterable_dataset(num_shards=len(dataset) + 1)


@pytest.mark.parametrize("batch_size", [1, 4])
@require_torch
def test_dataset_with_torch_dataloader(dataset, batch_size):
from torch.utils.data import DataLoader

from datasets import config

dataloader = DataLoader(dataset, batch_size=batch_size)
with patch.object(dataset, "_getitem", wraps=dataset._getitem) as mock_getitem:
out = list(dataloader)
getitem_call_count = mock_getitem.call_count
assert len(out) == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)
# calling dataset[list_of_indices] is much more efficient than [dataset[idx] for idx in list of indices]
if config.TORCH_VERSION >= version.parse("1.13.0"):
assert getitem_call_count == len(dataset) // batch_size + int(len(dataset) % batch_size > 0)


@pytest.mark.parametrize("return_lazy_dict", [True, False, "mix"])
def test_map_cases(return_lazy_dict):
def f(x):
Expand Down

0 comments on commit f401758

Please sign in to comment.