Skip to content

Commit

Permalink
update cuda raymarching
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Feb 18, 2022
1 parent cb3a0e4 commit 8288a9d
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 284 deletions.
13 changes: 7 additions & 6 deletions hashencoder/hashgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution,

inputs = inputs.contiguous()
embeddings = embeddings.contiguous()
offsets = offsets.contiguous().to(inputs.device)
offsets = offsets.contiguous()

B, D = inputs.shape # batch size, coord dim
L = offsets.shape[0] - 1 # level
Expand Down Expand Up @@ -95,19 +95,20 @@ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, b
print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)')

# allocate parameters
self.offsets = []
offsets = []
offset = 0
self.max_params = 2 ** log2_hashmap_size
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(self.max_params, (resolution + 1) ** input_dim) # limit max number
params_in_level = int(params_in_level / 8) * 8 # make divisible
self.offsets.append(offset)
offsets.append(offset)
offset += params_in_level
self.offsets.append(offset)
self.offsets = torch.from_numpy(np.array(self.offsets, dtype=np.int32))
offsets.append(offset)
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
self.register_buffer('offsets', offsets)

self.n_params = self.offsets[-1] * level_dim
self.n_params = offsets[-1] * level_dim

# parameters
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
Expand Down
89 changes: 45 additions & 44 deletions nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self,
geo_feat_dim=15,
num_layers_color=3,
hidden_dim_color=64,
density_grid_size=-1, # density grid size
cuda_raymarching=False,
):
super().__init__()

Expand Down Expand Up @@ -150,15 +150,19 @@ def __init__(self,

self.color_net = nn.ModuleList(color_net)

# density grid
if density_grid_size > 0:
# buffer is like parameter but never requires_grad
density_grid = torch.zeros([density_grid_size + 1] * 3) # +1 because we save values at grid
# extra state for cuda raymarching
self.cuda_raymarching = cuda_raymarching
if cuda_raymarching:
# density grid
density_grid = torch.zeros([128 + 1] * 3) # +1 because we save values at grid
self.register_buffer('density_grid', density_grid)
self.mean_density = 0
self.iter_density = 0
else:
self.density_grid = None
# step counter
step_counter = torch.zeros(64, 2, dtype=torch.int32) # 64 is hardcoded for averaging...
self.register_buffer('step_counter', step_counter)
self.mean_count = 0
self.local_step = 0

def forward(self, x, d, bound):
# x: [B, N, 3], in [-bound, bound]
Expand Down Expand Up @@ -299,7 +303,7 @@ def run(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color):

# mix background color
if bg_color is None:
bg_color = 1 # let it broadcast
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color

return depth, image
Expand All @@ -314,38 +318,43 @@ def run_cuda(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color):
bg_color = torch.ones(3, dtype=rays_o.dtype, device=rays_o.device)

### generate points (forward only)
points, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, self.training)
if self.training:
counter = self.step_counter[self.local_step % 64]
counter.zero_() # set to 0
self.local_step += 1
force_all_rays = False
else:
counter = None
force_all_rays = True

xyzs, dirs, deltas, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, self.training, 128, force_all_rays)

### call network inference
sigmas, rgbs = self(points[:, :3], points[:, 3:6], bound=bound)
sigmas, rgbs = self(xyzs, dirs, bound=bound)

### accumulate rays (need backward)
# inputs: sigmas: [M], rgbs: [M, 3], offsets: [N+1]
# outputs: depth: [N], image: [N, 3]
depth, image = raymarching.accumulate_rays(sigmas, rgbs, points, rays, bound, bg_color)
depth, image = raymarching.accumulate_rays(sigmas, rgbs, deltas, rays, bound, bg_color)

depth = depth.reshape(B, N)
image = image.reshape(B, N, 3)

return depth, image


def update_density_grid(self, bound, decay=0.95, split_size=128):
# call before run_cuda, prepare a coarse density grid.
def update_extra_state(self, bound, decay=0.95):
# call before each epoch to update extra states.

if self.density_grid is None:
if not self.cuda_raymarching:
return

### update density grid
resolution = self.density_grid.shape[0]

N = split_size # chunk to avoid OOM

X = torch.linspace(-bound, bound, resolution).split(N)
Y = torch.linspace(-bound, bound, resolution).split(N)
Z = torch.linspace(-bound, bound, resolution).split(N)

# all_pts = []
# all_density = []
X = torch.linspace(-bound, bound, resolution).split(128)
Y = torch.linspace(-bound, bound, resolution).split(128)
Z = torch.linspace(-bound, bound, resolution).split(128)

