Skip to content

Commit

Permalink
fix running time bug
Browse files Browse the repository at this point in the history
  • Loading branch information
dunbar12138 committed Aug 30, 2021
1 parent 935487c commit f20a20b
Showing 1 changed file with 116 additions and 17 deletions.
133 changes: 116 additions & 17 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

from loss import SigmaLoss


from data import RayDataset
from torch.utils.data import DataLoader

from utils.generate_renderpath import generate_renderpath
import cv2
# import time

# concate_time, iter_time, split_time, loss_time, backward_time = [], [], [], [], []


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -216,9 +220,9 @@ def render_test_ray(rays_o, rays_d, hwf, ndc, near, far, use_viewdirs, N_samples

z_vals = z_vals.reshape([rays_o.shape[0], N_samples])

rgb, sigma, depth_maps = sample_sigma(rays_o, rays_d, viewdirs, network, z_vals, network_query_fn)
rgb, sigma, depth_maps, weights = sample_sigma(rays_o, rays_d, viewdirs, network, z_vals, network_query_fn)

return rgb, sigma, z_vals, depth_maps
return rgb, sigma, z_vals, depth_maps, weights


def create_nerf(args):
Expand Down Expand Up @@ -614,6 +618,8 @@ def config_parser():
help="single forward for both depth and rgb")
parser.add_argument("--normalize_depth", action='store_true',
help="normalize depth before calculating loss")
parser.add_argument("--depth_rays_prop", type=float, default=0.5,
help="Proportion of depth rays.")
return parser


Expand All @@ -622,7 +628,6 @@ def train():
parser = config_parser()
args = parser.parse_args()

# Load data

if args.dataset_type == 'llff':
if args.colmap_depth:
Expand Down Expand Up @@ -756,11 +761,13 @@ def train():
index_pose = i_train[0]
rays_o, rays_d = get_rays_by_coord_np(H, W, focal, poses[index_pose,:3,:4], depth_gts[index_pose]['coord'])
rays_o, rays_d = torch.Tensor(rays_o).to(device), torch.Tensor(rays_d).to(device)
rgb, sigma, z_vals, depth_maps = render_test_ray(rays_o, rays_d, hwf, network=render_kwargs_test['network_fine'], **render_kwargs_test)
rgb, sigma, z_vals, depth_maps, weights = render_test_ray(rays_o, rays_d, hwf, network=render_kwargs_test['network_fine'], **render_kwargs_test)
# sigma = sigma.reshape(H, W, -1).cpu().numpy()
# z_vals = z_vals.reshape(H, W, -1).cpu().numpy()
# np.savez(os.path.join(testsavedir, 'rays.npz'), rgb=rgb.cpu().numpy(), sigma=sigma.cpu().numpy(), z_vals=z_vals.cpu().numpy())
visualize_sigma(sigma[0, :].cpu().numpy(), z_vals[0, :].cpu().numpy(), os.path.join(testsavedir, 'rays.png'))
# visualize_sigma(sigma[0, :].cpu().numpy(), z_vals[0, :].cpu().numpy(), os.path.join(testsavedir, 'rays.png'))
for k in range(20):
visualize_weights(weights[k*100, :].cpu().numpy(), z_vals[k*100, :].cpu().numpy(), os.path.join(testsavedir, f'rays_weights_%d.png' % k))
print("colmap depth:", depth_gts[index_pose]['depth'][0])
print("Estimated depth:", depth_maps[0].cpu().numpy())
print(depth_gts[index_pose]['coord'])
Expand All @@ -776,7 +783,11 @@ def train():
return

# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
if not args.colmap_depth:
N_rgb = args.N_rand
else:
N_depth = int(args.N_rand * args.depth_rays_prop)
N_rgb = args.N_rand - N_depth
use_batching = not args.no_batching
if use_batching:
# For random ray batching
Expand Down Expand Up @@ -818,7 +829,7 @@ def train():
print('shuffle depth rays')
np.random.shuffle(rays_depth)

max_depth = np.max(rays_depth[:,3,0])
max_depth = np.max(rays_depth[:,3,0])
print('done')
i_batch = 0

Expand All @@ -830,8 +841,8 @@ def train():
if use_batching:
# rays_rgb = torch.Tensor(rays_rgb).to(device)
# rays_depth = torch.Tensor(rays_depth).to(device) if rays_depth is not None else None
raysRGB_iter = iter(DataLoader(RayDataset(rays_rgb), batch_size = N_rand, shuffle=True, num_workers=0))
raysDepth_iter = iter(DataLoader(RayDataset(rays_depth), batch_size = N_rand, shuffle=True, num_workers=0)) if rays_depth is not None else None
raysRGB_iter = iter(DataLoader(RayDataset(rays_rgb), batch_size = N_rgb, shuffle=True, num_workers=0))
raysDepth_iter = iter(DataLoader(RayDataset(rays_depth), batch_size = N_depth, shuffle=True, num_workers=0)) if rays_depth is not None else None


