forked from haomo-ai/SuperFusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
voxel.py
115 lines (103 loc) · 4.58 KB
/
voxel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import torch
import torch_scatter
def pad_or_trim_to_np(x, shape, pad_val=0):
shape = np.asarray(shape)
pad = shape - np.minimum(np.shape(x), shape)
zeros = np.zeros_like(pad)
x = np.pad(x, np.stack([zeros, pad], axis=1), constant_values=pad_val)
return x[:shape[0], :shape[1]]
def raval_index(coords, dims):
dims = torch.cat((dims, torch.ones(1, device=dims.device)), dim=0)[1:]
dims = torch.flip(dims, dims=[0])
dims = torch.cumprod(dims, dim=0) / dims[0]
multiplier = torch.flip(dims, dims=[0])
indices = torch.sum(coords * multiplier, dim=1)
return indices
def points_to_voxels(
points_xyz,
points_mask,
grid_range_x,
grid_range_y,
grid_range_z
):
batch_size, num_points, _ = points_xyz.shape
voxel_size_x = grid_range_x[2]
voxel_size_y = grid_range_y[2]
voxel_size_z = grid_range_z[2]
grid_size = np.asarray([
(grid_range_x[1]-grid_range_x[0]) / voxel_size_x,
(grid_range_y[1]-grid_range_y[0]) / voxel_size_y,
(grid_range_z[1]-grid_range_z[0]) / voxel_size_z
]).astype('int32')
voxel_size = np.asarray([voxel_size_x, voxel_size_y, voxel_size_z])
voxel_size = torch.Tensor(voxel_size).to(points_xyz.device)
num_voxels = grid_size[0] * grid_size[1] * grid_size[2]
grid_offset = torch.Tensor(
[grid_range_x[0], grid_range_y[0], grid_range_z[0]]).to(points_xyz.device)
shifted_points_xyz = points_xyz - grid_offset
voxel_xyz = shifted_points_xyz / voxel_size
voxel_coords = voxel_xyz.int()
grid_size = torch.from_numpy(grid_size).to(points_xyz.device)
grid_size = grid_size.int()
zeros = torch.zeros_like(grid_size)
voxel_paddings = ((points_mask < 1.0) |
torch.any((voxel_coords >= grid_size) |
(voxel_coords < zeros), dim=-1))
voxel_indices = raval_index(
torch.reshape(voxel_coords, [batch_size * num_points, 3]), grid_size)
voxel_indices = torch.reshape(voxel_indices, [batch_size, num_points])
voxel_indices = torch.where(voxel_paddings,
torch.zeros_like(voxel_indices),
voxel_indices)
voxel_centers = ((0.5 + voxel_coords.float()) * voxel_size + grid_offset)
voxel_coords = torch.where(torch.unsqueeze(voxel_paddings, dim=-1),
torch.zeros_like(voxel_coords),
voxel_coords)
voxel_xyz = torch.where(torch.unsqueeze(voxel_paddings, dim=-1),
torch.zeros_like(voxel_xyz),
voxel_xyz)
voxel_paddings = voxel_paddings.float()
voxel_indices = voxel_indices.long()
points_per_voxel = torch_scatter.scatter_sum(
torch.ones((batch_size, num_points), dtype=voxel_coords.dtype,
device=voxel_coords.device) * (1-voxel_paddings),
voxel_indices,
dim=1,
dim_size=num_voxels
)
voxel_point_count = torch.gather(points_per_voxel,
dim=1,
index=voxel_indices)
#print("voxel_point_count: ", voxel_point_count) #voxel_point_count: tensor([[16., 15., 20., ..., 0., 0., 0.],
# [16., 15., 22., ..., 0., 0., 0.],
# [ 5., 14., 21., ..., 1., 1., 1.],
# [15., 6., 20., ..., 0., 0., 0.]], device='cuda:0')
#print("voxel_point_count shape: ", voxel_point_count.shape) #voxel_point_count shape: torch.Size([4, 81920])
voxel_centroids = torch_scatter.scatter_mean(
points_xyz,
voxel_indices,
dim=1,
dim_size=num_voxels)
point_centroids = torch.gather(voxel_centroids, dim=1, index=torch.unsqueeze(
voxel_indices, dim=-1).repeat(1, 1, 3))
local_points_xyz = points_xyz - point_centroids
result = {
'local_points_xyz': local_points_xyz,
'shifted_points_xyz': shifted_points_xyz,
'point_centroids': point_centroids,
'points_xyz': points_xyz,
'grid_offset': grid_offset,
'voxel_coords': voxel_coords,
'voxel_centers': voxel_centers,
'voxel_indices': voxel_indices,
'voxel_paddings': voxel_paddings,
'points_mask': 1 - voxel_paddings,
'num_voxels': num_voxels,
'grid_size': grid_size,
'voxel_xyz': voxel_xyz,
'voxel_size': voxel_size,
'voxel_point_count': voxel_point_count,
'points_per_voxel': points_per_voxel
}
return result