From 9e287d422bd5e17b1499858aaebf87015c863f5f Mon Sep 17 00:00:00 2001 From: GE ZHENG Date: Fri, 25 Aug 2017 11:23:11 +0900 Subject: [PATCH] latest code change accept different input size image and resize it --- matting.py | 294 ++++++++++++++++------------------------------ matting_deconv.py | 149 ++++++++++------------- 2 files changed, 161 insertions(+), 282 deletions(-) diff --git a/matting.py b/matting.py index 927e919..9568dde 100644 --- a/matting.py +++ b/matting.py @@ -8,135 +8,7 @@ from sys import getrefcount import gc -input_image_size = 650 - -def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'): - - with tf.variable_scope(scope): - input_shape = pool.get_shape().as_list() - output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]) - - flat_input_size = np.prod(input_shape) - flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]] - - pool_ = tf.reshape(pool, [flat_input_size]) - batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1]) - b = tf.ones_like(ind) * batch_range - b = tf.reshape(b, [flat_input_size, 1]) - ind_ = tf.reshape(ind, [flat_input_size, 1]) - ind_ = tf.concat([b, ind_], 1) - - ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape) - ret = tf.reshape(ret, output_shape) - return ret - -def preprocessing(image_batch, GTmatte_batch, GTBG_batch, GTFG_batch, image_size=320): - - g_mean = np.array(([126.8898,120.2431,112.1959])).reshape([1,1,3]) - image_batch_shape = image_batch.shape - batch_size = image_batch_shape[0] - #generate trimap by random size (15~30) erosion and dilation - -# trimap_batch = copy.deepcopy(GTmatte_batch) - trimap_batch = np.copy(GTmatte_batch) - #trimap_batch = copy.deepcopy(GTmatte_batch) - trimap_batch = generate_trimap(trimap_batch,GTmatte_batch,batch_size) - - train_batch_pre = np.concatenate([image_batch,trimap_batch,GTmatte_batch,GTBG_batch,GTFG_batch],3) - train_batch = np.zeros([batch_size,image_size,image_size,11]) - for i in range(batch_size): - #print('%dth image is under processing...'%i) - crop_size = random.choice([320,480,640]) - flip = random.choice([0,1]) - # i_padding = center_padding(sess,tf.slice(train_batch_pre,[i,0,0,0],[1,image_batch_shape[1],image_batch_shape[2],11])) - i_padding = center_padding(train_batch_pre[i]) - #i_UR_center = UR_center(i_padding) - i_UR_center = UR_center(i_padding) - if crop_size == 320: - h_start_index = i_UR_center[0] - 159 - w_start_index = i_UR_center[1] - 159 - tmp = i_padding[h_start_index:h_start_index+320, w_start_index:w_start_index+320, :] - if flip: - tmp = tmp[:,::-1,:] - # tmp[:,:,:3] = tmp[:,:,:3] - mean - tmp[:,:,3:5] = tmp[:,:,3:5] / 255.0 - tmp[:,:,:3] -= g_mean - train_batch[i,:,:,:] = tmp - if crop_size == 480: - h_start_index = i_UR_center[0] - 239 - w_start_index = i_UR_center[1] - 239 - tmp = i_padding[h_start_index:h_start_index+480, w_start_index:w_start_index+480, :] - if flip: - tmp = tmp[:,::-1,:] - tmp1 = np.zeros([image_size,image_size,11]) - tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean - tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0 - tmp1[:,:,4] = binarilize_alpha(misc.imresize(tmp[:,:,4],[image_size,image_size]),60) / 255.0 - tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3]) - tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3]) - train_batch[i,:,:,:] = tmp1 - - if crop_size == 640: - h_start_index = i_UR_center[0] - 319 - w_start_index = i_UR_center[1] - 319 - tmp = i_padding[h_start_index:h_start_index+640, w_start_index:w_start_index+640, :] - if flip: - tmp = tmp[:,::-1,:] - tmp1 = np.zeros([image_size,image_size,11]) - tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean - tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0 - tmp1[:,:,4] = binarilize_alpha(misc.imresize(tmp[:,:,4],[image_size,image_size]),60) / 255.0 - tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3]) - tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3]) - train_batch[i,:,:,:] = tmp1 - gc.collect() - # print('tmp %d' %getrefcount(tmp)) - # print('tmp %d' %getrefcount(tmp1)) - # print('tmp %d' %getrefcount(train_batch_pre)) - # print('tmp %d' %getrefcount(trimap_batch)) - train_batch = train_batch.astype(np.float32) - - return train_batch[:,:,:,:3],np.expand_dims(train_batch[:,:,:,3],3),np.expand_dims(train_batch[:,:,:,4],3),train_batch[:,:,:,5:8],train_batch[:,:,:,8:] #return input of CNN, and transformed GT alpha matte, GTBG,GTFG - -def binarilize_alpha(alpha, threshold): - alpha[np.where(alpha<=threshold)] = 0 - alpha[np.where(alpha>threshold)] = 255 - return alpha - -def center_padding(image): - ''' - image consists 11 channel (images, trimap, GT alpha matte, GTBG_batch, GTFG_batch) - padding images to 2000*2000 - ''' - - # image_shape = image.get_shape().as_list() - # print(image_shape) - # h_center = (image_shape[0]-1)//2 - # w_center = (image_shape[1]-1)//2 - # pad_image = np.zeros([2000,2000,11]) - # h_start_index = 999-h_center - # h_end_index = h_start_index + image_shape[0] - # w_start_index = 999-w_center - # w_end_index = w_start_index + image_shape[1] - # pad_image[h_start_index:h_end_index,w_start_index:w_end_index,:] = sess.run(image) - # return pad_image - - image_shape = image.shape - h_center = (image_shape[0]-1)//2 - w_center = (image_shape[1]-1)//2 - pad_image = np.zeros([1300,1300,11]) - h_start_index = 649-h_center - h_end_index = h_start_index + image_shape[0] - w_start_index = 649-w_center - w_end_index = w_start_index + image_shape[1] - pad_image[h_start_index:h_end_index,w_start_index:w_end_index,:] = image - return pad_image - -def show(image): - for i in range(image.shape[0]): - for j in range(image.shape[1]): - if image[i][j]!=0 and image[i][j] !=1 and image[i][j] !=0.5: - print([i,j]) +trimap_kernel = [val for val in range(20,35)] def UR_center(image): ''' @@ -144,83 +16,119 @@ def UR_center(image): centered on unknown region ''' trimap = image[:,:,3] -# UR = [[i,j] for i, j in itertools.product(range(trimap.shape[0]), range(trimap.shape[1])) if trimap[i,j] == 128] target = np.where(trimap==127.5) - # show(trimap) index = random.choice([i for i in range(len(target[0]))]) return np.array(target)[:,index][:2] -# return [int(i) for i in np.array(UR).mean(0)] -# return random.choice(UR) - - # trimap = tf.convert_to_tensor(image[:,:,3]) - # condition = tf.equal(trimap,128) - # indices = tf.where(condition) - # return random.choice(indices) def composition_RGB(BG,FG,p_matte): GB = tf.convert_to_tensor(BG) FG = tf.convert_to_tensor(FG) return p_matte * FG + (1 - p_matte) * BG -def global_mean(RGB_folder): - RGBs = os.listdir(RGB_folder) - num = len(RGBs) - ite = num // 100 - sum_tmp = [] - print(ite) - for i in range(ite): - print(i) - batch_tmp = np.zeros([input_image_size,input_image_size,3]) - RGBs_tmp = [os.path.join(RGB_folder,RGB_path) for RGB_path in RGBs[i*100:(i+1)*100]] - for RGB in RGBs_tmp: - batch_tmp += misc.imread(RGB) - sum_tmp.append(batch_tmp.sum(axis = 0).sum(axis = 0) / (input_image_size*input_image_size*100)) - return np.array(sum_tmp).mean(axis = 0) +# def global_mean(RGB_folder): +# RGBs = os.listdir(RGB_folder) +# num = len(RGBs) +# ite = num // 100 +# sum_tmp = [] +# print(ite) +# for i in range(ite): +# print(i) +# batch_tmp = np.zeros([320,320,3]) +# RGBs_tmp = [os.path.join(RGB_folder,RGB_path) for RGB_path in RGBs[i*100:(i+1)*100]] +# for RGB in RGBs_tmp: +# batch_tmp += misc.imread(RGB) +# sum_tmp.append(batch_tmp.sum(axis = 0).sum(axis = 0) / (320*320*100)) +# return np.array(sum_tmp).mean(axis = 0) def load_path(dataset_RGB,alpha,FG,BG): RGBs_path = os.listdir(dataset_RGB) RGBs_abspath = [os.path.join(dataset_RGB,RGB) for RGB in RGBs_path] - alphas_abspath = [os.path.join(alpha,RGB.split('-')[0]+'.png') for RGB in RGBs_path] - FGs_abspath = [os.path.join(FG,RGB.split('-')[0]+'.png') for RGB in RGBs_path] + alphas_abspath = [os.path.join(alpha,RGB.split('-')[0],RGB.split('-')[1]) for RGB in RGBs_path] + FGs_abspath = [os.path.join(FG,RGB.split('-')[0],RGB.split('-')[1]) for RGB in RGBs_path] BGs_abspath = [os.path.join(BG,RGB.split('-')[0],RGB.split('-')[1][:-3]+'jpg') for RGB in RGBs_path] - return RGBs_abspath,alphas_abspath,FGs_abspath,BGs_abspath - -def load_data(batch_RGB_paths,batch_alpha_paths,batch_FG_paths,batch_BG_paths): - batch_RGBs = [] - - for path in batch_RGB_paths: - file_contents = tf.read_file(tf.convert_to_tensor(path)) - image = tf.image.decode_png(file_contents) - batch_RGBs.append(image) - batch_RGBs = tf.stack(batch_RGBs) + return np.array(RGBs_abspath),np.array(alphas_abspath),np.array(FGs_abspath),np.array(BGs_abspath) - batch_alphas = [] - for path in batch_alpha_paths: - file_contents = tf.read_file(tf.convert_to_tensor(path)) - image = tf.image.decode_png(file_contents) - batch_alphas.append(image) - batch_alphas = tf.stack(batch_alphas) - - batch_FGs = [] - for path in batch_FG_paths: - file_contents = tf.read_file(tf.convert_to_tensor(path)) - image = tf.image.decode_png(file_contents) - batch_FGs.append(image) - batch_FGs = tf.stack(batch_FGs) - - batch_BGs = [] - for path in batch_BG_paths: - file_contents = tf.read_file(tf.convert_to_tensor(path)) - image = tf.image.decode_jpeg(file_contents) - batch_BGs.append(image) - batch_BGs = tf.stack(batch_BGs) - return batch_RGBs,batch_alphas,batch_FGs,batch_BGs - -def generate_trimap(trimap_batch,GTmatte_batch,batch_size): - kernel = [val for val in range(15,31)] +def load_data(sess,batch_RGB_paths,batch_alpha_paths,batch_FG_paths,batch_BG_paths): + + batch_size = batch_RGB_paths.shape[0] + train_batch = [] for i in range(batch_size): - k_size = random.choice(kernel) - trimap_batch[i][np.where((ndimage.grey_dilation(GTmatte_batch[i][:,:,0],size=(k_size,k_size)) - ndimage.grey_erosion(GTmatte_batch[i][:,:,0],size=(k_size,k_size)))!=0)] = 127.5 - return trimap_batch - - + comp_RGB_content = tf.read_file(tf.convert_to_tensor(batch_RGB_paths[i])) + comp_RGB = tf.cast(tf.image.decode_png(comp_RGB_content),tf.float32) + + alpha_content = tf.read_file(tf.convert_to_tensor(batch_alpha_paths[i])) + alpha = tf.cast(tf.image.decode_png(alpha_content),tf.float32) + + FG_content = tf.read_file(tf.convert_to_tensor(batch_FG_paths[i])) + FG = tf.cast(tf.image.decode_png(FG_content),tf.float32) + + BG_content = tf.read_file(tf.convert_to_tensor(batch_BG_paths[i])) + BG = tf.cast(tf.image.decode_jpeg(BG_content),tf.float32) + print('##########') + batch_i = tf.py_func(preprocessing_single,[comp_RGB, alpha, BG, FG, 320],tf.float32) + batch_i.set_shape([320,320,11]) + train_batch.append(batch_i) + train_batch = sess.run(tf.stack(train_batch)) + return train_batch[:,:,:,:3],np.expand_dims(train_batch[:,:,:,3],3),np.expand_dims(train_batch[:,:,:,4],3),train_batch[:,:,:,5:8],train_batch[:,:,:,8:] + +def generate_trimap(trimap,alpha): + + k_size = random.choice(trimap_kernel) + trimap[np.where((ndimage.grey_dilation(alpha[:,:,0],size=(k_size,k_size)) - ndimage.grey_erosion(alpha[:,:,0],size=(k_size,k_size)))!=0)] = 127.5 + return trimap + +def preprocessing_single(comp_RGB, alpha, BG, FG, image_size=320): + + g_mean = np.array(([126.88987627,120.24313843,112.19594981])).reshape([1,1,3]) + trimap = np.copy(alpha) + #trimap_batch = copy.deepcopy(GTmatte_batch) + trimap = generate_trimap(trimap,alpha) + + train_pre = np.concatenate([comp_RGB,trimap,alpha,BG,FG],2) + train_data = np.zeros([image_size,image_size,11]) + crop_size = random.choice([320,480,640]) + print(crop_size) + flip = random.choice([0,1]) + print(flip) + i_UR_center = UR_center(train_pre) + + if crop_size == 320: + h_start_index = i_UR_center[0] - 159 + w_start_index = i_UR_center[1] - 159 + tmp = train_pre[h_start_index:h_start_index+320, w_start_index:w_start_index+320, :] + if flip: + tmp = tmp[:,::-1,:] + # tmp[:,:,:3] = tmp[:,:,:3] - mean + tmp[:,:,3:5] = tmp[:,:,3:5] / 255.0 + tmp[:,:,:3] -= g_mean + train_data = tmp + + if crop_size == 480: + h_start_index = i_UR_center[0] - 239 + w_start_index = i_UR_center[1] - 239 + tmp = train_pre[h_start_index:h_start_index+480, w_start_index:w_start_index+480, :] + if flip: + tmp = tmp[:,::-1,:] + tmp1 = np.zeros([image_size,image_size,11]) + tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean + tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0 + tmp1[:,:,4] = misc.imresize(tmp[:,:,4],[image_size,image_size]) / 255.0 + tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3]) + tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3]) + train_data = tmp1 + + if crop_size == 640: + h_start_index = i_UR_center[0] - 319 + w_start_index = i_UR_center[1] - 319 + tmp = train_pre[h_start_index:h_start_index+640, w_start_index:w_start_index+640, :] + if flip: + tmp = tmp[:,::-1,:] + tmp1 = np.zeros([image_size,image_size,11]) + tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean + tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0 + tmp1[:,:,4] = misc.imresize(tmp[:,:,4],[image_size,image_size]) / 255.0 + tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3]) + tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3]) + train_data = tmp1 + train_data = train_data.astype(np.float32) + return train_data \ No newline at end of file diff --git a/matting_deconv.py b/matting_deconv.py index 5c817e2..5b0bec4 100644 --- a/matting_deconv.py +++ b/matting_deconv.py @@ -1,21 +1,26 @@ +''' +deconv_simple_v2.py +change the input image size and rearrange all data. + +''' + import tensorflow as tf import numpy as np -from matting import unpool,preprocessing,composition_RGB,load_path,load_data +from matting import composition_RGB,load_path,load_data import os from scipy import misc image_size = 320 -input_image_size = 650 batch_size = 25 max_epochs = 1000000 #pretrained_vgg_model_path model_path = './vgg16_weights.npz' log_dir = './tensor_log' -dataset_RGB = '/data/gezheng/data-matting/comp_RGB' -dataset_alpha = './alpha_final' -dataset_FG = './FG_final' -dataset_BG = '/data/gezheng/data-matting/BG' +dataset_RGB = '/data/gezheng/data-matting/new/comp_RGB' +dataset_alpha = '/data/gezheng/data-matting/new/alpha_final' +dataset_FG = '/data/gezheng/data-matting/new/FG_final' +dataset_BG = '/data/gezheng/data-matting/new/BG' paths_RGB,paths_alpha,paths_FG,paths_BG = load_path(dataset_RGB,dataset_alpha,dataset_FG,dataset_BG) @@ -26,35 +31,32 @@ index_queue = tf.train.range_input_producer(range_size, num_epochs=None,shuffle=True, seed=None, capacity=32) index_dequeue_op = index_queue.dequeue_many(batch_size, 'index_dequeue') -image_batch = tf.placeholder(tf.float32, shape=(batch_size,input_image_size,input_image_size,3)) -GT_matte_batch = tf.placeholder(tf.float32, shape = (batch_size,input_image_size,input_image_size,1)) -GTBG_batch = tf.placeholder(tf.float32, shape = (batch_size,input_image_size,input_image_size,3)) -GTFG_batch = tf.placeholder(tf.float32, shape = (batch_size,input_image_size,input_image_size,3)) +image_batch = tf.placeholder(tf.float32, shape=(batch_size,image_size,image_size,3)) +GT_matte_batch = tf.placeholder(tf.float32, shape = (batch_size,image_size,image_size,1)) +GT_trimap = tf.placeholder(tf.float32, shape = (batch_size,image_size,image_size,1)) +GTBG_batch = tf.placeholder(tf.float32, shape = (batch_size,image_size,image_size,3)) +GTFG_batch = tf.placeholder(tf.float32, shape = (batch_size,image_size,image_size,3)) is_train = tf.placeholder(tf.bool, name = 'is_train') en_parameters = [] -#if training: -#b_input, b_GTmatte ,b_GTBG, b_GTFG = preprocessing(image_batch,GT_matte_batch,GTBG_batch,GTFG_batch,image_size) -#[b_input, b_GTmatte ,b_GTBG, b_GTFG] = tf.py_func(preprocessing,[image_batch,GT_matte_batch,GTBG_batch,GTFG_batch,image_size],tf.float32) -if is_train: - b_RGB,b_trimap, b_GTmatte ,b_GTBG, b_GTFG = tf.py_func(preprocessing,[image_batch,GT_matte_batch,GTBG_batch,GTFG_batch,image_size],[tf.float32,tf.float32,tf.float32,tf.float32,tf.float32]) -else: - pass - -tf.summary.image('GT_alpha',b_GTmatte,max_outputs = 5) -tf.summary.image('trimap',b_trimap,max_outputs = 5) - -b_RGB.set_shape([batch_size,image_size,image_size,3]) -b_trimap.set_shape([batch_size,image_size,image_size,1]) + +tf.summary.image('GT_matte_batch',GT_matte_batch,max_outputs = 5) +tf.summary.image('trimap',GT_trimap,max_outputs = 5) +tf.summary.image('image_batch',image_batch,max_outputs = 5) + +# b_RGB.set_shape([batch_size,image_size,image_size,3]) +# b_trimap.set_shape([batch_size,image_size,image_size,1]) +# b_input = tf.concat([b_RGB,b_trimap],3) +# b_GTmatte.set_shape([batch_size,image_size,image_size,1]) +# b_GTBG.set_shape([batch_size,image_size,image_size,3]) +# b_GTFG.set_shape([batch_size,image_size,image_size,3]) + +b_RGB = tf.identity(image_batch,name = 'b_RGB') +b_trimap = tf.identity(GT_trimap,name = 'b_trimap') +b_GTmatte = tf.identity(GT_matte_batch,name = 'b_GTmatte') +b_GTBG = tf.identity(GTBG_batch,name = 'b_GTBG') +b_GTFG = tf.identity(GTFG_batch,name = 'b_GTFG') + b_input = tf.concat([b_RGB,b_trimap],3) -b_GTmatte.set_shape([batch_size,image_size,image_size,1]) -b_GTBG.set_shape([batch_size,image_size,image_size,3]) -b_GTFG.set_shape([batch_size,image_size,image_size,3]) - -b_RGB = tf.identity(b_RGB,name = 'b_RGB') -b_trimap = tf.identity(b_trimap,name = 'b_trimap') -b_GTmatte = tf.identity(b_GTmatte,name = 'b_GTmatte') -b_GTBG = tf.identity(b_GTBG,name = 'b_GTBG') -b_GTFG = tf.identity(b_GTFG,name = 'b_GTFG') #else: # preprocessing(image_batch) @@ -224,68 +226,33 @@ training = True -#deconv6_1 -with tf.variable_scope('deconv6_1') as scope: - outputs = tf.layers.conv2d_transpose(pool5, 512, [1, 1], strides=(1, 1), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) - #deconv6_1 = tf.nn.relu(outputs) - deconv6_1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) #deconv6_2 with tf.variable_scope('deconv6_2') as scope: - outputs = tf.layers.conv2d_transpose(deconv6_1, 512, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) + outputs = tf.layers.conv2d_transpose(pool5, 512, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) #deconv6_2 = tf.nn.relu(outputs) deconv6_2 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) -#deconv5_1 -with tf.variable_scope('deconv5_1') as scope: - outputs = tf.layers.conv2d_transpose(deconv6_2, 512, [5, 5], strides=(1, 1), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) - #deconv5_1 = tf.nn.relu(outputs) - deconv5_1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) - -#deconv5_2 -with tf.variable_scope('deconv5_2') as scope: - outputs = tf.layers.conv2d_transpose(deconv5_1, 512, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) - #deconv5_2 = tf.nn.relu(outputs) - deconv5_2 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) - -#deconv4_1 -with tf.variable_scope('deconv4_1') as scope: - outputs = tf.layers.conv2d_transpose(deconv5_2, 256, [5, 5], strides=(1, 1), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) - #deconv4_1= tf.nn.relu(outputs) - deconv4_1 =tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) - #deconv4_2 -with tf.variable_scope('deconv4_2') as scope: - outputs = tf.layers.conv2d_transpose(deconv4_1, 256, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) +with tf.variable_scope('deconv4_1') as scope: + outputs = tf.layers.conv2d_transpose(deconv6_2, 256, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) #deconv4_2= tf.nn.relu(outputs) - deconv4_2 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) - -#deconv3_1 -with tf.variable_scope('deconv3_1') as scope: - outputs = tf.layers.conv2d_transpose(deconv4_2, 128, [5, 5], strides=(1, 1), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) - #deconv3_1 = tf.nn.relu(outputs) - deconv3_1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) + deconv4_1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) #deconv3_2 -with tf.variable_scope('deconv3_2') as scope: - outputs = tf.layers.conv2d_transpose(deconv3_1, 128, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) +with tf.variable_scope('deconv3_1') as scope: + outputs = tf.layers.conv2d_transpose(deconv4_1, 128, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) #deconv3_2 = tf.nn.relu(outputs) - deconv3_2 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) - -#deconv2_1 -with tf.variable_scope('deconv2_1') as scope: - outputs = tf.layers.conv2d_transpose(deconv3_2, 64, [5, 5], strides=(1, 1), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) - #deconv2_1 = tf.nn.relu(outputs) - deconv2_1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) + deconv3_1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) #deconv2_2 -with tf.variable_scope('deconv2_2') as scope: - outputs = tf.layers.conv2d_transpose(deconv2_1, 64, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) +with tf.variable_scope('deconv2_1') as scope: + outputs = tf.layers.conv2d_transpose(deconv3_1, 64, [3, 3], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) # deconv2_2 = tf.nn.relu(outputs) - deconv2_2 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) + deconv2_1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) #deconv1 with tf.variable_scope('deconv1') as scope: - outputs = tf.layers.conv2d_transpose(deconv2_2, 64, [5, 5], strides=(1, 1), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) + outputs = tf.layers.conv2d_transpose(deconv2_1, 32, [5, 5], strides=(2, 2), padding='SAME', kernel_initializer=tf.contrib.layers.xavier_initializer()) #deconv1 = tf.nn.relu(outputs) deconv1 = tf.nn.relu(tf.layers.batch_normalization(outputs,training=training)) @@ -320,14 +287,14 @@ total_loss = tf.reduce_sum(wl * alpha_diff + (2-wl) * c_diff) / batch_size tf.summary.scalar('total_loss',total_loss) global_step = tf.Variable(0,trainable=False) - train_op = tf.train.AdamOptimizer(learning_rate = 1e-6).minimize(total_loss,global_step = global_step) + train_op = tf.train.AdamOptimizer(learning_rate = 1e-5).minimize(total_loss,global_step = global_step) coord = tf.train.Coordinator() summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(log_dir, tf.get_default_graph()) -gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.5) +#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.5) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) tf.train.start_queue_runners(coord=coord,sess=sess) @@ -345,22 +312,26 @@ else: sess.run(en_parameters[i].assign(weights[k])) print('finish loading vgg16 model') - #load train data + #load train data while epoch_num < max_epochs: - print('epoch %d' % epoch_num) + print('epoch %d' % epoch_num) while batch_num < batchs_per_epoch: print('batch %d, loading batch data...' % batch_num) batch_index = sess.run(index_dequeue_op) - batch_RGB_paths = np.array(paths_RGB)[batch_index] - batch_alpha_paths = np.array(paths_alpha)[batch_index] - batch_FG_paths = np.array(paths_FG)[batch_index] - batch_BG_paths = np.array(paths_BG)[batch_index] - - batch_RGBs,batch_alphas,batch_FGs,batch_BGs = load_data(batch_RGB_paths,batch_alpha_paths,batch_FG_paths,batch_BG_paths) - feed = {image_batch:batch_RGBs.eval(), GT_matte_batch:batch_alphas.eval(), GTBG_batch:batch_BGs.eval(), GTFG_batch:batch_FGs.eval(),is_train:True} + batch_RGB_paths = paths_RGB[batch_index] + batch_alpha_paths = paths_alpha[batch_index] + batch_FG_paths = paths_FG[batch_index] + batch_BG_paths = paths_BG[batch_index] + print('finish loading path') + batch_RGBs,batch_trimaps,batch_alphas,batch_FGs,batch_BGs = load_data(sess,batch_RGB_paths,batch_alpha_paths,batch_FG_paths,batch_BG_paths) + + # feed = {image_batch:batch_RGBs.eval(), GT_matte_batch:batch_alphas.eval(),GT_trimap:batch_trimaps, GTBG_batch:batch_BGs.eval(), GTFG_batch:batch_FGs.eval(),is_train:True} + feed = {image_batch:batch_RGBs, GT_matte_batch:batch_alphas,GT_trimap:batch_trimaps, GTBG_batch:batch_BGs, GTFG_batch:batch_FGs,is_train:True} _,loss,summary_str,step,p_mattes = sess.run([train_op,total_loss,summary_op,global_step,pred_mattes],feed_dict = feed) misc.imsave('./predict/alpha.jpg',p_mattes[0,:,:,0]) summary_writer.add_summary(summary_str,global_step = step) print('loss is %f' %loss) batch_num += 1 epoch_num += 1 + +