N_iters = args.N_iters + 1
Expand All @@ -854,7 +865,7 @@ def train():
try:
batch = next(raysRGB_iter).to(device)
except StopIteration:
raysRGB_iter = iter(DataLoader(RayDataset(rays_rgb), batch_size = N_rand, shuffle=True, num_workers=0))
raysRGB_iter = iter(DataLoader(RayDataset(rays_rgb), batch_size = N_rgb, shuffle=True, num_workers=0))
batch = next(raysRGB_iter).to(device)
batch = torch.transpose(batch, 0, 1)
batch_rays, target_s = batch[:2], batch[2]
Expand All @@ -864,7 +875,7 @@ def train():
try:
batch_depth = next(raysDepth_iter).to(device)
except StopIteration:
raysDepth_iter = iter(DataLoader(RayDataset(rays_depth), batch_size = N_rand, shuffle=True, num_workers=0))
raysDepth_iter = iter(DataLoader(RayDataset(rays_depth), batch_size = N_depth, shuffle=True, num_workers=0))
batch_depth = next(raysDepth_iter).to(device)
batch_depth = torch.transpose(batch_depth, 0, 1)
batch_rays_depth = batch_depth[:2] # 2 x B x 3
Expand Down Expand Up @@ -905,24 +916,43 @@ def train():
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)

coords = torch.reshape(coords, [-1,2]) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
select_inds = np.random.choice(coords.shape[0], size=[N_rgb], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.stack([rays_o, rays_d], 0)
batch_rays = torch.stack([rays_o, rays_d], 0) # (2, N_rand, 3)
target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)

##### Core optimization loop #####
# timer_0 = time.perf_counter()

if args.colmap_depth:
N_batch = batch_rays.shape[1]
batch_rays = torch.cat([batch_rays, batch_rays_depth], 1) # (2, 2 * N_rand, 3)

# timer_concate = time.perf_counter()


rgb, disp, acc, depth, extras = render(H, W, focal, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)
# timer_iter = time.perf_counter()

if args.colmap_depth and not args.depth_with_rgb:
_, _, _, depth_col, extras_col = render(H, W, focal, chunk=args.chunk, rays=batch_rays_depth,
verbose=i < 10, retraw=True, depths=target_depth,
**render_kwargs_train)
# _, _, _, depth_col, extras_col = render(H, W, focal, chunk=args.chunk, rays=batch_rays_depth,
# verbose=i < 10, retraw=True, depths=target_depth,
# **render_kwargs_train)
rgb = rgb[:N_batch, :]
disp = disp[:N_batch]
acc = acc[:N_batch]
depth, depth_col = depth[:N_batch], depth[N_batch:]
extras = {x:extras[x][:N_batch] for x in extras}
# extras_col = extras[N_rand:, :]

elif args.colmap_depth and args.depth_with_rgb:
depth_col = depth

# timer_split = time.perf_counter()

optimizer.zero_grad()
img_loss = img2mse(rgb, target_s)
Expand All @@ -941,11 +971,13 @@ def train():
sigma_loss = 0
if args.sigma_loss:
sigma_loss = extras_col['sigma_loss'].mean()
print(sigma_loss)
# print(sigma_loss)
trans = extras['raw'][...,-1]
loss = img_loss + args.depth_lambda * depth_loss + args.sigma_lambda * sigma_loss
psnr = mse2psnr(img_loss)

# timer_loss = time.perf_counter()

if 'rgb0' in extras and not args.no_coarse:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
Expand All @@ -954,6 +986,26 @@ def train():
loss.backward()
optimizer.step()

# timer_backward = time.perf_counter()
# print('\nconcate:',timer_concate-timer_0)
# print('iter',timer_iter-timer_concate)
# print('split',timer_split-timer_iter)
# print('loss',timer_loss-timer_split)
# print('backward',timer_backward-timer_loss)
# concate_time.append(timer_concate-timer_0)
# iter_time.append(timer_iter-timer_concate)
# split_time.append(timer_split-timer_iter)
# loss_time.append(timer_loss-timer_split)
# backward_time.append(timer_backward-timer_loss)

# if i%10 == 0:
# print('\nconcate:',np.mean(concate_time))
# print('iter',np.mean(iter_time))
# print('split',np.mean(split_time))
# print('loss',np.mean(loss_time))
# print('backward',np.mean(backward_time))
# print('total:',np.mean(concate_time)+np.mean(iter_time)+np.mean(split_time)+np.mean(loss_time)+np.mean(backward_time))

# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
Expand Down Expand Up @@ -988,6 +1040,12 @@ def train():
imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.nanmax(disps)), fps=30, quality=8)


# if args.use_viewdirs:
# render_kwargs_test['c2w_staticcam'] = render_poses[0][:3,:4]
# with torch.no_grad():
# rgbs_still, _ = render_path(render_poses, hwf, args.chunk, render_kwargs_test)
# render_kwargs_test['c2w_staticcam'] = None
# imageio.mimwrite(moviebase + 'rgb_still.mp4', to8b(rgbs_still), fps=30, quality=8)

if i%args.i_testset==0 and i > 0 and len(i_test) > 0:
testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
Expand All @@ -1005,6 +1063,47 @@ def train():

if i%args.i_print==0:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr)
tf.contrib.summary.histogram('tran', trans)
if args.N_importance > 0:
tf.contrib.summary.scalar('psnr0', psnr0)
if i%args.i_img==0:
# Log a rendered validation view to Tensorboard
img_i=np.random.choice(i_val)
target = images[img_i]
pose = poses[img_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
**render_kwargs_test)
psnr = mse2psnr(img2mse(rgb, target))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.scalar('psnr_holdout', psnr)
tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
if args.N_importance > 0:
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
"""

global_step += 1

Expand Down

0 comments on commit f20a20b

Please sign in to comment.