Skip to content

Commit

Permalink
add load_test_data(test_alpha)
Browse files Browse the repository at this point in the history
  • Loading branch information
GE ZHENG committed Aug 25, 2017
1 parent e72703a commit 6fcc592
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ def preprocessing_single(comp_RGB, alpha, BG, FG, batch_RGB_paths,i,image_size=3

if crop_size == 620:
h_start_index = i_UR_center[0] - 309
w_start_index = i_UR_center[1] - 309
#boundary security
if h_start_index<0:
h_start_index = 0
w_start_index = i_UR_center[1] - 309
if w_start_index<0:
w_start_index = 0
tmp = train_pre[h_start_index:h_start_index+620, w_start_index:w_start_index+620, :]
Expand All @@ -129,4 +130,26 @@ def preprocessing_single(comp_RGB, alpha, BG, FG, batch_RGB_paths,i,image_size=3
tmp1[:,:,8:] = misc.imresize(tmp[:,:,8:],[image_size,image_size,3])
train_data = tmp1
train_data = train_data.astype(np.float32)
return train_data
return train_data

def load_test_data(test_alpha):
rgb_path = os.path.join(test_alpha,'rgb')
trimap_path = os.path.join(test_alpha,'trimap')
alpha_path = os.path.join(test_alpha,'alpha')
images = os.listdir(trimap_path)
test_num = len(images)
all_shape = []
rgb_batch = []
tri_batch = []
alp_batch = []
for i in range(test_num):
rgb = misc.imread(os.path.join(rgb_path,images[i]))
trimap = misc.imread(os.path.join(trimap_path,images[i]),'L')
alpha = misc.imread(os.path.join(alpha_path,images[i]),'L')
all_shape.append(trimap.shape)
rgb_batch.append(misc.imresize(rgb,[320,320,3]))
trimap = misc.imresize(trimap,[320,320],interp = 'nearest')
tri_batch.append(np.expand_dims(trimap,2))
alp_batch.append(np.expand_dims(alpha,2))
return np.array(rgb_batch),np.array(tri_batch),np.array(alp_batch),all_shape,images

0 comments on commit 6fcc592

Please sign in to comment.