Skip to content

Commit

Permalink
latest code change accept different input size image and resize it
Browse files Browse the repository at this point in the history
  • Loading branch information
GE ZHENG committed Aug 25, 2017
1 parent 7fd0049 commit 9e287d4
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 282 deletions.
294 changes: 101 additions & 193 deletions matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,219 +8,127 @@
from sys import getrefcount
import gc

input_image_size = 650

def unpool(pool, ind, ksize=[1, 2, 2, 1], scope='unpool'):

with tf.variable_scope(scope):
input_shape = pool.get_shape().as_list()
output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3])

flat_input_size = np.prod(input_shape)
flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(output_shape[0], dtype=ind.dtype), shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1])
ind_ = tf.concat([b, ind_], 1)

ret = tf.scatter_nd(ind_, pool_, shape=flat_output_shape)
ret = tf.reshape(ret, output_shape)
return ret

def preprocessing(image_batch, GTmatte_batch, GTBG_batch, GTFG_batch, image_size=320):

g_mean = np.array(([126.8898,120.2431,112.1959])).reshape([1,1,3])
image_batch_shape = image_batch.shape
batch_size = image_batch_shape[0]
#generate trimap by random size (15~30) erosion and dilation

# trimap_batch = copy.deepcopy(GTmatte_batch)
trimap_batch = np.copy(GTmatte_batch)
#trimap_batch = copy.deepcopy(GTmatte_batch)
trimap_batch = generate_trimap(trimap_batch,GTmatte_batch,batch_size)

train_batch_pre = np.concatenate([image_batch,trimap_batch,GTmatte_batch,GTBG_batch,GTFG_batch],3)
train_batch = np.zeros([batch_size,image_size,image_size,11])
for i in range(batch_size):
#print('%dth image is under processing...'%i)
crop_size = random.choice([320,480,640])
flip = random.choice([0,1])
# i_padding = center_padding(sess,tf.slice(train_batch_pre,[i,0,0,0],[1,image_batch_shape[1],image_batch_shape[2],11]))
i_padding = center_padding(train_batch_pre[i])
#i_UR_center = UR_center(i_padding)
i_UR_center = UR_center(i_padding)
if crop_size == 320:
h_start_index = i_UR_center[0] - 159
w_start_index = i_UR_center[1] - 159
tmp = i_padding[h_start_index:h_start_index+320, w_start_index:w_start_index+320, :]
if flip:
tmp = tmp[:,::-1,:]
# tmp[:,:,:3] = tmp[:,:,:3] - mean
tmp[:,:,3:5] = tmp[:,:,3:5] / 255.0
tmp[:,:,:3] -= g_mean
train_batch[i,:,:,:] = tmp
if crop_size == 480:
h_start_index = i_UR_center[0] - 239
w_start_index = i_UR_center[1] - 239
tmp = i_padding[h_start_index:h_start_index+480, w_start_index:w_start_index+480, :]
if flip:
tmp = tmp[:,::-1,:]
tmp1 = np.zeros([image_size,image_size,11])
tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean
tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0
tmp1[:,:,4] = binarilize_alpha(misc.imresize(tmp[:,:,4],[image_size,image_size]),60) / 255.0
tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3])
tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3])
train_batch[i,:,:,:] = tmp1

if crop_size == 640:
h_start_index = i_UR_center[0] - 319
w_start_index = i_UR_center[1] - 319
tmp = i_padding[h_start_index:h_start_index+640, w_start_index:w_start_index+640, :]
if flip:
tmp = tmp[:,::-1,:]
tmp1 = np.zeros([image_size,image_size,11])
tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean
tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0
tmp1[:,:,4] = binarilize_alpha(misc.imresize(tmp[:,:,4],[image_size,image_size]),60) / 255.0
tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3])
tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3])
train_batch[i,:,:,:] = tmp1
gc.collect()
# print('tmp %d' %getrefcount(tmp))
# print('tmp %d' %getrefcount(tmp1))
# print('tmp %d' %getrefcount(train_batch_pre))
# print('tmp %d' %getrefcount(trimap_batch))
train_batch = train_batch.astype(np.float32)

return train_batch[:,:,:,:3],np.expand_dims(train_batch[:,:,:,3],3),np.expand_dims(train_batch[:,:,:,4],3),train_batch[:,:,:,5:8],train_batch[:,:,:,8:] #return input of CNN, and transformed GT alpha matte, GTBG,GTFG

def binarilize_alpha(alpha, threshold):
alpha[np.where(alpha<=threshold)] = 0
alpha[np.where(alpha>threshold)] = 255
return alpha

def center_padding(image):
'''
image consists 11 channel (images, trimap, GT alpha matte, GTBG_batch, GTFG_batch)
padding images to 2000*2000
'''

# image_shape = image.get_shape().as_list()
# print(image_shape)
# h_center = (image_shape[0]-1)//2
# w_center = (image_shape[1]-1)//2
# pad_image = np.zeros([2000,2000,11])
# h_start_index = 999-h_center
# h_end_index = h_start_index + image_shape[0]
# w_start_index = 999-w_center
# w_end_index = w_start_index + image_shape[1]
# pad_image[h_start_index:h_end_index,w_start_index:w_end_index,:] = sess.run(image)
# return pad_image

