Skip to content

Commit

Permalink
add InstanceNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
HypoX64 committed Jan 16, 2020
1 parent 29458f1 commit a6994b5
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ test*/
video_tmp/
result/
#./
/pix2pix
/pix2pixHD
/tmp
/to_make_show
Expand Down
2 changes: 1 addition & 1 deletion models/pix2pix_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def init_func(m): # define the initialization function
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)

print('initialize network with %s' % init_type)
#print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>


Expand Down
2 changes: 1 addition & 1 deletion models/unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ def forward(self, x):
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.outc(x)
return torch.sigmoid(x)
return torch.Tanh(x)
26 changes: 18 additions & 8 deletions models/video_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@
from .unet_parts import *
from .pix2pix_model import *

Norm = 'batch'
if Norm == 'instance':
NormLayer_2d = nn.InstanceNorm2d
NormLayer_3d = nn.InstanceNorm3d
use_bias = False
else:
NormLayer_2d = nn.BatchNorm2d
NormLayer_3d = nn.BatchNorm3d
use_bias = True

class encoder_2d(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""

def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=NormLayer_2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
Expand Down Expand Up @@ -55,7 +65,7 @@ class decoder_2d(nn.Module):
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""

def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=NormLayer_2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
Expand Down Expand Up @@ -114,8 +124,8 @@ class conv_3d(nn.Module):
def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
super(conv_3d, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
nn.BatchNorm3d(outchannel),
nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias),
NormLayer_3d(outchannel),
nn.ReLU(inplace=True),
)

Expand All @@ -128,8 +138,8 @@ def __init__(self,inchannel,outchannel,kernel_size=3,stride=1,padding=1):
super(conv_2d, self).__init__()
self.conv = nn.Sequential(
nn.ReflectionPad2d(padding),
nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(outchannel),
nn.Conv2d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=0, bias=use_bias),
NormLayer_2d(outchannel),
nn.ReLU(inplace=True),
)

Expand All @@ -145,8 +155,8 @@ def __init__(self,in_channel):
self.down2 = conv_3d(64, 128, 3, 2, 1)
self.down3 = conv_3d(128, 256, 3, 1, 1)
self.conver2d = nn.Sequential(
nn.Conv2d(256*int(in_channel/4), 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.Conv2d(256*int(in_channel/4), 256, kernel_size=3, stride=1, padding=1, bias=use_bias),
NormLayer_2d(256),
nn.ReLU(inplace=True),
)

Expand Down
4 changes: 2 additions & 2 deletions train/add/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def Toinputshape(imgs,masks,finesize):
# print(imgs[i].shape,masks[i].shape)
img,mask = data.random_transform_image(imgs[i], masks[i], finesize)
# print(img.shape,mask.shape)
mask = mask.reshape(1,finesize,finesize)/255.0
img = img.transpose((2, 0, 1))/255.0
mask = (mask.reshape(1,finesize,finesize)/255.0-0.5)/0.5
img = (img.transpose((2, 0, 1))/255.0-0.5)/0.5
result_imgs.append(img)
result_masks.append(mask)
result_imgs = np.array(result_imgs)
Expand Down
58 changes: 28 additions & 30 deletions train/clean/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@

N = 25
ITER = 10000000
LR = 0.001
LR = 0.0002
beta1 = 0.5
use_gpu = True
use_gan = False
use_L2 = True
CONTINUE = True
lambda_L1 = 1.0#100.0
lambda_gan = 1.0
CONTINUE = False
lambda_L1 = 100.0
lambda_gan = 1

SAVE_FRE = 10000
start_iter = 0
finesize = 128
loadsize = int(finesize*1.1)
batchsize = 8
perload_num = 32
batchsize = 1
perload_num = 16
savename = 'MosaicNet_test'
dir_checkpoint = 'checkpoints/'+savename
util.makedirs(dir_checkpoint)
Expand All @@ -45,6 +45,7 @@
videos = os.listdir('./dataset')
videos.sort()
lengths = []
print('check dataset...')
for video in videos:
video_images = os.listdir('./dataset/'+video+'/ori')
lengths.append(len(video_images))
Expand All @@ -55,7 +56,8 @@
loadmodel.show_paramsnumber(netG,'netG')
# netG = unet_model.UNet(3*N+1, 3)
if use_gan:
netD = pix2pix_model.define_D(3*2+1, 64, 'basic', n_layers_D=3, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])
#netD = pix2pix_model.define_D(3*2+1, 64, 'pixel', norm='instance')
netD = pix2pix_model.define_D(3*2+1, 64, 'basic', norm='instance')
#netD = pix2pix_model.define_D(3*2+1, 64, 'n_layers', n_layers_D=5, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[])

