forked from VITA-Group/DeblurGANv2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added predictor fixes and test method
- Loading branch information
okupyn
committed
Nov 5, 2019
1 parent
d565087
commit 441d0e5
Showing
2 changed files
with
96 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from __future__ import print_function | ||
import argparse | ||
import numpy as np | ||
import torch | ||
import cv2 | ||
import yaml | ||
import os | ||
from torchvision import models, transforms | ||
from torch.autograd import Variable | ||
import shutil | ||
import glob | ||
import tqdm | ||
from util.metrics import PSNR | ||
from albumentations import Compose, CenterCrop, PadIfNeeded | ||
from PIL import Image | ||
from ssim.ssimlib import SSIM | ||
from models.networks import get_generator | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser('Test an image') | ||
parser.add_argument('--img_folder', required=True, help='GoPRO Folder') | ||
parser.add_argument('--weights_path', required=True, help='Weights path') | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def prepare_dirs(path): | ||
if os.path.exists(path): | ||
shutil.rmtree(path) | ||
os.makedirs(path) | ||
|
||
|
||
def get_gt_image(path): | ||
dir, filename = os.path.split(path) | ||
base, seq = os.path.split(dir) | ||
base, _ = os.path.split(base) | ||
img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB) | ||
return img | ||
|
||
|
||
def test_image(model, image_path): | ||
img_transforms = transforms.Compose([ | ||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||
]) | ||
size_transform = Compose([ | ||
PadIfNeeded(736, 1280) | ||
]) | ||
crop = CenterCrop(720, 1280) | ||
img = cv2.imread(image_path) | ||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
img_s = size_transform(image=img)['image'] | ||
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32')) | ||
img_tensor = img_transforms(img_tensor) | ||
with torch.no_grad(): | ||
img_tensor = Variable(img_tensor.unsqueeze(0).cuda()) | ||
result_image = model(img_tensor) | ||
result_image = result_image[0].cpu().float().numpy() | ||
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0 | ||
result_image = crop(image=result_image)['image'] | ||
result_image = result_image.astype('uint8') | ||
gt_image = get_gt_image(image_path) | ||
_, filename = os.path.split(image_path) | ||
psnr = PSNR(result_image, gt_image) | ||
pilFake = Image.fromarray(result_image) | ||
pilReal = Image.fromarray(gt_image) | ||
ssim = SSIM(pilFake).cw_ssim_value(pilReal) | ||
return psnr, ssim | ||
|
||
|
||
def test(model, files): | ||
psnr = 0 | ||
ssim = 0 | ||
for file in tqdm.tqdm(files): | ||
cur_psnr, cur_ssim = test_image(model, file) | ||
psnr += cur_psnr | ||
ssim += cur_ssim | ||
print("PSNR = {}".format(psnr / len(files))) | ||
print("SSIM = {}".format(ssim / len(files))) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = get_args() | ||
with open('config/config.yaml') as cfg: | ||
config = yaml.load(cfg) | ||
model = get_generator(config['model']) | ||
model.load_state_dict(torch.load(args.weights_path)['model']) | ||
model = model.cuda() | ||
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True)) | ||
test(model, filenames) |