Skip to content

Commit

Permalink
fix staged inference to avoid OOM
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jan 23, 2022
1 parent 99f8181 commit 5c0c61e
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 40 deletions.
86 changes: 50 additions & 36 deletions nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,34 +108,8 @@ def __init__(self,
color_net.append(nn.Linear(in_dim, out_dim, bias=False))

self.color_net = nn.ModuleList(color_net)


def forward(self, pts, rays_d, bound, staged=False, max_batch_size=256000):
# pts: [B, N, 3]
# rays_d: [B, N, 3]

B, N = pts.shape[:2]

if staged:
sigmas = torch.zeros((B, N), device=pts.device)
rgbs = torch.zeros((B, N, 3), device=pts.device)

for b in range(B):
head = 0
while head < N:
tail = min(head + max_batch_size, N)

sigmas_, rgbs_ = self.run(pts[b:b+1, head:tail], rays_d[b:b+1, head:tail], bound)

sigmas[b:b+1, head:tail] = sigmas_.reshape(1, -1)
rgbs[b:b+1, head:tail] = rgbs_.reshape(1, -1, 3)
head += max_batch_size
else:
sigmas, rgbs = self.run(pts, rays_d, bound)

return sigmas, rgbs

def run(self, x, d, bound):
def forward(self, x, d, bound):
# x: [B, N, 3], in [-bound, bound]
# d: [B, N, 3], nomalized in [-1, 1]

Expand All @@ -161,14 +135,28 @@ def run(self, x, d, bound):

return sigma, color

def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False, max_batch_size=256000, **kwargs):
# rays_o, rays_d: [B, N, 3]
def density(self, x, bound):
# x: [B, N, 3], in [-bound, bound]

x = self.encoder(x, size=bound)
h = x
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)

sigma = h[..., 0]

return sigma

def run(self, rays_o, rays_d, num_steps, bound, upsample_steps):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: pred_rgb: [B, N, 3]

# sample steps
B, N = rays_o.shape[:2]
device = rays_o.device

# sample steps
near = rays_o.norm(dim=-1, keepdim=True) - bound # [B, N, 1]
far = near + 2 * bound

Expand All @@ -192,7 +180,7 @@ def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False,

# query SDF and RGB
rays_d_ = rays_d.unsqueeze(-2).expand_as(pts)
sigmas, rgbs = self(pts.reshape(B, -1, 3), rays_d=rays_d_.reshape(B, -1, 3), bound=bound, staged=staged, max_batch_size=max_batch_size)
sigmas, rgbs = self(pts.reshape(B, -1, 3), rays_d_.reshape(B, -1, 3), bound=bound)
rgbs = rgbs.reshape(B, N, num_steps, 3) # [B, N, T, 3]
sigmas = sigmas.reshape(B, N, num_steps) # [B, N, T]

Expand All @@ -216,7 +204,7 @@ def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False,

# only forward new points to save computation
new_rays_d_ = rays_d.unsqueeze(-2).expand_as(new_pts)
new_sigmas, new_rgbs = self(new_pts.reshape(B, -1, 3), rays_d=new_rays_d_.reshape(B, -1, 3), bound=bound, staged=staged, max_batch_size=max_batch_size)
new_sigmas, new_rgbs = self(new_pts.reshape(B, -1, 3), new_rays_d_.reshape(B, -1, 3), bound=bound)
new_rgbs = new_rgbs.reshape(B, N, upsample_steps, 3) # [B, N, t, 3]
new_sigmas = new_sigmas.reshape(B, N, upsample_steps) # [B, N, t]

Expand Down Expand Up @@ -249,10 +237,36 @@ def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False,
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [B, N, 3], in [0, 1]
#image = image + (1 - weights_sum).unsqueeze(-1) # white background (infinite depth)

# construct results
results = {}
return depth, image

results['depth'] = depth.reshape(B, -1)
results['rgb'] = image.reshape(B, -1, 3)
def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False, max_ray_batch=256000, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: pred_rgb: [B, N, 3]

B, N = rays_o.shape[:2]
device = rays_o.device

if staged:
depth = torch.zeros((B, N), device=device)
image = torch.zeros((B, N, 3), device=device)

for b in range(B):
head = 0
while head < N:
tail = min(head + max_ray_batch, N)

depth_, image_ = self.run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], num_steps, bound, upsample_steps)

depth[b:b+1, head:tail] = depth_
image[b:b+1, head:tail] = image_

head += max_ray_batch

else:
depth, image = self.run(rays_o, rays_d, num_steps, bound, upsample_steps)

results = {}
results['depth'] = depth
results['rgb'] = image

return results
4 changes: 3 additions & 1 deletion nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self,
model, # network
criterion=None, # loss function, if None, assume inline implementation in train_step
optimizer=None, # optimizer
ema_decay=None, # if use EMA, set the decay
ema_decay=0.95, # if use EMA, set the decay
lr_scheduler=None, # scheduler
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
local_rank=0, # which GPU am I
Expand Down Expand Up @@ -197,6 +197,8 @@ def __init__(self,
else:
self.ema = None

# TODO: allocate a density_grid for ray marching.

if self.fp16:
self.scaler = torch.cuda.amp.GradScaler()

Expand Down
4 changes: 2 additions & 2 deletions test_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
parser.add_argument('--num_rays', type=int, default=4096)
parser.add_argument('--num_steps', type=int, default=128)
parser.add_argument('--upsample_steps', type=int, default=128)
parser.add_argument('--max_batch_size', type=int, default=12800)
parser.add_argument('--max_ray_batch', type=int, default=4096) # lower if OOM

parser.add_argument('--radius', type=float, default=2, help="assume the camera is located on sphere(0, radius))")
parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in sphere(0, size)")
Expand All @@ -33,6 +33,6 @@
trainer = Trainer('ngp', vars(opt), model, workspace=opt.workspace, use_checkpoint='latest')

# test dataset
test_dataset = NeRFDataset(opt.path, 'test', downscale=2, radius=opt.radius)
test_dataset = NeRFDataset(opt.path, 'test', downscale=1, radius=opt.radius)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1)
trainer.test(test_loader)
2 changes: 1 addition & 1 deletion train_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
parser.add_argument('--num_rays', type=int, default=4096)
parser.add_argument('--num_steps', type=int, default=128)
parser.add_argument('--upsample_steps', type=int, default=128)
parser.add_argument('--max_batch_size', type=int, default=12800)
parser.add_argument('--max_ray_batch', type=int, default=4096)

parser.add_argument('--radius', type=float, default=2, help="assume the camera is located on sphere(0, radius))")
parser.add_argument('--bound', type=float, default=2, help="assume the scene is bounded in sphere(0, size)")
Expand Down

0 comments on commit 5c0c61e

Please sign in to comment.