Skip to content

Commit

Permalink
major update
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Mar 27, 2022
1 parent a91e9cb commit dd4378c
Show file tree
Hide file tree
Showing 24 changed files with 293 additions and 2,367 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ tmp*
data/
trial*/
volsdf/
**volsdf*
**volsdf*
tensorf/
**tensorf*
16 changes: 11 additions & 5 deletions ffmlp/ffmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def forward(ctx, inputs, weights, input_dim, output_dim, hidden_dim, num_layers,

B = inputs.shape[0]

assert B >= 128 and B % 128 == 0, f"ffmlp batch size must be 128 * m (m > 0), but got {B}."

inputs = inputs.contiguous()
weights = weights.contiguous()

Expand Down Expand Up @@ -148,12 +146,20 @@ def forward(self, inputs):
# return: [B, outupt_dim]

#print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item(), inputs.requires_grad)


B, C = inputs.shape
#assert B >= 128 and B % 128 == 0, f"ffmlp batch size must be 128 * m (m > 0), but got {B}."

# pad input
pad = 128 - (B % 128)
if pad > 0:
inputs = torch.cat([inputs, torch.zeros(pad, C, dtype=inputs.dtype, device=inputs.device)], dim=0)

outputs = ffmlp_forward(inputs, self.weights, self.input_dim, self.padded_output_dim, self.hidden_dim, self.num_layers, self.activation, self.output_activation, not self.training, inputs.requires_grad)

# unpad output
if self.padded_output_dim != self.output_dim:
outputs = outputs[:, :self.output_dim]
if B != outputs.shape[0] or self.padded_output_dim != self.output_dim:
outputs = outputs[:B, :self.output_dim]

#print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())

Expand Down
4 changes: 2 additions & 2 deletions main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
parser.add_argument('--num_rays', type=int, default=4096)
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
# (only valid when not using --cuda_ray)
parser.add_argument('--num_steps', type=int, default=128)
parser.add_argument('--upsample_steps', type=int, default=128)
parser.add_argument('--num_steps', type=int, default=512)
parser.add_argument('--upsample_steps', type=int, default=0)
parser.add_argument('--max_ray_batch', type=int, default=4096)
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
Expand Down
4 changes: 2 additions & 2 deletions main_tensoRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch")
parser.add_argument('--l1_reg_weight', type=float, default=4e-5)
# (only valid when not using --cuda_ray)
parser.add_argument('--num_steps', type=int, default=128)
parser.add_argument('--upsample_steps', type=int, default=128)
parser.add_argument('--num_steps', type=int, default=512)
parser.add_argument('--upsample_steps', type=int, default=0)
parser.add_argument('--max_ray_batch', type=int, default=4096)
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
Expand Down
6 changes: 3 additions & 3 deletions main_tensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
parser.add_argument('--N_voxel_init', type=int, default=128**3)
parser.add_argument('--N_voxel_final', type=int, default=300**3)
parser.add_argument("--upsamp_list", type=int, action="append", default=[2000,3000,4000,5500,7000])
parser.add_argument("--update_AlphaMask_list", type=int, action="append", default=[2000,4000])
parser.add_argument("--update_AlphaMask_list", type=int, action="append", default=[]) # [2000,4000]
parser.add_argument('--lindisp', default=False, action="store_true", help='use disparity depth sampling')
parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--accumulate_decay", type=float, default=0.998)
Expand Down Expand Up @@ -71,7 +71,7 @@

aabb = (torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]) * opt.bound).to(device)
reso_cur = N_to_reso(opt.N_voxel_init, aabb)
nSamples = min(opt.nSamples, cal_n_samples(reso_cur, opt.step_ratio))
nSamples = 512 # min(opt.nSamples, cal_n_samples(reso_cur, opt.step_ratio))
near_far = [2.0, 6.0] # fixed for blender
N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(opt.N_voxel_init), np.log(opt.N_voxel_final), len(opt.upsamp_list)+1))).long()).tolist()[1:]

Expand Down Expand Up @@ -109,7 +109,7 @@

scheduler = lambda optimizer: optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.33)

trainer = Trainer('tensorf', vars(opt), model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, metrics=[PSNRMeter()], use_checkpoint='latest', eval_interval=50)
trainer = Trainer('tensorf', vars(opt), model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, metrics=[PSNRMeter()], use_checkpoint='scratch', eval_interval=50)

