Skip to content

Commit

Permalink
add normal supervision
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuanyu committed Dec 19, 2023
1 parent 160e048 commit dd5f3e8
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 17 deletions.
60 changes: 51 additions & 9 deletions nerf/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def __init__(self, opt, device, type='train', downscale=1, n_test=10):
frames = transform["frames"]
#frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort...

# load normals and depths
self.load_normal_depth()

# for colmap, manually interpolate a test set.
if self.mode == 'colmap' and type == 'test':
# choose two random poses, and interpolate between.
Expand All @@ -181,24 +184,30 @@ def __init__(self, opt, device, type='train', downscale=1, n_test=10):
# self.poses.append(pose)

# New way for testing: interpolate between two frames from the training dataset
# self.poses = []
# self.images = None
# for i in range(len(frames)):
# pose0 = nerf_matrix_to_ngp(np.array(frames[i]['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
# pose1 = nerf_matrix_to_ngp(np.array(frames[(i+1)%len(frames)]['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
# rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]]))
# slerp = Slerp([0, 1], rots)
# pose_m = np.eye(4, dtype=np.float32)
# pose_m[:3, :3] = slerp(0.5).as_matrix()
# pose_m[:3, 3] = 0.5 * pose0[:3, 3] + 0.5 * pose1[:3, 3]
# self.poses.append(pose0)
# self.poses.append(pose_m)

self.poses = []
self.images = None
for i in range(len(frames)):
pose0 = nerf_matrix_to_ngp(np.array(frames[i]['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
pose1 = nerf_matrix_to_ngp(np.array(frames[(i+1)%len(frames)]['transform_matrix'], dtype=np.float32), scale=self.scale, offset=self.offset) # [4, 4]
rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]]))
slerp = Slerp([0, 1], rots)
pose_m = np.eye(4, dtype=np.float32)
pose_m[:3, :3] = slerp(0.5).as_matrix()
pose_m[:3, 3] = 0.5 * pose0[:3, 3] + 0.5 * pose1[:3, 3]
self.poses.append(pose0)
self.poses.append(pose_m)

else:
# for colmap, manually split a valid set (the first frame).
if self.mode == 'colmap':
if type == 'train':
frames = frames[1:]
frames = frames[0:]
elif type == 'val':
frames = frames[:1]
# else 'all' or 'trainval' : use all frames
Expand Down Expand Up @@ -235,7 +244,7 @@ def __init__(self, opt, device, type='train', downscale=1, n_test=10):

self.poses.append(pose)
self.images.append(image)

self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4]
if self.images is not None:
self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C]
Expand Down Expand Up @@ -267,6 +276,9 @@ def __init__(self, opt, device, type='train', downscale=1, n_test=10):
self.images = self.images.to(dtype).to(self.device)
if self.error_map is not None:
self.error_map = self.error_map.to(self.device)
if self.normals is not None and self.depths is not None:
self.normals = self.normals.to(self.device)
self.depths = self.depths.to(self.device)

# load intrinsics
if 'fl_x' in transform or 'fl_y' in transform:
Expand All @@ -285,6 +297,22 @@ def __init__(self, opt, device, type='train', downscale=1, n_test=10):
cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2)

self.intrinsics = np.array([fl_x, fl_y, cx, cy])

def load_normal_depth(self):
arr_path = os.path.join(self.root_path, "aovs.npz")
self.normals = None
self.depths = None
if os.path.exists(arr_path): # and self.training:
arr = np.load(arr_path)
normals = arr["normals"]
depths = arr["depths"]
# TODO: add transformations here
normals[:, :, :, 1] *= -1
normals[:, :, :, [0, 2]] = normals[:, :, :, [2, 0]]
normals[:, :, :, 2] *= -1

self.normals = torch.from_numpy(normals)
self.depths = torch.from_numpy(depths)


def collate(self, index):
Expand Down Expand Up @@ -327,6 +355,20 @@ def collate(self, index):
C = images.shape[-1]
images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
results['images'] = images

if self.normals is not None:
normals = self.normals[index].to(self.device) # [B, H, W, 3]
if self.training:
C = normals.shape[-1]
normals = torch.gather(normals.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3]
results['normals'] = normals

if self.depths is not None:
depths = self.depths[index].to(self.device) # [B, H, W, 1]
if self.training:
C = depths.shape[-1]
depths = torch.gather(depths.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 1]
results['depths'] = depths

# need inds to update error_map
if error_map is not None:
Expand Down
6 changes: 6 additions & 0 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,12 @@ def train_step(self, data):

# MSE loss
loss = self.criterion(pred_rgb, gt_rgb).mean(-1) # [B, N, 3] --> [B, N]

# normal loss
if 'normals' in data and self.epoch > 5:
gt_normal = data['normals'] * images[..., 3:] + bg_color * (1 - images[..., 3:])
pred_normal = outputs['normal']
loss += 0.2 * self.criterion(pred_normal, gt_normal).mean(-1) # [B, N, 3] --> [B, N]

# patch-based rendering
if self.opt.patch_size > 1:
Expand Down
21 changes: 13 additions & 8 deletions nerf2occ.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def load_model(self):
@torch.no_grad
def run(
self,
grid_size = 128,
sample_interval = 2
grid_size = 64,
sample_interval = 8 # if grid_size set to 64, max value of sample_interval should be 16
):
bound = 1.0
# sample points
gsize = grid_size*sample_interval
X = torch.arange(gsize, dtype=torch.float32, device=self.device)/(0.5 * gsize) - 1.0
Y = torch.arange(gsize, dtype=torch.float32, device=self.device)/(0.5 * gsize) - 1.0
Z = torch.arange(gsize, dtype=torch.float32, device=self.device)/(0.5 * gsize) - 1.0
Y = (torch.arange(gsize, dtype=torch.float32, device=self.device)/(gsize) - 0.5) * (2 * bound)
Z = (torch.arange(gsize, dtype=torch.float32, device=self.device)/(gsize) - 0.5) * (2 * bound)
X = (torch.arange(gsize, dtype=torch.float32, device=self.device)/(gsize) - 0.5) * (2 * bound)
points = torch.concat([t.unsqueeze(dim=3) for t in torch.meshgrid(X, Y, Z, indexing='ij')], dim=3)
# get density (sigma and alpha)
points = points.view(-1, 3)
Expand All @@ -65,9 +66,13 @@ def run(
alphas = F.max_pool3d(alphas, kernel_size=sample_interval, stride=sample_interval)
alphas = alphas.view(grid_size, grid_size, grid_size)
# occupancy thresholding
thres = 0.7
alphas[alphas > thres] = 1
alphas[alphas <= thres] = 0
thres = 0.3
if thres < 1:
alphas[alphas > thres] = 1
alphas[alphas <= thres] = 0
else:
alphas[alphas < thres] = 0
alphas[alphas >= thres] = 1
mesh = voxel2mesh(alphas.cpu().numpy())
mesh.export("test.obj")

Expand Down

0 comments on commit dd5f3e8

Please sign in to comment.