forked from ashawkey/torch-ngp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
243 lines (176 loc) · 9.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
from nerf.utils import *
from nerf.utils import Trainer as _Trainer
class Trainer(_Trainer):
def __init__(self,
name, # name of this experiment
opt, # extra conf
model, # network
criterion=None, # loss function, if None, assume inline implementation in train_step
optimizer=None, # optimizer
ema_decay=None, # if use EMA, set the decay
lr_scheduler=None, # scheduler
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
local_rank=0, # which GPU am I
world_size=1, # total num of GPUs
device=None, # device to use, usually setting to None is OK. (auto choose device)
mute=False, # whether to mute all print
fp16=False, # amp optimize level
eval_interval=1, # eval once every $ epoch
max_keep_ckpt=2, # max num of saved ckpts in disk
workspace='workspace', # workspace to save logs & ckpts
best_mode='min', # the smaller/larger result, the better
use_loss_as_metric=True, # use loss as the first metric
report_metric_at_train=False, # also report metrics at training
use_checkpoint="latest", # which ckpt to use at init time
use_tensorboardX=True, # whether to use tensorboard for logging
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
):
self.optimizer_fn = optimizer
self.lr_scheduler_fn = lr_scheduler
super().__init__(name, opt, model, criterion, optimizer, ema_decay, lr_scheduler, metrics, local_rank, world_size, device, mute, fp16, eval_interval, max_keep_ckpt, workspace, best_mode, use_loss_as_metric, report_metric_at_train, use_checkpoint, use_tensorboardX, scheduler_update_every_step)
### ------------------------------
def train_step(self, data):
rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
time = data['time'] # [B, 1]
# if there is no gt image, we train with CLIP loss.
if 'images' not in data:
B, N = rays_o.shape[:2]
H, W = data['H'], data['W']
# currently fix white bg, MUST force all rays!
outputs = self.model.render(rays_o, rays_d, time, staged=False, bg_color=None, perturb=True, force_all_rays=True, **vars(self.opt))
pred_rgb = outputs['image'].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()
# [debug] uncomment to plot the images used in train_step
#torch_vis_2d(pred_rgb[0])
loss = self.clip_loss(pred_rgb)
return pred_rgb, None, loss
images = data['images'] # [B, N, 3/4]
B, N, C = images.shape
if self.opt.color_space == 'linear':
images[..., :3] = srgb_to_linear(images[..., :3])
if C == 3 or self.model.bg_radius > 0:
bg_color = 1
# train with random background color if not using a bg model and has alpha channel.
else:
#bg_color = torch.ones(3, device=self.device) # [3], fixed white background
#bg_color = torch.rand(3, device=self.device) # [3], frame-wise random.
bg_color = torch.rand_like(images[..., :3]) # [N, 3], pixel-wise random.
if C == 4:
gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:])
else:
gt_rgb = images
outputs = self.model.render(rays_o, rays_d, time, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt))
pred_rgb = outputs['image']
loss = self.criterion(pred_rgb, gt_rgb).mean(-1) # [B, N, 3] --> [B, N]
# special case for CCNeRF's rank-residual training
if len(loss.shape) == 3: # [K, B, N]
loss = loss.mean(0)
# update error_map
if self.error_map is not None:
index = data['index'] # [B]
inds = data['inds_coarse'] # [B, N]
# take out, this is an advanced indexing and the copy is unavoidable.
error_map = self.error_map[index] # [B, H * W]
# [debug] uncomment to save and visualize error map
# if self.global_step % 1001 == 0:
# tmp = error_map[0].view(128, 128).cpu().numpy()
# print(f'[write error map] {tmp.shape} {tmp.min()} ~ {tmp.max()}')
# tmp = (tmp - tmp.min()) / (tmp.max() - tmp.min())
# cv2.imwrite(os.path.join(self.workspace, f'{self.global_step}.jpg'), (tmp * 255).astype(np.uint8))
error = loss.detach().to(error_map.device) # [B, N], already in [0, 1]
# ema update
ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error
error_map.scatter_(1, inds, ema_error)
# put back
self.error_map[index] = error_map
loss = loss.mean()
# deform regularization
if 'deform' in outputs and outputs['deform'] is not None:
loss = loss + 1e-3 * outputs['deform'].abs().mean()
return pred_rgb, gt_rgb, loss
def eval_step(self, data):
rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
time = data['time'] # [B, 1]
images = data['images'] # [B, H, W, 3/4]
B, H, W, C = images.shape
if self.opt.color_space == 'linear':
images[..., :3] = srgb_to_linear(images[..., :3])
# eval with fixed background color
bg_color = 1
if C == 4:
gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:])
else:
gt_rgb = images
outputs = self.model.render(rays_o, rays_d, time, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt))
pred_rgb = outputs['image'].reshape(B, H, W, 3)
pred_depth = outputs['depth'].reshape(B, H, W)
loss = self.criterion(pred_rgb, gt_rgb).mean()
return pred_rgb, pred_depth, gt_rgb, loss
# moved out bg_color and perturb for more flexible control...
def test_step(self, data, bg_color=None, perturb=False):
rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
time = data['time'] # [B, 1]
H, W = data['H'], data['W']
if bg_color is not None:
bg_color = bg_color.to(self.device)
outputs = self.model.render(rays_o, rays_d, time, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt))
pred_rgb = outputs['image'].reshape(-1, H, W, 3)
pred_depth = outputs['depth'].reshape(-1, H, W)
return pred_rgb, pred_depth
# [GUI] test on a single image
def test_gui(self, pose, intrinsics, W, H, time=0, bg_color=None, spp=1, downscale=1):
# render resolution (may need downscale to for better frame rate)
rH = int(H * downscale)
rW = int(W * downscale)
intrinsics = intrinsics * downscale
pose = torch.from_numpy(pose).unsqueeze(0).to(self.device)
rays = get_rays(pose, intrinsics, rH, rW, -1)
data = {
'time': torch.FloatTensor([[time]]).to(self.device), # from scalar to [1, 1] tensor.
'rays_o': rays['rays_o'],
'rays_d': rays['rays_d'],
'H': rH,
'W': rW,
}
self.model.eval()
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=self.fp16):
# here spp is used as perturb random seed!
preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp)
if self.ema is not None:
self.ema.restore()
# interpolation to the original resolution
if downscale != 1:
# TODO: have to permute twice with torch...
preds = F.interpolate(preds.permute(0, 3, 1, 2), size=(H, W), mode='nearest').permute(0, 2, 3, 1).contiguous()
preds_depth = F.interpolate(preds_depth.unsqueeze(1), size=(H, W), mode='nearest').squeeze(1)
if self.opt.color_space == 'linear':
preds = linear_to_srgb(preds)
pred = preds[0].detach().cpu().numpy()
pred_depth = preds_depth[0].detach().cpu().numpy()
outputs = {
'image': pred,
'depth': pred_depth,
}
return outputs
def save_mesh(self, time, save_path=None, resolution=256, threshold=10):
# time: scalar in [0, 1]
time = torch.FloatTensor([[time]]).to(self.device)
if save_path is None:
save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply')
self.log(f"==> Saving mesh to {save_path}")
os.makedirs(os.path.dirname(save_path), exist_ok=True)
def query_func(pts):
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=self.fp16):
sigma = self.model.density(pts.to(self.device), time)['sigma']
return sigma
vertices, triangles = extract_geometry(self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func)
mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault...
mesh.export(save_path)
self.log(f"==> Finished saving mesh.")