-
Notifications
You must be signed in to change notification settings - Fork 3
/
test.py
76 lines (52 loc) · 2.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import tensorflow as tf
import skimage.measure as measure
import scipy.io
import numpy as np
import os
import model
import utils
import data
def run(config):
test_data_path = './data/test_data/mat/'
result_root = './result/'
benchmark_list = ['Set5', 'Set14', 'B100', 'Urban100']
scale = [2, 3, 4]
if not os.path.exists(result_root):
os.makedirs(result_root)
for benchmark in benchmark_list:
os.makedirs(os.path.join(result_root, benchmark))
for s in scale:
os.makedirs(os.path.join(result_root, benchmark, str(s)))
s = config.scale
with tf.Session() as sess:
vdsr = model.Model(config)
vdsr.load(sess, config.checkpoint_path, config.model_name)
for benchmark in benchmark_list:
print(benchmark)
test_benchmark_path = os.path.join(test_data_path, benchmark)
lr, gt = data.load_lr_gt_mat(test_benchmark_path, s)
quality_result = open(
os.path.join(result_root, benchmark, 'quality_%d.csv' % s), 'w'
)
quality_result.write('file name, psnr, ssim\n')
psnr_list = []
ssim_list = []
for i, _ in enumerate(gt):
lr_image = lr[i]['data']
gt_image = gt[i]['data']
sr = sess.run(vdsr.inference, feed_dict={
vdsr.lr: lr_image.reshape((1,) + lr_image.shape + (1,))
})
sr = sr.reshape(sr.shape[1: 3])
sr_ = utils.shave(sr, s)
sr_ = sr_.astype(np.float64)
gt_image_ = utils.shave(gt_image, s)
_psnr = measure.compare_psnr(gt_image_, sr_)
_ssim = measure.compare_ssim(gt_image_, sr_)
quality_result.write('%s, %f, %f\n' % (gt[i]['name'], _psnr, _ssim))
psnr_list.append(_psnr)
ssim_list.append(_ssim)
scipy.io.savemat(
os.path.join(result_root, benchmark, str(s), gt[i]['name'] + '.mat'), {'sr': sr}
)
quality_result.close()