Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyuzhi authored Nov 29, 2019
1 parent 648a3a9 commit e1417b7
Showing 1 changed file with 75 additions and 49 deletions.
124 changes: 75 additions & 49 deletions deepfillv2/test.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,81 @@
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 7 12:03:52 2018
@author: yzzhao2
"""

import argparse
import cv2
import numpy as np
import torch
from torchvision import transforms
from PIL import Image

def forward(size, root, model):
# pre-processing, let all the images are in RGB color space
img = Image.open(root)
img = img.resize((size, size), Image.ANTIALIAS).convert('RGB')
img = np.array(img).astype(np.float64)
# define a mask
mask = np.zeros([size, size, 1], dtype = np.float64)
if size == 144:
center = np.ones([100, 100, 1], dtype = np.float64)
mask[22:122, 22:122, :] = center
elif size == 200:
center = np.ones([144, 144, 1], dtype = np.float64)
mask[28:172, 28:172, :] = center
elif size == 256:
center = np.ones([200, 200, 1], dtype = np.float64)
mask[28:228, 28:228, :] = center
maskimg = (img * mask) / 255
maskimg = maskimg.astype(np.float32)
maskimg = transforms.ToTensor()(maskimg)
maskimg = maskimg.reshape([1, 3, size, size])
mask = mask.astype(np.float32)
mask = transforms.ToTensor()(mask)
mask = mask.reshape([1, 1, size, size])
maskimg = torch.cat((maskimg, mask), 1).cuda()
# get the output
output = model(maskimg)
# transfer to image
output = output.cpu().detach().numpy().reshape([3, size, size])
output = output.transpose(1, 2, 0)
output = output * 255
output = np.array(output, dtype = np.uint8)
return output
from torch.utils.data import DataLoader

import dataset

if __name__ == "__main__":

size = 256
root = 'C:\\Users\\ZHAO Yuzhi\\Desktop\\dataset\\COCO2014_val_256\\COCO_val2014_000000000285.jpg'
#model = torch.load('Pre_PRPGAN_1st_epoch5_batchsize8.pth')
model = torch.load('TestNet_epoch10_batchsize8.pth')
# ----------------------------------------
# Initialize the parameters
# ----------------------------------------
parser = argparse.ArgumentParser()
# Dataset parameters
parser.add_argument('--baseroot', type = str, default = "/home/alien/Documents/LINTingyu/inpainting/test", help = 'the testing folder')
parser.add_argument('--mask_type', type = str, default = 'free_form', help = 'mask type')
parser.add_argument('--imgsize', type = int, default = 256, help = 'size of image')
parser.add_argument('--margin', type = int, default = 10, help = 'margin of image')
parser.add_argument('--mask_num', type = int, default = 15, help = 'number of mask')
parser.add_argument('--bbox_shape', type = int, default = 30, help = 'margin of image for bbox mask')
parser.add_argument('--max_angle', type = int, default = 4, help = 'parameter of angle for free form mask')
parser.add_argument('--max_len', type = int, default = 40, help = 'parameter of length for free form mask')
parser.add_argument('--max_width', type = int, default = 10, help = 'parameter of width for free form mask')
# Other parameters
parser.add_argument('--batch_size', type = int, default = 1, help = 'test batch size, always 1')
parser.add_argument('--load_name', type = str, default = 'deepfillNet_epoch4_batchsize4.pth', help = 'test model name')
opt = parser.parse_args()
print(opt)

# ----------------------------------------
# Initialize testing dataset
# ----------------------------------------

# Define the dataset
testset = dataset.InpaintDataset(opt)
print('The overall number of images equals to %d' % len(testset))

# Define the dataloader
dataloader = DataLoader(testset, batch_size = opt.batch_size, pin_memory = True)

# ----------------------------------------
# Testing
# ----------------------------------------

model = torch.load(opt.load_name)

for batch_idx, (img, mask) in enumerate(dataloader):

# Load mask (shape: [B, 1, H, W]), masked_img (shape: [B, 3, H, W]), img (shape: [B, 3, H, W]) and put it to cuda
img = img.cuda()
mask = mask.cuda()

# Generator output
masked_img = img * (1 - mask)
fake1, fake2 = model(masked_img, mask)

# forward propagation
fusion_fake = img * (1 - mask) + fake2 * mask # in range [-1, 1]
img_1 = img

# show
img = img.cpu().numpy().reshape(3, opt.imgsize, opt.imgsize).transpose(1, 2, 0)
img = (img + 1) * 128
img = img.astype(np.uint8)
fusion_fake = fusion_fake.detach().cpu().numpy().reshape(3, opt.imgsize, opt.imgsize).transpose(1, 2, 0)
fusion_fake = (fusion_fake + 1) * 128
fusion_fake = fusion_fake.astype(np.uint8)

# forward propagation
fusion_fake_1 = img_1 * (1 - mask) + fake1 * mask # in range [-1, 1]

fusion_fake_1 = fusion_fake_1.detach().cpu().numpy().reshape(3, opt.imgsize, opt.imgsize).transpose(1, 2, 0)
fusion_fake_1 = (fusion_fake_1 + 1) * 128
fusion_fake_1 = fusion_fake_1.astype(np.uint8)

output = forward(size, root, model)
img = Image.fromarray(output)
img.show()
show_img = np.concatenate((img, fusion_fake_1, fusion_fake), axis = 1)
r, g, b = cv2.split(show_img)
show_img = cv2.merge([b, g, r])
cv2.imshow('comparison.jpg', show_img)
cv2.imwrite('result.jpg', show_img)

0 comments on commit e1417b7

Please sign in to comment.