-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathface_dataset.py
76 lines (57 loc) · 2.31 KB
/
face_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os.path as osp
import os
from PIL import Image
import numpy as np
import json
import cv2
class CelebAMask_HQ_Dataset(Dataset):
def __init__(self,
root_dir,
sample_indices,
mode,
tr_transform=None):
assert mode in ('train', "val", "test")
self.root_dir = root_dir
self.mode = mode
self.tr_transform = tr_transform
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.image_dir = os.path.join(root_dir, 'CelebA-HQ-img') # Path to image folder
self.mask_dir = os.path.join(root_dir, 'mask') # Path to mask folder
self.sample_indices = sample_indices
self.train_dataset = []
self.test_dataset = []
self.preprocess()
def preprocess(self):
for i in range(len([name for name in os.listdir(self.image_dir) if osp.isfile(osp.join(self.image_dir, name))])):
img_path = osp.join(self.image_dir, str(i)+'.jpg')
label_path = osp.join(self.mask_dir, str(i)+'.png')
if self.mode != "test":
self.train_dataset.append([img_path, label_path])
else:
self.test_dataset.append([img_path, label_path])
def __getitem__(self, idx):
idx = self.sample_indices[idx]
if self.mode != "test":
img_pth, mask_pth = self.train_dataset[idx]
else:
img_pth, mask_pth = self.test_dataset[idx]
# read img, mask
image = Image.open(img_pth).convert('RGB')
image = image.resize((512, 512), Image.BILINEAR)
# mask = Image.open(mask_pth).convert('P')
mask = Image.open(mask_pth).convert('L')
# data augmentation
if self.mode == 'train':
image, mask = self.tr_transform(image, mask)
image = self.to_tensor(image)
mask = torch.from_numpy(np.array(mask)).long()
return image, mask
def __len__(self):
return len(self.sample_indices)