image_shape = image.shape
h_center = (image_shape[0]-1)//2
w_center = (image_shape[1]-1)//2
pad_image = np.zeros([1300,1300,11])
h_start_index = 649-h_center
h_end_index = h_start_index + image_shape[0]
w_start_index = 649-w_center
w_end_index = w_start_index + image_shape[1]
pad_image[h_start_index:h_end_index,w_start_index:w_end_index,:] = image
return pad_image

def show(image):
for i in range(image.shape[0]):
for j in range(image.shape[1]):
if image[i][j]!=0 and image[i][j] !=1 and image[i][j] !=0.5:
print([i,j])
trimap_kernel = [val for val in range(20,35)]

def UR_center(image):
'''
image consists 5 channel (images, trimap, GT alpha matte)
centered on unknown region
'''
trimap = image[:,:,3]
# UR = [[i,j] for i, j in itertools.product(range(trimap.shape[0]), range(trimap.shape[1])) if trimap[i,j] == 128]
target = np.where(trimap==127.5)
# show(trimap)
index = random.choice([i for i in range(len(target[0]))])
return np.array(target)[:,index][:2]
# return [int(i) for i in np.array(UR).mean(0)]
# return random.choice(UR)

# trimap = tf.convert_to_tensor(image[:,:,3])
# condition = tf.equal(trimap,128)
# indices = tf.where(condition)
# return random.choice(indices)

def composition_RGB(BG,FG,p_matte):
GB = tf.convert_to_tensor(BG)
FG = tf.convert_to_tensor(FG)
return p_matte * FG + (1 - p_matte) * BG

def global_mean(RGB_folder):
RGBs = os.listdir(RGB_folder)
num = len(RGBs)
ite = num // 100
sum_tmp = []
print(ite)
for i in range(ite):
print(i)
batch_tmp = np.zeros([input_image_size,input_image_size,3])
RGBs_tmp = [os.path.join(RGB_folder,RGB_path) for RGB_path in RGBs[i*100:(i+1)*100]]
for RGB in RGBs_tmp:
batch_tmp += misc.imread(RGB)
sum_tmp.append(batch_tmp.sum(axis = 0).sum(axis = 0) / (input_image_size*input_image_size*100))
return np.array(sum_tmp).mean(axis = 0)
# def global_mean(RGB_folder):
# RGBs = os.listdir(RGB_folder)
# num = len(RGBs)
# ite = num // 100
# sum_tmp = []
# print(ite)
# for i in range(ite):
# print(i)
# batch_tmp = np.zeros([320,320,3])
# RGBs_tmp = [os.path.join(RGB_folder,RGB_path) for RGB_path in RGBs[i*100:(i+1)*100]]
# for RGB in RGBs_tmp:
# batch_tmp += misc.imread(RGB)
# sum_tmp.append(batch_tmp.sum(axis = 0).sum(axis = 0) / (320*320*100))
# return np.array(sum_tmp).mean(axis = 0)

def load_path(dataset_RGB,alpha,FG,BG):
RGBs_path = os.listdir(dataset_RGB)
RGBs_abspath = [os.path.join(dataset_RGB,RGB) for RGB in RGBs_path]
alphas_abspath = [os.path.join(alpha,RGB.split('-')[0]+'.png') for RGB in RGBs_path]
FGs_abspath = [os.path.join(FG,RGB.split('-')[0]+'.png') for RGB in RGBs_path]
alphas_abspath = [os.path.join(alpha,RGB.split('-')[0],RGB.split('-')[1]) for RGB in RGBs_path]
FGs_abspath = [os.path.join(FG,RGB.split('-')[0],RGB.split('-')[1]) for RGB in RGBs_path]
BGs_abspath = [os.path.join(BG,RGB.split('-')[0],RGB.split('-')[1][:-3]+'jpg') for RGB in RGBs_path]
return RGBs_abspath,alphas_abspath,FGs_abspath,BGs_abspath

def load_data(batch_RGB_paths,batch_alpha_paths,batch_FG_paths,batch_BG_paths):
batch_RGBs = []

for path in batch_RGB_paths:
file_contents = tf.read_file(tf.convert_to_tensor(path))
image = tf.image.decode_png(file_contents)
batch_RGBs.append(image)
batch_RGBs = tf.stack(batch_RGBs)
return np.array(RGBs_abspath),np.array(alphas_abspath),np.array(FGs_abspath),np.array(BGs_abspath)

batch_alphas = []
for path in batch_alpha_paths:
file_contents = tf.read_file(tf.convert_to_tensor(path))
image = tf.image.decode_png(file_contents)
batch_alphas.append(image)
batch_alphas = tf.stack(batch_alphas)

batch_FGs = []
for path in batch_FG_paths:
file_contents = tf.read_file(tf.convert_to_tensor(path))
image = tf.image.decode_png(file_contents)
batch_FGs.append(image)
batch_FGs = tf.stack(batch_FGs)

