Skip to content

Commit

Permalink
Adding UnZipperIterDataPipe (pytorch#198)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#198

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D34014294

Pulled By: NivekT

fbshipit-source-id: 36fee95cd517b3050f0241e800618aa2a872f5cd
  • Loading branch information
NivekT authored and facebook-github-bot committed Feb 8, 2022
1 parent 8942cf1 commit 7c7bafc
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ These DataPipes provide utility functions (e.g. caching, CSV parsing, filtering)
StreamReader
TarArchiveReader
UnBatcher
UnZipper
XzFileReader
ZipArchiveReader
Zipper
98 changes: 97 additions & 1 deletion test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ParagraphAggregator,
Rows2Columnar,
SampleMultiplexer,
UnZipper,
)


Expand Down Expand Up @@ -567,7 +568,7 @@ def _return_self(x):
with self.assertRaises(TypeError):
len(batch_dp)

def test_flatmap_datapipe(self):
def test_flatmap_iterdatapipe(self):
source_dp = IterableWrapper(list(range(20)))

def fn(e):
Expand All @@ -589,6 +590,101 @@ def fn(e):
with self.assertRaisesRegex(TypeError, "length relies on the output of its function."):
len(flatmapped_dp)

def test_unzipper_iterdatapipe(self):
source_dp = IterableWrapper([(i, i + 10, i + 20) for i in range(10)])

# Functional Test: unzips each sequence, no `sequence_length` specified
dp1, dp2, dp3 = UnZipper(source_dp, sequence_length=3)
self.assertEqual(list(range(10)), list(dp1))
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

# Functional Test: unzips each sequence, with `sequence_length` specified
dp1, dp2, dp3 = source_dp.unzip(sequence_length=3)
self.assertEqual(list(range(10)), list(dp1))
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

# Functional Test: skipping over specified values
dp2, dp3 = source_dp.unzip(sequence_length=3, columns_to_skip=[0])
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

(dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2])
self.assertEqual(list(range(10, 20)), list(dp2))

source_dp = IterableWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)])
dp2, dp3 = source_dp.unzip(sequence_length=4, columns_to_skip=[0, 3])
self.assertEqual(list(range(10, 20)), list(dp2))
self.assertEqual(list(range(20, 30)), list(dp3))

# Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small, raises error
source_dp = IterableWrapper([(i, i + 10) for i in range(10)])
dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=5)
it1 = iter(dp1)
for _ in range(5):
next(it1)
with self.assertRaises(BufferError):
next(it1)
with self.assertRaises(BufferError):
list(dp2)

# Reset Test: reset the DataPipe after reading part of it
dp1, dp2 = source_dp.unzip(sequence_length=2)
i1, i2 = iter(dp1), iter(dp2)
output2 = []
for i, n2 in enumerate(i2):
output2.append(n2)
if i == 4:
i1 = iter(dp1) # Doesn't reset because i1 hasn't been read
self.assertEqual(list(range(10, 20)), output2)

# Reset Test: DataPipe reset when some of it have been read
dp1, dp2 = source_dp.unzip(sequence_length=2)
i1, i2 = iter(dp1), iter(dp2)
output1, output2 = [], []
for i, (n1, n2) in enumerate(zip(i1, i2)):
output1.append(n1)
output2.append(n2)
if i == 4:
with warnings.catch_warnings(record=True) as wa:
i1 = iter(dp1) # Reset both all child DataPipe
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
self.assertEqual(list(range(5)) + list(range(10)), output1)
self.assertEqual(list(range(10, 15)) + list(range(10, 20)), output2)