if CONTINUE:
Expand Down Expand Up @@ -104,26 +106,19 @@ def loaddata():
return input_img,ground_true

print('preloading data, please wait 5s...')
# input_imgs=[]
# ground_trues=[]
input_imgs = torch.rand(batchsize,N*3+1,finesize,finesize).cuda()
ground_trues = torch.rand(batchsize,3,finesize,finesize).cuda()

if perload_num <= batchsize:
perload_num = batchsize*2
input_imgs = torch.rand(perload_num,N*3+1,finesize,finesize).cuda()
ground_trues = torch.rand(perload_num,3,finesize,finesize).cuda()
load_cnt = 0

def preload():
global load_cnt
while 1:
try:
# input_img,ground_true = loaddata()
# input_imgs.append(input_img)
# ground_trues.append(ground_true)
ran = random.randint(0, batchsize-1)
ran = random.randint(0, perload_num-1)
input_imgs[ran],ground_trues[ran] = loaddata()


# if len(input_imgs)>perload_num:
# del(input_imgs[0])
# del(ground_trues[0])
load_cnt += 1
# time.sleep(0.1)
except Exception as e:
Expand All @@ -133,21 +128,24 @@ def preload():
t = threading.Thread(target=preload,args=()) #t为新创建的线程
t.daemon = True
t.start()
while load_cnt < batchsize*2:

time_start=time.time()
while load_cnt < perload_num:
time.sleep(0.1)
time_end=time.time()
print('load speed:',round((time_end-time_start)/perload_num,3),'s/it')


util.copyfile('./train.py', os.path.join(dir_checkpoint,'train.py'))
util.copyfile('../../models/video_model.py', os.path.join(dir_checkpoint,'model.py'))
netG.train()
time_start=time.time()
print("Begin training...")
for iter in range(start_iter+1,ITER):

# inputdata,target = loaddata()
# ran = random.randint(1, perload_num-2)
# inputdata = inputdatas[ran]
# target = targets[ran]

inputdata = input_imgs.clone()
target = ground_trues.clone()
ran = random.randint(0, perload_num-batchsize-1)
inputdata = input_imgs[ran:ran+batchsize].clone()
target = ground_trues[ran:ran+batchsize].clone()

pred = netG(inputdata)

Expand Down Expand Up @@ -262,13 +260,13 @@ def preload():
netG.eval()

test_names = os.listdir('./test')
test_names.sort()
result = np.zeros((finesize*2,finesize*len(test_names),3), dtype='uint8')

for cnt,test_name in enumerate(test_names,0):
img_names = os.listdir(os.path.join('./test',test_name,'image'))
img_names.sort()
inputdata = np.zeros((finesize,finesize,3*N+1), dtype='uint8')
img_names.sort()
for i in range(0,N):
img = impro.imread(os.path.join('./test',test_name,'image',img_names[i]))
img = impro.resize(img,finesize)
Expand All @@ -286,4 +284,4 @@ def preload():
result[finesize:finesize*2,finesize*cnt:finesize*(cnt+1),:] = pred

cv2.imwrite(os.path.join(dir_checkpoint,str(iter+1)+'_test.png'), result)
netG.train()
netG.train()
8 changes: 4 additions & 4 deletions util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def random_transform_video(src,target,finesize,N):
target = target[:,::-1,:]

#random color
alpha = random.uniform(-0.2,0.2)
alpha = random.uniform(-0.3,0.3)
beta = random.uniform(-0.2,0.2)
b = random.uniform(-0.1,0.1)
g = random.uniform(-0.1,0.1)
r = random.uniform(-0.1,0.1)
b = random.uniform(-0.05,0.05)
g = random.uniform(-0.05,0.05)
r = random.uniform(-0.05,0.05)
for i in range(N):
src[:,:,i*3:(i+1)*3] = color_adjust(src[:,:,i*3:(i+1)*3],alpha,beta,b,g,r)
target = color_adjust(target,alpha,beta,b,g,r)
Expand Down
8 changes: 7 additions & 1 deletion util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,10 @@ def get_bar(percent,num = 25):
else:
bar += '-'
bar += ']'
return bar+' '+str(round(percent,2))+'%'
return bar+' '+str(round(percent,2))+'%'

def copyfile(scr,dst):
try:
shutil.copyfile(src, dst)
except Exception as e:
print(e)

0 comments on commit a6994b5

Please sign in to comment.