tmp_grid = torch.zeros_like(self.density_grid)
with torch.no_grad():
Expand All @@ -354,39 +363,31 @@ def update_density_grid(self, bound, decay=0.95, split_size=128):
for zi, zs in enumerate(Z):
lx, ly, lz = len(xs), len(ys), len(zs)
xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij')
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(tmp_grid.device).unsqueeze(0) # [1, N, 3]
density = self.density(pts, bound).reshape(lx, ly, lz).detach()
tmp_grid[xi * N: xi * N + lx, yi * N: yi * N + ly, zi * N: zi * N + lz] = density

# all_pts.append(pts[0])
# all_density.append(density.reshape(-1))
pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3]
# manual padding for ffmlp
n = pts.shape[0]
pad_n = 128 - (n % 128)
if pad_n != 0:
pts = torch.cat([pts, torch.zeros(pad_n, 3)], dim=0)
density = self.density(pts.to(tmp_grid.device), bound)[:n].reshape(lx, ly, lz).detach()
tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = density


# smooth by maxpooling
tmp_grid = F.pad(tmp_grid, (0, 1, 0, 1, 0, 1))
tmp_grid = F.max_pool3d(tmp_grid.unsqueeze(0).unsqueeze(0), kernel_size=2, stride=1).squeeze(0).squeeze(0)

# ema update
#self.density_grid = tmp_grid
self.density_grid = torch.maximum(self.density_grid * decay, tmp_grid)

self.mean_density = torch.mean(self.density_grid).item()
self.iter_density += 1

# TMP: save as voxel volume (point cloud format...)
# all_pts = torch.cat(all_pts, dim=0).detach().cpu().numpy() # [N, 3]
# all_density = torch.cat(all_density, dim=0).detach().cpu().numpy() # [N]
# mask = all_density > 10
# all_pts = all_pts[mask]
# all_density = all_density[mask]
# plot_pointcloud(all_pts, map_color(all_density))

#vertices, triangles = mcubes.marching_cubes(tmp_grid.detach().cpu().numpy(), 5)
#vertices = vertices / (resolution - 1.0) * 2 * bound - bound
#mesh = trimesh.Trimesh(vertices, triangles)
#mesh.export(f'./{self.iter_density}.ply')
### update step counter
total_step = min(64, self.local_step)
if total_step > 0:
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
self.local_step = 0

print(f'[density grid] iter={self.iter_density} min={self.density_grid.min().item()}, max={self.density_grid.max().item()}, mean={self.mean_density}')
print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f} | [step counter] mean={self.mean_count}')


def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False, max_ray_batch=4096, bg_color=None, **kwargs):
Expand Down
87 changes: 43 additions & 44 deletions nerf/network_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self,
geo_feat_dim=15,
num_layers_color=3,
hidden_dim_color=64,
density_grid_size=-1, # density grid size
cuda_raymarching=False,
):
super().__init__()

Expand Down Expand Up @@ -121,23 +121,27 @@ def __init__(self,
num_layers=self.num_layers_color,
)

# density grid
if density_grid_size > 0:
# buffer is like parameter but never requires_grad
density_grid = torch.zeros([density_grid_size + 1] * 3) # +1 because we save values at grid
# extra state for cuda raymarching
self.cuda_raymarching = cuda_raymarching
if cuda_raymarching:
# density grid
density_grid = torch.zeros([128 + 1] * 3) # +1 because we save values at grid
self.register_buffer('density_grid', density_grid)
self.mean_density = 0
self.iter_density = 0
else:
self.density_grid = None
# step counter
step_counter = torch.zeros(64, 2, dtype=torch.int32) # 64 is hardcoded for averaging...
self.register_buffer('step_counter', step_counter)
self.mean_count = 0
self.local_step = 0

def forward(self, x, d, bound):
# x: [B, N, 3], in [-bound, bound]
# d: [B, N, 3], nomalized in [-1, 1]

prefix = x.shape[:-1]
x = x.reshape(-1, 3)
d = d.reshape(-1, 3)
x = x.view(-1, 3)
d = d.view(-1, 3)

# sigma
x = self.encoder(x, size=bound)
Expand All @@ -156,24 +160,24 @@ def forward(self, x, d, bound):
# sigmoid activation for rgb
color = torch.sigmoid(h)

sigma = sigma.reshape(*prefix)
color = color.reshape(*prefix, -1)
sigma = sigma.view(*prefix)
color = color.view(*prefix, -1)

return sigma, color

def density(self, x, bound):
# x: [B, N, 3], in [-bound, bound]

