-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfast_dataloader.py
72 lines (54 loc) · 1.97 KB
/
fast_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
class _InfiniteSampler(torch.utils.data.Sampler):
"""Wraps another Sampler to yield an infinite stream."""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
for batch in self.sampler:
yield batch
class InfiniteDataLoader:
def __init__(self, dataset, weights, batch_size, num_workers):
super().__init__()
if weights is not None:
sampler = torch.utils.data.WeightedRandomSampler(
weights, replacement=True, num_samples=batch_size)
else:
sampler = torch.utils.data.RandomSampler(dataset, replacement=True)
batch_sampler = torch.utils.data.BatchSampler(
sampler,
batch_size=batch_size,
drop_last=True)
self._infinite_iterator = iter(torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler)
))
def __iter__(self):
while True:
yield next(self._infinite_iterator)
def __len__(self):
raise ValueError
class FastDataLoader:
"""
DataLoader wrapper with slightly improved speed by not respawning worker processes at every epoch
"""
def __init__(self, dataset, batch_size, num_workers):
super().__init__()
batch_sampler = torch.utils.data.BatchSampler(
torch.utils.data.RandomSampler(dataset, replacement=False),
batch_size=batch_size,
drop_last=False
)
self._infinite_iterator = iter(torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler)
))
self._length = len(batch_sampler)
self.dataset = dataset
def __iter__(self):
for _ in range(len(self)):
yield next(self._infinite_iterator)
def __len__(self):
return self._length