Skip to content

Commit

Permalink
misc fix, manually merge ashawkey#83, add basis based dnerf
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jul 19, 2022
1 parent dd2add6 commit 5a7222a
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 13 deletions.
1 change: 1 addition & 0 deletions assets/update_logs.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## Update logs

* 7.16: add temporal basis based dynamic nerf (experimental). It trains much faster compared to the deformation based dynamic nerf, but performance is much worse for now...
* 6.29: add support for HyperNeRF's dataset.
* we use a simplified pinhole camera model, may introduce bias.
* 6.26: add support for D-NeRF.
Expand Down
262 changes: 262 additions & 0 deletions dnerf/network_basis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from encoding import get_encoder
from activation import trunc_exp
from .renderer import NeRFRenderer


class NeRFNetwork(NeRFRenderer):
def __init__(self,
encoding="tiledgrid",
encoding_dir="sphere_harmonics",
encoding_time="frequency",
encoding_bg="hashgrid",
num_layers=2,
hidden_dim=64,
geo_feat_dim=32,
num_layers_color=3,
hidden_dim_color=64,
num_layers_bg=2,
hidden_dim_bg=64,
sigma_basis_dim=32,
color_basis_dim=8,
num_layers_basis=5,
hidden_dim_basis=128,
bound=1,
**kwargs,
):
super().__init__(bound, **kwargs)

# basis network
self.num_layers_basis = num_layers_basis
self.hidden_dim_basis = hidden_dim_basis
self.sigma_basis_dim = sigma_basis_dim
self.color_basis_dim = color_basis_dim
self.encoder_time, self.in_dim_time = get_encoder(encoding_time, input_dim=1, multires=6)

basis_net = []
for l in range(num_layers_basis):
if l == 0:
in_dim = self.in_dim_time
else:
in_dim = hidden_dim_basis

if l == num_layers_basis - 1:
out_dim = self.sigma_basis_dim + self.color_basis_dim
else:
out_dim = hidden_dim_basis

basis_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.basis_net = nn.ModuleList(basis_net)

# sigma network
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.geo_feat_dim = geo_feat_dim
self.encoder, self.in_dim = get_encoder(encoding, desired_resolution=2048 * bound)

sigma_net = []
for l in range(num_layers):
if l == 0:
in_dim = self.in_dim
else:
in_dim = hidden_dim

if l == num_layers - 1:
out_dim = self.sigma_basis_dim + self.geo_feat_dim # SB sigma + features for color
else:
out_dim = hidden_dim

sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.sigma_net = nn.ModuleList(sigma_net)

# color network
self.num_layers_color = num_layers_color
self.hidden_dim_color = hidden_dim_color
self.encoder_dir, self.in_dim_dir = get_encoder(encoding_dir)

color_net = []
for l in range(num_layers_color):
if l == 0:
in_dim = self.in_dim_dir + self.geo_feat_dim
else:
in_dim = hidden_dim

if l == num_layers_color - 1:
out_dim = 3 * self.color_basis_dim # 3 * CB rgb
else:
out_dim = hidden_dim

color_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.color_net = nn.ModuleList(color_net)

# background network
if self.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
self.encoder_bg, self.in_dim_bg = get_encoder(encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048) # much smaller hashgrid

bg_net = []
for l in range(num_layers_bg):
if l == 0:
in_dim = self.in_dim_bg + self.in_dim_dir
else:
in_dim = hidden_dim_bg

if l == num_layers_bg - 1:
out_dim = 3 # 3 rgb
else:
out_dim = hidden_dim_bg

bg_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.bg_net = nn.ModuleList(bg_net)
else:
self.bg_net = None


def forward(self, x, d, t):
# x: [N, 3], in [-bound, bound]
# d: [N, 3], nomalized in [-1, 1]
# t: [1, 1], in [0, 1]

