-
Notifications
You must be signed in to change notification settings - Fork 11
/
custom_dataset_data_loader.py
executable file
·44 lines (36 loc) · 1.35 KB
/
custom_dataset_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'temporal':
from data.temporal_dataset import TemporalDataset
dataset = TemporalDataset()
elif opt.dataset_mode == 'face':
from data.face_dataset import FaceDataset
dataset = FaceDataset()
elif opt.dataset_mode == 'pose':
from data.pose_dataset import PoseDataset
dataset = PoseDataset()
elif opt.dataset_mode == 'test':
from data.test_dataset import TestDataset
dataset = TestDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)