# attach extra things
trainer.aabb = aabb
Expand Down
43 changes: 39 additions & 4 deletions nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(self,


def forward(self, x, d):
# x: [B, N, 3], in [-bound, bound]
# d: [B, N, 3], nomalized in [-1, 1]
# x: [N, 3], in [-bound, bound]
# d: [N, 3], nomalized in [-1, 1]

# sigma
x = self.encoder(x, bound=self.bound)
Expand Down Expand Up @@ -96,7 +96,7 @@ def forward(self, x, d):
return sigma, color

def density(self, x):
# x: [B, N, 3], in [-bound, bound]
# x: [N, 3], in [-bound, bound]

x = self.encoder(x, bound=self.bound)
h = x
Expand All @@ -106,5 +106,40 @@ def density(self, x):
h = F.relu(h, inplace=True)

sigma = F.relu(h[..., 0])
geo_feat = h[..., 1:]

return {
'sigma': sigma,
'geo_feat': geo_feat,
}

# allow masked inference
def color(self, x, d, mask=None, geo_feat=None, **kwargs):
# x: [N, 3] in [-bound, bound]
# mask: [N,], bool, indicates where we actually needs to compute rgb.

if mask is not None:
rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
# in case of empty mask
if not mask.any():
return rgbs
x = x[mask]
d = d[mask]
geo_feat = geo_feat[mask]

d = self.encoder_dir(d)
h = torch.cat([d, geo_feat], dim=-1)
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)

# sigmoid activation for rgb
h = torch.sigmoid(h)

if mask is not None:
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
else:
rgbs = h

return sigma
return rgbs
22 changes: 20 additions & 2 deletions nerf/network_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,42 @@ def color(self, x, d, mask=None, geo_feat=None, **kwargs):
# x: [N, 3] in [-bound, bound]
# mask: [N,], bool, indicates where we actually needs to compute rgb.

#starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
#starter.record()

if mask is not None:
rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
# in case of empty mask
if not mask.any():
return rgbs
x = x[mask]
d = d[mask]
geo_feat = geo_feat[mask]

#print(x.shape, rgbs.shape)

#ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'mask = {curr_time}')
#starter.record()

d = self.encoder_dir(d)

p = torch.zeros_like(geo_feat[..., :1]) # manual input padding
h = torch.cat([d, geo_feat, p], dim=-1)

h = self.color_net(h)

# sigmoid activation for rgb
h = torch.sigmoid(h)

#ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'call = {curr_time}')
#starter.record()

if mask is not None:
rgbs = torch.zeros(np.prod(prefix), 3, dtype=h.dtype, device=h.device) # [N, 3]
rgbs[mask] = h
rgbs[mask] = h.to(rgbs.dtype)
else:
rgbs = h

#ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'unmask = {curr_time}')
#starter.record()

return rgbs
54 changes: 40 additions & 14 deletions nerf/network_tcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,9 @@ def __init__(self,


def forward(self, x, d):
# x: [B, N, 3], in [-bound, bound]
# d: [B, N, 3], nomalized in [-1, 1]
# x: [N, 3], in [-bound, bound]
# d: [N, 3], nomalized in [-1, 1]

prefix = x.shape[:-1]
x = x.view(-1, 3)
d = d.view(-1, 3)

# sigma
x = (x + self.bound) / (2 * self.bound) # to [0, 1]
Expand All @@ -106,25 +103,54 @@ def forward(self, x, d):

# sigmoid activation for rgb
color = torch.sigmoid(h)

sigma = sigma.view(*prefix)
color = color.view(*prefix, -1)

return sigma, color

def density(self, x):
# x: [B, N, 3], in [-bound, bound]

prefix = x.shape[:-1]
x = x.view(-1, 3)
# x: [N, 3], in [-bound, bound]

x = (x + self.bound) / (2 * self.bound) # to [0, 1]
x = self.encoder(x)
h = self.sigma_net(x)

#sigma = torch.exp(torch.clamp(h[..., 0], -15, 15))
sigma = F.relu(h[..., 0])
geo_feat = h[..., 1:]

return {
'sigma': sigma,
'geo_feat': geo_feat,
}

# allow masked inference
def color(self, x, d, mask=None, geo_feat=None, **kwargs):
# x: [N, 3] in [-bound, bound]
# mask: [N,], bool, indicates where we actually needs to compute rgb.

x = (x + self.bound) / (2 * self.bound) # to [0, 1]

if mask is not None:
rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3]
# in case of empty mask
if not mask.any():
return rgbs
x = x[mask]
d = d[mask]
geo_feat = geo_feat[mask]

# color
d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1]
d = self.encoder_dir(d)

h = torch.cat([d, geo_feat], dim=-1)
h = self.color_net(h)

# sigmoid activation for rgb
h = torch.sigmoid(h)

sigma = sigma.view(*prefix)
if mask is not None:
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
else:
rgbs = h

return sigma
return rgbs
Loading

0 comments on commit dd4378c

Please sign in to comment.