# time --> basis
enc_t = self.encoder_time(t) # [1, 1] --> [1, C']
h = enc_t
for l in range(self.num_layers_basis):
h = self.basis_net[l](h)
if l != self.num_layers_basis - 1:
h = F.relu(h, inplace=True)

sigma_basis = h[0, :self.sigma_basis_dim]
color_basis = h[0, self.sigma_basis_dim:]

# sigma
x = self.encoder(x, bound=self.bound)
h = x
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)

sigma = trunc_exp(h[..., :self.sigma_basis_dim] @ sigma_basis)
geo_feat = h[..., self.sigma_basis_dim:]

# color
d = self.encoder_dir(d)
h = torch.cat([d, geo_feat], dim=-1)
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)

# sigmoid activation for rgb
rgbs = torch.sigmoid(h.view(-1, 3, self.color_basis_dim) @ color_basis)

return sigma, rgbs, None

def density(self, x, t):
# x: [N, 3], in [-bound, bound]
# t: [1, 1], in [0, 1]

results = {}

# time --> basis
enc_t = self.encoder_time(t) # [1, 1] --> [1, C']
h = enc_t
for l in range(self.num_layers_basis):
h = self.basis_net[l](h)
if l != self.num_layers_basis - 1:
h = F.relu(h, inplace=True)

sigma_basis = h[0, :self.sigma_basis_dim]
color_basis = h[0, self.sigma_basis_dim:]

# sigma
x = self.encoder(x, bound=self.bound)
h = x
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)

sigma = trunc_exp(h[..., :self.sigma_basis_dim] @ sigma_basis)
geo_feat = h[..., self.sigma_basis_dim:]

results['sigma'] = sigma
results['geo_feat'] = geo_feat
# results['color_basis'] = color_basis

return results

def background(self, x, d):
# x: [N, 2], in [-1, 1]

h = self.encoder_bg(x) # [N, C]
d = self.encoder_dir(d)

h = torch.cat([d, h], dim=-1)
for l in range(self.num_layers_bg):
h = self.bg_net[l](h)
if l != self.num_layers_bg - 1:
h = F.relu(h, inplace=True)

# sigmoid activation for rgb
rgbs = torch.sigmoid(h)

return rgbs

# TODO: non cuda-ray mode is broken for now... (how to pass color_basis to self.color())
# # allow masked inference
# def color(self, x, d, mask=None, geo_feat=None, **kwargs):
# # x: [N, 3] in [-bound, bound]
# # t: [1, 1], in [0, 1]
# # mask: [N,], bool, indicates where we actually needs to compute rgb.

# if mask is not None:
# rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
# # in case of empty mask
# if not mask.any():
# return rgbs
# x = x[mask]
# d = d[mask]
# geo_feat = geo_feat[mask]

# d = self.encoder_dir(d)
# h = torch.cat([d, geo_feat], dim=-1)
# for l in range(self.num_layers_color):
# h = self.color_net[l](h)
# if l != self.num_layers_color - 1:
# h = F.relu(h, inplace=True)

# # sigmoid activation for rgb
# h = torch.sigmoid(h)

# if mask is not None:
# rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
# else:
# rgbs = h

# return rgbs

# optimizer utils
def get_params(self, lr, lr_net):

params = [
{'params': self.encoder.parameters(), 'lr': lr},
{'params': self.sigma_net.parameters(), 'lr': lr_net},
{'params': self.encoder_dir.parameters(), 'lr': lr},
{'params': self.color_net.parameters(), 'lr': lr_net},
{'params': self.encoder_time.parameters(), 'lr': lr},
{'params': self.basis_net.parameters(), 'lr': lr_net},
]
if self.bg_radius > 0:
params.append({'params': self.encoder_bg.parameters(), 'lr': lr})
params.append({'params': self.bg_net.parameters(), 'lr': lr_net})

return params
4 changes: 2 additions & 2 deletions dnerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def train_step(self, data):
loss = loss.mean()

