Skip to content

Commit

Permalink
Added predictor fixes and test method
Browse files Browse the repository at this point in the history
  • Loading branch information
okupyn committed Nov 5, 2019
1 parent d565087 commit 441d0e5
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 2 deletions.
8 changes: 6 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __call__(self, img: np.ndarray, mask: Optional[np.ndarray], ignore_mask=True
def main(img_pattern: str,
mask_pattern: Optional[str] = None,
weights_path='best_fpn.h5',
out_dir='submit/'
):
out_dir='submit/',
side_by_side: bool = False):
def sorted_glob(pattern):
return sorted(glob(pattern))

Expand All @@ -87,8 +87,12 @@ def sorted_glob(pattern):
for name, pair in tqdm(zip(names, pairs), total=len(names)):
f_img, f_mask = pair
img, mask = map(cv2.imread, (f_img, f_mask))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

pred = predictor(img, mask)
if side_by_side:
pred = np.hstack((img, pred))
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(out_dir, name),
pred)

Expand Down
90 changes: 90 additions & 0 deletions test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import print_function
import argparse
import numpy as np
import torch
import cv2
import yaml
import os
from torchvision import models, transforms
from torch.autograd import Variable
import shutil
import glob
import tqdm
from util.metrics import PSNR
from albumentations import Compose, CenterCrop, PadIfNeeded
from PIL import Image
from ssim.ssimlib import SSIM
from models.networks import get_generator


def get_args():
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--img_folder', required=True, help='GoPRO Folder')
parser.add_argument('--weights_path', required=True, help='Weights path')

return parser.parse_args()


def prepare_dirs(path):
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)


def get_gt_image(path):
dir, filename = os.path.split(path)
base, seq = os.path.split(dir)
base, _ = os.path.split(base)
img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB)
return img


def test_image(model, image_path):
img_transforms = transforms.Compose([
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
size_transform = Compose([
PadIfNeeded(736, 1280)
])
crop = CenterCrop(720, 1280)
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_s = size_transform(image=img)['image']
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
img_tensor = img_transforms(img_tensor)
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
result_image = model(img_tensor)
result_image = result_image[0].cpu().float().numpy()
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
result_image = crop(image=result_image)['image']
result_image = result_image.astype('uint8')
gt_image = get_gt_image(image_path)
_, filename = os.path.split(image_path)
psnr = PSNR(result_image, gt_image)
pilFake = Image.fromarray(result_image)
pilReal = Image.fromarray(gt_image)
ssim = SSIM(pilFake).cw_ssim_value(pilReal)
return psnr, ssim


def test(model, files):
psnr = 0
ssim = 0
for file in tqdm.tqdm(files):
cur_psnr, cur_ssim = test_image(model, file)
psnr += cur_psnr
ssim += cur_ssim
print("PSNR = {}".format(psnr / len(files)))
print("SSIM = {}".format(ssim / len(files)))


if __name__ == '__main__':
args = get_args()
with open('config/config.yaml') as cfg:
config = yaml.load(cfg)
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path)['model'])
model = model.cuda()
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True))
test(model, filenames)

0 comments on commit 441d0e5

Please sign in to comment.