prefix = x.shape[:-1]
x = x.reshape(-1, 3)
x = x.view(-1, 3)

x = self.encoder(x, size=bound)
h = self.sigma_net(x)

#sigma = torch.exp(torch.clamp(h[..., 0], -15, 15))
sigma = F.relu(h[..., 0])

sigma = sigma.reshape(*prefix)
sigma = sigma.view(*prefix)

return sigma

Expand Down Expand Up @@ -286,46 +290,43 @@ def run_cuda(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color):
bg_color = torch.ones(3, dtype=rays_o.dtype, device=rays_o.device)

### generate points (forward only)
points, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, self.training)

### call network inference
# manual pad for ffmlp (slow, should be avoided...)
n = points.shape[0]
pad_n = 128 - (n % 128)
if pad_n > 0:
points = torch.cat([points, torch.zeros(pad_n, points.shape[1], device=points.device, dtype=points.dtype)], dim=0)

sigmas, rgbs = self(points[:, :3], points[:, 3:6], bound=bound)
if self.training:
counter = self.step_counter[self.local_step % 64]
counter.zero_() # set to 0
self.local_step += 1
force_all_rays = False
else:
counter = None
force_all_rays = True

if pad_n > 0:
sigmas = sigmas[:n]
rgbs = rgbs[:n]
xyzs, dirs, deltas, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, self.training, 128, force_all_rays)

### call network inference
sigmas, rgbs = self(xyzs, dirs, bound=bound)

### accumulate rays (need backward)
# inputs: sigmas: [M], rgbs: [M, 3], offsets: [N+1]
# outputs: depth: [N], image: [N, 3]
depth, image = raymarching.accumulate_rays(sigmas, rgbs, points, rays, bound, bg_color)
depth, image = raymarching.accumulate_rays(sigmas, rgbs, deltas, rays, bound, bg_color)

depth = depth.reshape(B, N)
image = image.reshape(B, N, 3)

return depth, image


def update_density_grid(self, bound, decay=0.95, split_size=128):
# call before run_cuda, prepare a coarse density grid.
def update_extra_state(self, bound, decay=0.95):
# call before each epoch to update extra states.

if self.density_grid is None:
if not self.cuda_raymarching:
return

### update density grid
resolution = self.density_grid.shape[0]

N = split_size # chunk to avoid OOM

X = torch.linspace(-bound, bound, resolution).split(N)
Y = torch.linspace(-bound, bound, resolution).split(N)
Z = torch.linspace(-bound, bound, resolution).split(N)
X = torch.linspace(-bound, bound, resolution).split(128)
Y = torch.linspace(-bound, bound, resolution).split(128)
Z = torch.linspace(-bound, bound, resolution).split(128)

tmp_grid = torch.zeros_like(self.density_grid)
with torch.no_grad():
Expand All @@ -341,26 +342,24 @@ def update_density_grid(self, bound, decay=0.95, split_size=128):
if pad_n != 0:
pts = torch.cat([pts, torch.zeros(pad_n, 3)], dim=0)
density = self.density(pts.to(tmp_grid.device), bound)[:n].reshape(lx, ly, lz).detach()
tmp_grid[xi * N: xi * N + lx, yi * N: yi * N + ly, zi * N: zi * N + lz] = density
tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = density

# smooth by maxpooling
tmp_grid = F.pad(tmp_grid, (0, 1, 0, 1, 0, 1))
tmp_grid = F.max_pool3d(tmp_grid.unsqueeze(0).unsqueeze(0), kernel_size=2, stride=1).squeeze(0).squeeze(0)

# ema update
#self.density_grid = tmp_grid
self.density_grid = torch.maximum(self.density_grid * decay, tmp_grid)

self.mean_density = torch.mean(self.density_grid).item()
self.iter_density += 1

# TMP: save mesh for debug
# vertices, triangles = mcubes.marching_cubes(tmp_grid.detach().cpu().numpy(), 5)
# vertices = vertices / (resolution - 1.0) * 2 * bound - bound
# mesh = trimesh.Trimesh(vertices, triangles)
# mesh.export(f'./tmp/{self.iter_density}.ply')
### update step counter
total_step = min(64, self.local_step)
if total_step > 0:
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
self.local_step = 0

print(f'[density grid] iter={self.iter_density} min={self.density_grid.min().item()}, max={self.density_grid.max().item()}, mean={self.mean_density}')
print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f} | [step counter] mean={self.mean_count}')


def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False, max_ray_batch=4096, bg_color=None, **kwargs):
Expand Down
Loading

0 comments on commit 8288a9d

Please sign in to comment.