-
Notifications
You must be signed in to change notification settings - Fork 15
/
test.py
101 lines (88 loc) · 4.01 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse
import torch
import torch.nn as nn
from network.Math_Module import P, Q
from network.decom import Decom
import os
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import time
from utils import *
def one2three(x):
return torch.cat([x, x, x], dim=1).to(x)
class Inference(nn.Module):
def __init__(self, opts):
super().__init__()
self.opts = opts
# loading decomposition model
self.model_Decom_low = Decom()
self.model_Decom_low = load_initialize(self.model_Decom_low, self.opts.Decom_model_low_path)
# loading R; old_model_opts; and L model
self.unfolding_opts, self.model_R, self.model_L= load_unfolding(self.opts.unfolding_model_path)
# loading adjustment model
self.adjust_model = load_adjustment(self.opts.adjust_model_path)
self.P = P()
self.Q = Q()
transform = [
transforms.ToTensor(),
]
self.transform = transforms.Compose(transform)
print(self.model_Decom_low)
print(self.model_R)
print(self.model_L)
print(self.adjust_model)
#time.sleep(8)
def unfolding(self, input_low_img):
for t in range(self.unfolding_opts.round):
if t == 0: # initialize R0, L0
P, Q = self.model_Decom_low(input_low_img)
else: # update P and Q
w_p = (self.unfolding_opts.gamma + self.unfolding_opts.Roffset * t)
w_q = (self.unfolding_opts.lamda + self.unfolding_opts.Loffset * t)
P = self.P(I=input_low_img, Q=Q, R=R, gamma=w_p)
Q = self.Q(I=input_low_img, P=P, L=L, lamda=w_q)
R = self.model_R(r=P, l=Q)
L = self.model_L(l=Q)
return R, L
def lllumination_adjust(self, L, ratio):
ratio = torch.ones(L.shape).cuda() * self.opts.ratio
return self.adjust_model(l=L, alpha=ratio)
def forward(self, input_low_img):
if torch.cuda.is_available():
input_low_img = input_low_img.cuda()
with torch.no_grad():
start = time.time()
R, L = self.unfolding(input_low_img)
High_L = self.lllumination_adjust(L, self.opts.ratio)
I_enhance = High_L * R
p_time = (time.time() - start)
return I_enhance, p_time
def run(self, low_img_path):
file_name = os.path.basename(self.opts.img_path)
name = file_name.split('.')[0]
low_img = self.transform(Image.open(low_img_path)).unsqueeze(0)
enhance, p_time = self.forward(input_low_img=low_img)
if not os.path.exists(self.opts.output):
os.makedirs(self.opts.output)
save_path = os.path.join(self.opts.output, file_name.replace(name, "%s_%d_URetinexNet"%(name, self.opts.ratio)))
np_save_TensorImg(enhance, save_path)
print("================================= time for %s: %f============================"%(file_name, p_time))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Configure')
# specify your data path here!
parser.add_argument('--img_path', type=str, default="./demo/input/3.png")
parser.add_argument('--output', type=str, default="./demo/output")
# ratio are recommended to be 3-5, bigger ratio will lead to over-exposure
parser.add_argument('--ratio', type=int, default=5)
# model path
parser.add_argument('--Decom_model_low_path', type=str, default="./ckpt/init_low.pth")
parser.add_argument('--unfolding_model_path', type=str, default="./ckpt/unfolding.pth")
parser.add_argument('--adjust_model_path', type=str, default="./ckpt/L_adjust.pth")
parser.add_argument('--gpu_id', type=int, default=0)
opts = parser.parse_args()
for k, v in vars(opts).items():
print(k, v)
os.environ['CUDA_VISIBLE_DEVICES'] = str(opts.gpu_id)
model = Inference(opts).cuda()
model.run(opts.img_path)