forked from ashawkey/torch-ngp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
provider.py
103 lines (76 loc) · 3.18 KB
/
provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import time
import glob
import numpy as np
import cv2
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
import trimesh
import pysdf
def map_color(value, cmap_name='viridis', vmin=None, vmax=None):
# value: [N], float
# return: RGB, [N, 3], float in [0, 1]
import matplotlib.cm as cm
if vmin is None: vmin = value.min()
if vmax is None: vmax = value.max()
value = (value - vmin) / (vmax - vmin) # range in [0, 1]
cmap = cm.get_cmap(cmap_name)
rgb = cmap(value)[:, :3] # will return rgba, we take only first 3 so we get rgb
return rgb
def plot_pointcloud(pc, sdfs):
# pc: [N, 3]
# sdfs: [N, 1]
color = map_color(sdfs.squeeze(1))
pc = trimesh.PointCloud(pc, color)
trimesh.Scene([pc]).show()
# SDF dataset
class SDFDataset(Dataset):
def __init__(self, path, size=10000, num_samples=25600, clip_sdf=0.1):
super().__init__()
self.path = path
# load obj
self.mesh = trimesh.load(path, force='mesh')
# normalize
vs = self.mesh.vertices
vmin = vs.min(0)
vmax = vs.max(0)
v_center = (vmin + vmax) / 2
v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95
vs = (vs - v_center[None, :]) * v_scale
self.mesh.vertices = vs
print(f"[INFO] mesh: {self.mesh.vertices.shape} {self.mesh.faces.shape}")
if not self.mesh.is_watertight:
print(f"[WARN] mesh is not watertight! SDF maybe incorrect.")
#trimesh.Scene([self.mesh]).show()
self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces)
self.num_samples = num_samples
assert self.num_samples % 8 == 0, "num_samples must be divisible by 8."
self.clip_sdf = clip_sdf
self.size = size
def __len__(self):
return self.size
def __getitem__(self, _):
# online sampling
sdfs = np.zeros((self.num_samples, 1))
# surface
points_surface = self.mesh.sample(self.num_samples * 7 // 8)
# pertube surface
points_surface[self.num_samples // 2:] += 0.002 * np.random.randn(self.num_samples * 3 // 8, 3)
# random
points_uniform = np.random.rand(self.num_samples // 8, 3) * 2 - 1
points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32)
sdfs[self.num_samples // 2:] = -self.sdf_fn(points[self.num_samples // 2:])[:,None].astype(np.float32)
#print('surface', points[:self.num_samples * 7 // 8].min(), '~', points[:self.num_samples * 7 // 8].max(), sdfs[:self.num_samples * 7 // 8].min(), '~', sdfs[:self.num_samples * 7 // 8].max())
#print('uniform', points_uniform.min(), '~', points_uniform.max(), sdfs[self.num_samples * 7 // 8:].min(), '~', sdfs[self.num_samples * 7 // 8:].max())
# clip sdf
if self.clip_sdf is not None:
sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf)
results = {
'sdfs': sdfs,
'points': points,
}
#plot_pointcloud(points, sdfs)
#print('points', points.min(), '~', points.max())
#print('sdfs', sdfs.min(), '~', sdfs.max())
return results