Skip to content

Commit

Permalink
add provider.py, debug for data_prepare.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ShixuanGu committed Oct 4, 2021
1 parent fe1282e commit a5d92d8
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 23 deletions.
49 changes: 26 additions & 23 deletions data_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,86 +14,89 @@ def main():
for data in [x for x in os.listdir('./data/ribfrac/ribfrac-train-images-1/Part1/')]:
source = nib.load('./data/ribfrac/ribfrac-train-images-1/Part1/'+data)
source = source.get_fdata()
source[source != 0] = 1
source[source >= 200] = 1
source[source != 1] = 0

label = nib.load('./data/RibSeg/nii/'+data[:-12]+'rib-seg.nii.gz')
label = label.get_fdata()
label[label !=0] = 1

temp = np.argwhere(source == 1)
choice = np.random.choice(temp.shape[0], 30000, replace=False)
# downsample
points = temp[choice, :]

label = []
label_selected_points = []
for i in points:
label.append(label[i[0]][i[1]][i[2]])
temp = np.array(temp)
label_selected_points.append(label[i[0]][i[1]][i[2]])
label_selected_points = np.array(label_selected_points)
np.save('./data/pn/data_pn/train'+data[:-13], points)
np.save('./data/pn/label_pn/train' + data[:-13], temp)
np.save('./data/pn/label_pn/train' + data[:-13], label_selected_points)


for data in [x for x in os.listdir('./data/ribfrac/ribfrac-train-images-2/Part2/')]:
source = nib.load('./data/ribfrac/ribfrac-train-images-2/Part2/'+data)
source = source.get_fdata()
source[source != 0] = 1
source[source >= 200] = 1
source[source != 1] = 0

label = nib.load('./data/RibSeg/nii/'+data[:-12]+'rib-seg.nii.gz')
label = label.get_fdata()
label[label !=0] = 1

temp = np.argwhere(source == 1)
choice = np.random.choice(temp.shape[0], 30000, replace=False)
# downsample
points = temp[choice, :]

label = []
label_selected_points = []
for i in points:
label.append(label[i[0]][i[1]][i[2]])
temp = np.array(temp)
label_selected_points.append(label[i[0]][i[1]][i[2]])
label_selected_points = np.array(label_selected_points)
np.save('./data/pn/data_pn/train'+data[:-13], points)
np.save('./data/pn/label_pn/train' + data[:-13], temp)
np.save('./data/pn/label_pn/train' + data[:-13], label_selected_points)


for data in [x for x in os.listdir('./data/ribfrac/ribfrac-val-images/')]:
source = nib.load('./ribfrac/ribfrac-val-images/' + data)
source = source.get_fdata()
source[source != 0] = 1
source[source >= 200] = 1
source[source != 1] = 0

label = nib.load('./data/RibSeg/nii/' + data[:-12] + 'rib-seg.nii.gz')
label = label.get_fdata()
label[label != 0] = 1

temp = np.argwhere(source == 1)
choice = np.random.choice(temp.shape[0], 30000, replace=False)
# downsample
points = temp[choice, :]

label = []
label_selected_points = []
for i in points:
label.append(label[i[0]][i[1]][i[2]])
temp = np.array(temp)
label_selected_points.append(label[i[0]][i[1]][i[2]])
label_selected_points = np.array(label_selected_points)
np.save('./data/pn/data_pn/val' + data[:-13], points)
np.save('./data/pn/label_pn/val' + data[:-13], temp)
np.save('./data/pn/label_pn/val' + data[:-13], label_selected_points)


for data in [x for x in os.listdir('./data/ribfrac/ribfrac-test-images/')]:
source = nib.load('./data/ribfrac/ribfrac-test-images/'+data)
source = source.get_fdata()
source[source != 0] = 1
source[source >= 200] = 1
source[source != 1] = 0

label = nib.load('./data/RibSeg/nii/'+data[:-12]+'rib-seg.nii.gz')
label = label.get_fdata()
label[label !=0] = 1

temp = np.argwhere(source == 1)
choice = np.random.choice(temp.shape[0], 30000, replace=False)
# downsample
points = temp[choice, :]

label = []
label_selected_points = []
for i in points:
label.append(label[i[0]][i[1]][i[2]])
label_selected_points.append(label[i[0]][i[1]][i[2]])
temp = np.array(temp)
np.save('./data/pn/data_pn/test'+data[:-13], points)
np.save('./data/pn/label_pn/test' + data[:-13], temp)
np.save('./data/pn/label_pn/test' + data[:-13], label_selected_points)

if __name__ == '__main__':
main()
Expand Down
251 changes: 251 additions & 0 deletions provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import numpy as np

def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
for b in range(B):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data


def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx

def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]

def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data

def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data

def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return batch_xyz_normal

def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data


def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data

def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
return rotated_data



def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data


def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data

def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data


def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data

def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc



0 comments on commit a5d92d8

Please sign in to comment.