Skip to content

Commit

Permalink
GUI: support depth mode, add max_steps slider, show FPS
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jul 31, 2022
1 parent 87bb685 commit 37a1c8c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 29 deletions.
38 changes: 26 additions & 12 deletions dnerf/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,14 @@ def orbit(self, dx, dy):
rotvec_y = side * np.radians(-0.1 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot

# wrong: rotate along global x/y axis
#self.rot = R.from_euler('xy', [-dy * 0.1, -dx * 0.1], degrees=True) * self.rot

def scale(self, delta):
self.radius *= 1.1 ** (-delta)

def pan(self, dx, dy, dz=0):
# pan in camera coordinate system (careful on the sensitivity!)
self.center += 0.001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])

# wrong: pan in global coordinate system
#self.center += 0.001 * np.array([-dx, -dy, dz])
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])



class NeRFGUI:
def __init__(self, opt, trainer, train_loader=None, debug=True):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
Expand All @@ -78,6 +71,7 @@ def __init__(self, opt, trainer, train_loader=None, debug=True):
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel
self.mode = 'image' # choose from ['image', 'depth']
self.time = 0 # time for dynamic scene, in [0, 1]

self.dynamic_resolution = True
Expand Down Expand Up @@ -107,7 +101,7 @@ def train_step(self):
self.step += self.train_steps
self.need_update = True

dpg.set_value("_log_train_time", f'{t:.4f}ms')
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')

# dynamic train steps
Expand All @@ -117,6 +111,12 @@ def train_step(self):
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
self.train_steps = train_steps

def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image']
else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)


def test_step(self):
# TODO: seems we have to move data from GPU --> CPU --> GPU?
Expand All @@ -141,14 +141,14 @@ def test_step(self):
self.downscale = downscale

if self.need_update:
self.render_buffer = outputs['image']
self.render_buffer = self.prepare_buffer(outputs)
self.spp = 1
self.need_update = False
else:
self.render_buffer = (self.render_buffer * self.spp + outputs['image']) / (self.spp + 1)
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
self.spp += 1

dpg.set_value("_log_infer_time", f'{t:.4f}ms')
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
dpg.set_value("_log_spp", self.spp)
dpg.set_value("_texture", self.render_buffer)
Expand Down Expand Up @@ -278,6 +278,13 @@ def callback_set_dynamic_resolution(sender, app_data):
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")

# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True

dpg.add_combo(('image', 'depth'), label='image_mode', default_value=self.mode, callback=callback_change_mode)

# time slider
def callback_set_time(sender, app_data):
self.time = app_data
Expand Down Expand Up @@ -306,6 +313,13 @@ def callback_set_dt_gamma(sender, app_data):

dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)

# max_steps slider
def callback_set_max_steps(sender, app_data):
self.opt.max_steps = app_data
self.need_update = True

dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)

# aabb slider
def callback_set_aabb(sender, app_data, user_data):
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
Expand Down
38 changes: 26 additions & 12 deletions nerf/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,14 @@ def orbit(self, dx, dy):
rotvec_y = side * np.radians(-0.1 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot

# wrong: rotate along global x/y axis
#self.rot = R.from_euler('xy', [-dy * 0.1, -dx * 0.1], degrees=True) * self.rot

def scale(self, delta):
self.radius *= 1.1 ** (-delta)

def pan(self, dx, dy, dz=0):
# pan in camera coordinate system (careful on the sensitivity!)
self.center += 0.001 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])

# wrong: pan in global coordinate system
#self.center += 0.001 * np.array([-dx, -dy, dz])
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz])



class NeRFGUI:
def __init__(self, opt, trainer, train_loader=None, debug=True):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
Expand All @@ -78,6 +71,7 @@ def __init__(self, opt, trainer, train_loader=None, debug=True):
self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # camera moved, should reset accumulation
self.spp = 1 # sample per pixel
self.mode = 'image' # choose from ['image', 'depth']

self.dynamic_resolution = True
self.downscale = 1
Expand Down Expand Up @@ -106,7 +100,7 @@ def train_step(self):
self.step += self.train_steps
self.need_update = True

dpg.set_value("_log_train_time", f'{t:.4f}ms')
dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}')

# dynamic train steps
Expand All @@ -116,6 +110,12 @@ def train_step(self):
if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
self.train_steps = train_steps

def prepare_buffer(self, outputs):
if self.mode == 'image':
return outputs['image']
else:
return np.expand_dims(outputs['depth'], -1).repeat(3, -1)


def test_step(self):
# TODO: seems we have to move data from GPU --> CPU --> GPU?
Expand All @@ -140,14 +140,14 @@ def test_step(self):
self.downscale = downscale

if self.need_update:
self.render_buffer = outputs['image']
self.render_buffer = self.prepare_buffer(outputs)
self.spp = 1
self.need_update = False
else:
self.render_buffer = (self.render_buffer * self.spp + outputs['image']) / (self.spp + 1)
self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1)
self.spp += 1

dpg.set_value("_log_infer_time", f'{t:.4f}ms')
dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)')
dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}')
dpg.set_value("_log_spp", self.spp)
dpg.set_value("_texture", self.render_buffer)
Expand Down Expand Up @@ -277,6 +277,13 @@ def callback_set_dynamic_resolution(sender, app_data):
dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution)
dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution")

# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True

dpg.add_combo(('image', 'depth'), label='image_mode', default_value=self.mode, callback=callback_change_mode)

# bg_color picker
def callback_change_bg(sender, app_data):
self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1]
Expand All @@ -298,6 +305,13 @@ def callback_set_dt_gamma(sender, app_data):

dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma)

# max_steps slider
def callback_set_max_steps(sender, app_data):
self.opt.max_steps = app_data
self.need_update = True

dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps)

# aabb slider
def callback_set_aabb(sender, app_data, user_data):
# user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax)
Expand Down
3 changes: 2 additions & 1 deletion nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ def update_extra_state(self, decay=0.95, S=128):
# ema update
valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask])
self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 non-training regions are viewed as 0 density.
self.mean_density = torch.mean(self.density_grid.clamp(min=0)).item() # -1 regions are viewed as 0 density.
#self.mean_density = torch.mean(self.density_grid[self.density_grid > 0]).item() # do not count -1 regions
self.iter_density += 1

# convert to bitfield
Expand Down
9 changes: 5 additions & 4 deletions raymarching/src/raymarching.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ __global__ void kernel_march_rays_train(
// get mip level
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]

const float mip_bound = fminf((float)(1 << level), bound);
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
const float mip_rbound = 1 / mip_bound;

// convert to nearest grid position
Expand Down Expand Up @@ -439,7 +439,7 @@ __global__ void kernel_march_rays_train(
// get mip level
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]

const float mip_bound = fminf((float)(1 << level), bound);
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
const float mip_rbound = 1 / mip_bound;

// convert to nearest grid position
Expand Down Expand Up @@ -767,7 +767,7 @@ __global__ void kernel_march_rays(
// get mip level
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]

const float mip_bound = fminf((float)(1 << level), bound);
const float mip_bound = fminf(scalbnf(1, level), bound);
const float mip_rbound = 1 / mip_bound;

// convert to nearest grid position
Expand Down Expand Up @@ -887,7 +887,8 @@ __global__ void kernel_composite_rays(
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);

// ray is terminated if T is too small
if (T < 1e-4) break;
// use a larger bound to further accelerate inference
if (T < 1e-2f) break;

// locate
sigmas++;
Expand Down

0 comments on commit 37a1c8c

Please sign in to comment.