batch_BGs = []
for path in batch_BG_paths:
file_contents = tf.read_file(tf.convert_to_tensor(path))
image = tf.image.decode_jpeg(file_contents)
batch_BGs.append(image)
batch_BGs = tf.stack(batch_BGs)
return batch_RGBs,batch_alphas,batch_FGs,batch_BGs

def generate_trimap(trimap_batch,GTmatte_batch,batch_size):
kernel = [val for val in range(15,31)]
def load_data(sess,batch_RGB_paths,batch_alpha_paths,batch_FG_paths,batch_BG_paths):

batch_size = batch_RGB_paths.shape[0]
train_batch = []
for i in range(batch_size):
k_size = random.choice(kernel)
trimap_batch[i][np.where((ndimage.grey_dilation(GTmatte_batch[i][:,:,0],size=(k_size,k_size)) - ndimage.grey_erosion(GTmatte_batch[i][:,:,0],size=(k_size,k_size)))!=0)] = 127.5
return trimap_batch


comp_RGB_content = tf.read_file(tf.convert_to_tensor(batch_RGB_paths[i]))
comp_RGB = tf.cast(tf.image.decode_png(comp_RGB_content),tf.float32)

alpha_content = tf.read_file(tf.convert_to_tensor(batch_alpha_paths[i]))
alpha = tf.cast(tf.image.decode_png(alpha_content),tf.float32)

FG_content = tf.read_file(tf.convert_to_tensor(batch_FG_paths[i]))
FG = tf.cast(tf.image.decode_png(FG_content),tf.float32)

BG_content = tf.read_file(tf.convert_to_tensor(batch_BG_paths[i]))
BG = tf.cast(tf.image.decode_jpeg(BG_content),tf.float32)
print('##########')
batch_i = tf.py_func(preprocessing_single,[comp_RGB, alpha, BG, FG, 320],tf.float32)
batch_i.set_shape([320,320,11])
train_batch.append(batch_i)
train_batch = sess.run(tf.stack(train_batch))
return train_batch[:,:,:,:3],np.expand_dims(train_batch[:,:,:,3],3),np.expand_dims(train_batch[:,:,:,4],3),train_batch[:,:,:,5:8],train_batch[:,:,:,8:]

def generate_trimap(trimap,alpha):

k_size = random.choice(trimap_kernel)
trimap[np.where((ndimage.grey_dilation(alpha[:,:,0],size=(k_size,k_size)) - ndimage.grey_erosion(alpha[:,:,0],size=(k_size,k_size)))!=0)] = 127.5
return trimap

def preprocessing_single(comp_RGB, alpha, BG, FG, image_size=320):

g_mean = np.array(([126.88987627,120.24313843,112.19594981])).reshape([1,1,3])
trimap = np.copy(alpha)
#trimap_batch = copy.deepcopy(GTmatte_batch)
trimap = generate_trimap(trimap,alpha)

train_pre = np.concatenate([comp_RGB,trimap,alpha,BG,FG],2)
train_data = np.zeros([image_size,image_size,11])
crop_size = random.choice([320,480,640])
print(crop_size)
flip = random.choice([0,1])
print(flip)
i_UR_center = UR_center(train_pre)

if crop_size == 320:
h_start_index = i_UR_center[0] - 159
w_start_index = i_UR_center[1] - 159
tmp = train_pre[h_start_index:h_start_index+320, w_start_index:w_start_index+320, :]
if flip:
tmp = tmp[:,::-1,:]
# tmp[:,:,:3] = tmp[:,:,:3] - mean
tmp[:,:,3:5] = tmp[:,:,3:5] / 255.0
tmp[:,:,:3] -= g_mean
train_data = tmp

if crop_size == 480:
h_start_index = i_UR_center[0] - 239
w_start_index = i_UR_center[1] - 239
tmp = train_pre[h_start_index:h_start_index+480, w_start_index:w_start_index+480, :]
if flip:
tmp = tmp[:,::-1,:]
tmp1 = np.zeros([image_size,image_size,11])
tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean
tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0
tmp1[:,:,4] = misc.imresize(tmp[:,:,4],[image_size,image_size]) / 255.0
tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3])
tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3])
train_data = tmp1

if crop_size == 640:
h_start_index = i_UR_center[0] - 319
w_start_index = i_UR_center[1] - 319
tmp = train_pre[h_start_index:h_start_index+640, w_start_index:w_start_index+640, :]
if flip:
tmp = tmp[:,::-1,:]
tmp1 = np.zeros([image_size,image_size,11])
tmp1[:,:,:3] = misc.imresize(tmp[:,:,:3],[image_size,image_size,3]) - g_mean
tmp1[:,:,3] = misc.imresize(tmp[:,:,3],[image_size,image_size],interp = 'nearest') / 255.0
tmp1[:,:,4] = misc.imresize(tmp[:,:,4],[image_size,image_size]) / 255.0
tmp1[:,:,5:8] = misc.imresize(tmp[:,:,5:8],[image_size,image_size,3])
tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3])
train_data = tmp1
train_data = train_data.astype(np.float32)
return train_data
Loading

0 comments on commit 9e287d4

Please sign in to comment.