forked from zwx8981/DBCNN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLIVEChallengeFolder.py
112 lines (85 loc) · 3.32 KB
/
LIVEChallengeFolder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch.utils.data as data
from PIL import Image
import os
import os.path
#import math
import scipy.io
import numpy as np
import random
def getFileName(path, suffix):
''' 获取指定目录下的所有指定后缀的文件名 '''
filename = []
f_list = os.listdir(path)
# print f_list
for i in f_list:
# os.path.splitext():分离文件名与扩展名
if os.path.splitext(i)[1] == suffix:
filename.append(i)
return filename
def getDistortionTypeFileName(path, num):
filename = []
index = 1
for i in range(0,num):
name = '%s%s%s' % ('img',str(index),'.bmp')
filename.append(os.path.join(path,name))
index = index + 1
return filename
class LIVEChallengeFolder(data.Dataset):
def __init__(self, root, loader, index, transform=None, target_transform=None):
self.root = root
self.loader = loader
self.imgpath = scipy.io.loadmat(os.path.join(self.root, 'Data', 'AllImages_release.mat'))
self.imgpath = self.imgpath['AllImages_release']
self.imgpath = self.imgpath[7:1169]
self.mos = scipy.io.loadmat(os.path.join(self.root, 'Data', 'AllMOS_release.mat'))
self.labels = self.mos['AllMOS_release'].astype(np.float32)
self.labels = self.labels[0][7:1169]
sample = []
for i, item in enumerate(index):
sample.append((os.path.join(self.root, 'Images', self.imgpath[item][0][0]), self.labels[item]))
self.samples = sample
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self):
length = len(self.samples)
return length
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def accimage_loader(path):
import accimage
try:
return accimage.Image(path)
except IOError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)
def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
if __name__ == '__main__':
liveroot = 'D:\zwx_Project\dbcnn_pytorch\dataset\ChallengeDB_release'
index = list(range(0,1162))
random.shuffle(index)
train_index = index[0:round(0.8*1162)]
test_index = index[round(0.8*1162):1162]
trainset = LIVEChallengeFolder(root = liveroot, loader = default_loader, index = train_index)
testset = LIVEChallengeFolder(root = liveroot, loader = default_loader, index = test_index)