# deform regularization
deform = outputs['deform']
loss = loss + 1e-3 * deform.abs().mean()
if 'deform' in outputs and outputs['deform'] is not None:
loss = loss + 1e-3 * outputs['deform'].abs().mean()

return pred_rgb, gt_rgb, loss

Expand Down
7 changes: 6 additions & 1 deletion main_dnerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--basis', action='store_true', help="[experimental] use temporal basis instead of deformation to model dynamic scene (check Fourier PlenOctree and NeuVV)")
# parser.add_argument('--ff', action='store_true', help="use fully-fused MLP")
# parser.add_argument('--tcnn', action='store_true', help="use TCNN backend")

Expand Down Expand Up @@ -69,7 +70,11 @@
opt.cuda_ray = True
opt.preload = True

from dnerf.network import NeRFNetwork
if opt.basis:
assert opt.cuda_ray, "Non-cuda-ray mode is temporarily broken with temporal basis mode"
from dnerf.network_basis import NeRFNetwork
else:
from dnerf.network import NeRFNetwork

print(opt)

Expand Down
13 changes: 9 additions & 4 deletions nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None,
return {
'depth': depth,
'image': image,
'weights_sum': weights_sum,
}


Expand All @@ -274,6 +275,8 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
elif bg_color is None:
bg_color = 1

results = {}

if self.training:
# setup counter
counter = self.step_counter[self.local_step % 16]
Expand Down Expand Up @@ -314,6 +317,8 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
image = image.view(*prefix, 3)
depth = depth.view(*prefix)

results['weights_sum'] = weights_sum

else:

Expand Down Expand Up @@ -365,11 +370,11 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
image = image.view(*prefix, 3)
depth = depth.view(*prefix)

results['depth'] = depth
results['image'] = image

return {
'depth': depth,
'image': image,
}
return results

@torch.no_grad()
def mark_untrained_grid(self, poses, intrinsic, S=64):
Expand Down
9 changes: 7 additions & 2 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init__(self,
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)

# variable init
self.epoch = 1
self.epoch = 0
self.global_step = 0
self.local_step = 0
self.stats = {
Expand Down Expand Up @@ -442,6 +442,11 @@ def train_step(self, data):

loss = loss.mean()

# extra loss
# pred_weights_sum = outputs['weights_sum'] + 1e-8
# loss_ws = - 1e-1 * pred_weights_sum * torch.log(pred_weights_sum) # entropy to encourage weights_sum to be 0 or 1.
# loss = loss + loss_ws.mean()

return pred_rgb, gt_rgb, loss

def eval_step(self, data):
Expand Down Expand Up @@ -523,7 +528,7 @@ def train(self, train_loader, valid_loader, max_epochs):
# get a ref to error_map
self.error_map = train_loader._data.error_map

for epoch in range(self.epoch, max_epochs + 1):
for epoch in range(self.epoch + 1, max_epochs + 1):
self.epoch = epoch

self.train_one_epoch(train_loader)
Expand Down
4 changes: 4 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,12 @@ python main_CCNeRF.py data/nerf_synthetic/hotdog --workspace trial_cc_hotdog -O

### D-NeRF
# almost the same as Instant-ngp NeRF, just replace the main script.
# use deformation to model dynamic scene
python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0
python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0 --gui
# use temporal basis to model dynamic scene
python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_basis_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0 --basis
python main_dnerf.py data/dnerf/jumpingjacks --workspace trial_dnerf_basis_jumpingjacks -O --bound 1.0 --scale 0.8 --dt_gamma 0 --basis --gui
# for the hypernerf dataset, first convert it into nerf-compatible format:
python scripts/hyper2nerf.py data/split-cookie --downscale 2 # will generate transforms*.json
python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies -O --bound 1 --scale 0.3 --dt_gamma 0
Expand Down
Loading

0 comments on commit 5a7222a

Please sign in to comment.