# Reset Test: DataPipe reset, even when some other child DataPipes are not read
source_dp = IterableWrapper([(i, i + 10, i + 20) for i in range(10)])
dp1, dp2, dp3 = source_dp.unzip(sequence_length=3)
output1, output2 = list(dp1), list(dp2)
self.assertEqual(list(range(10)), output1)
self.assertEqual(list(range(10, 20)), output2)
with warnings.catch_warnings(record=True) as wa:
self.assertEqual(list(range(10)), list(dp1)) # Resets even though dp3 has not been read
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
output3 = []
for i, n3 in enumerate(dp3):
output3.append(n3)
if i == 4:
with warnings.catch_warnings(record=True) as wa:
output1 = list(dp1) # Resets even though dp3 is only partially read
self.assertEqual(len(wa), 1)
self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted")
self.assertEqual(list(range(20, 25)), output3)
self.assertEqual(list(range(10)), output1)
break
self.assertEqual(list(range(20, 30)), list(dp3)) # dp3 has to read from the start again

# __len__ Test: Each DataPipe inherits the source datapipe's length
dp1, dp2, dp3 = source_dp.unzip(sequence_length=3)
self.assertEqual(len(source_dp), len(dp1))
self.assertEqual(len(source_dp), len(dp2))
self.assertEqual(len(source_dp), len(dp3))

# TODO: Add testing for different stages of pickling for UnZipper


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchdata/datapipes/iter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from torchdata.datapipes.iter.util.samplemultiplexer import SampleMultiplexerDataPipe as SampleMultiplexer
from torchdata.datapipes.iter.util.saver import SaverIterDataPipe as Saver
from torchdata.datapipes.iter.util.tararchivereader import TarArchiveReaderIterDataPipe as TarArchiveReader
from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper
from torchdata.datapipes.iter.util.xzfilereader import XzFileReaderIterDataPipe as XzFileReader
from torchdata.datapipes.iter.util.ziparchivereader import ZipArchiveReaderIterDataPipe as ZipArchiveReader

Expand Down Expand Up @@ -134,6 +135,7 @@
"StreamReader",
"TarArchiveReader",
"UnBatcher",
"UnZipper",
"XzFileReader",
"ZipArchiveReader",
"Zipper",
Expand Down
59 changes: 59 additions & 0 deletions torchdata/datapipes/iter/util/unzipper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import Optional, Sequence, TypeVar

from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _ForkerIterDataPipe
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe


T = TypeVar("T")


@functional_datapipe("unzip")
class UnZipperIterDataPipe(IterDataPipe[T]):
r"""
Takes in a DataPipe of Sequences, unpacks each Sequence, and return the elements in separate DataPipes
based on their position in the Sequence. The number of instances produced equals to the sequence legnth
minus the number of columns to skip.
Note:
Each sequence within the DataPipe should have the same length, specified by
the input argument `sequence_length`.
Args:
source_datapipe: Iterable DataPipe with sequences of data
sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length.
buffer_size: this restricts how far ahead the leading child DataPipe can read relative
to the slowest child DataPipe. Use -1 for the unlimited buffer.
columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be
an integer from 0 to sequence_length - 1)
"""

def __new__(
cls,
source_datapipe: IterDataPipe[Sequence[T]],
sequence_length: int,
buffer_size: int = 1000,
columns_to_skip: Optional[Sequence[int]] = None,
):
if columns_to_skip is None:
instance_ids = list(range(sequence_length))
else:
skips = set(columns_to_skip)
instance_ids = [i for i in range(sequence_length) if i not in skips]

if len(instance_ids) == 0:
raise RuntimeError(
"All instances are being filtered out in UnZipperIterDataPipe. Please check"
"the input `sequence_length` and `columns_to_skip`."
)

# The implementation basically uses Forker but only yields a specific element within the sequence
container = _UnZipperIterDataPipe(source_datapipe, sequence_length, buffer_size)
return [_ChildDataPipe(container, i) for i in instance_ids]


class _UnZipperIterDataPipe(_ForkerIterDataPipe):
def get_next_element_by_instance(self, instance_id: int):
for return_val in super().get_next_element_by_instance(instance_id):
yield return_val[instance_id]

0 comments on commit 7c7bafc

Please sign in to comment.