Skip to content

Commit

Permalink
Merge pull request nv-tlabs#15 from orperel/renderer_fixes
Browse files Browse the repository at this point in the history
renderer fixes
  • Loading branch information
tovacinni authored Aug 10, 2021
2 parents 02b9e74 + 87c1534 commit cbad458
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions sdf-net/app/sdf_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,19 @@

from lib.renderer import Renderer
from lib.models import *
from lib.tracer import *
from lib.options import parse_options
from lib.geoutils import sample_unif_sphere, sample_fib_sphere, normalized_slice


def write_exr(path, data):
pyexr.write(path, data,
channel_names={'normal': ['X','Y','Z'],
'x': ['X','Y','Z'],
'view': ['X','Y','Z']},
channel_names={'normal': ['X', 'Y', 'Z'],
'x': ['X', 'Y', 'Z'],
'view': ['X', 'Y', 'Z']},
precision=pyexr.HALF)


if __name__ == '__main__':

# Parse
Expand Down Expand Up @@ -91,17 +94,17 @@ def write_exr(path, data):
net = globals()[args.net](args)
if args.jit:
net = torch.jit.script(net)

net.load_state_dict(torch.load(args.pretrained))

net.to(device)
net.eval()

print("Total number of parameters: {}".format(sum(p.numel() for p in net.parameters())))

if args.export is not None:
net = SOL_NGLOD(net)

net.save(args.export)
sys.exit()

Expand All @@ -121,78 +124,79 @@ def write_exr(path, data):
if not os.path.exists(_dir):
os.makedirs(_dir)

renderer = Renderer(args, device, net).eval()
tracer = globals()[args.tracer](args)
renderer = Renderer(tracer, args=args, device=device)

if args.rotate is not None:
rad = np.radians(args.rotate)
model_matrix = torch.FloatTensor(R.from_rotvec(rad * np.array([0,1,0])).as_matrix())
model_matrix = torch.FloatTensor(R.from_rotvec(rad * np.array([0, 1, 0])).as_matrix())
else:
model_matrix = torch.eye(3)

if args.r360:
for angle in np.arange(0, 360, 2):
rad = np.radians(angle)
model_matrix = torch.FloatTensor(R.from_rotvec(rad * np.array([0,1,0])).as_matrix())
model_matrix = torch.FloatTensor(R.from_rotvec(rad * np.array([0, 1, 0])).as_matrix())

out = renderer.shade_images(f=args.camera_origin,
t=args.camera_lookat,
fv=args.camera_fov,
fov=args.camera_fov,
aa=not args.disable_aa,
mm=model_matrix)

data = out.float().numpy().exrdict()

idx = int(math.floor(100 * angle))

if args.exr:
write_exr('{}/exr/{:06d}.exr'.format(ins_dir, idx), data)

img_out = out.image().byte().numpy()
Image.fromarray(img_out.rgb).save('{}/rgb/{:06d}.png'.format(ins_dir, idx), mode='RGB')
Image.fromarray(img_out.normal).save('{}/normal/{:06d}.png'.format(ins_dir, idx), mode='RGB')

elif args.rsphere:
views = sample_fib_sphere(args.nb_poses)
cam_origins = args.cam_radius * views
for p, cam_origin in enumerate(cam_origins):
out = renderer.shade_images(f=cam_origin,
t=args.camera_lookat,
fv=args.camera_fov,
fov=args.camera_fov,
aa=not args.disable_aa,
mm=model_matrix)

data = out.float().numpy().exrdict()

if args.exr:
write_exr('{}/exr/{:06d}.exr'.format(ins_dir, p), data)

img_out = out.image().byte().numpy()
Image.fromarray(img_out.rgb).save('{}/rgb/{:06d}.png'.format(ins_dir, p), mode='RGB')
Image.fromarray(img_out.normal).save('{}/normal/{:06d}.png'.format(ins_dir, p), mode='RGB')

else:

out = renderer.shade_images(f=args.camera_origin,
t=args.camera_lookat,
fv=args.camera_fov,
aa=not args.disable_aa,
out = renderer.shade_images(net=net,
f=args.camera_origin,
t=args.camera_lookat,
fov=args.camera_fov,
aa=not args.disable_aa,
mm=model_matrix)

data = out.float().numpy().exrdict()

if args.render_2d:
depth = args.depth
data['sdf_slice'] = renderer.sdf_slice(depth=depth)
data['rgb_slice'] = renderer.rgb_slice(depth=depth)
data['normal_slice'] = renderer.normal_slice(depth=depth)

if args.exr:
write_exr(f'{ins_dir}/out.exr', data)

img_out = out.image().byte().numpy()

Image.fromarray(img_out.rgb).save('{}/{}_rgb.png'.format(ins_dir, name), mode='RGB')
Image.fromarray(img_out.depth).save('{}/{}_depth.png'.format(ins_dir, name), mode='RGB')
Image.fromarray(img_out.normal).save('{}/{}_normal.png'.format(ins_dir, name), mode='RGB')
Image.fromarray(img_out.hit).save('{}/{}_hit.png'.format(ins_dir, name), mode='L')

0 comments on commit cbad458

Please sign in to comment.