forked from ashawkey/torch-ngp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
provider.py
88 lines (67 loc) · 2.68 KB
/
provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import torch
from torch.utils.data import Dataset
import trimesh
import pysdf
def map_color(value, cmap_name='viridis', vmin=None, vmax=None):
# value: [N], float
# return: RGB, [N, 3], float in [0, 1]
import matplotlib.cm as cm
if vmin is None: vmin = value.min()
if vmax is None: vmax = value.max()
value = (value - vmin) / (vmax - vmin) # range in [0, 1]
cmap = cm.get_cmap(cmap_name)
rgb = cmap(value)[:, :3] # will return rgba, we take only first 3 so we get rgb
return rgb
def plot_pointcloud(pc, sdfs):
# pc: [N, 3]
# sdfs: [N, 1]
color = map_color(sdfs.squeeze(1))
pc = trimesh.PointCloud(pc, color)
trimesh.Scene([pc]).show()
# SDF dataset
class SDFDataset(Dataset):
def __init__(self, path, size=100, num_samples=2**18, clip_sdf=None):
super().__init__()
self.path = path
# load obj
self.mesh = trimesh.load(path, force='mesh')
# normalize to [-1, 1] (different from instant-sdf where is [0, 1])
vs = self.mesh.vertices
vmin = vs.min(0)
vmax = vs.max(0)
v_center = (vmin + vmax) / 2
v_scale = 2 / np.sqrt(np.sum((vmax - vmin) ** 2)) * 0.95
vs = (vs - v_center[None, :]) * v_scale
self.mesh.vertices = vs
print(f"[INFO] mesh: {self.mesh.vertices.shape} {self.mesh.faces.shape}")
if not self.mesh.is_watertight:
print(f"[WARN] mesh is not watertight! SDF maybe incorrect.")
#trimesh.Scene([self.mesh]).show()
self.sdf_fn = pysdf.SDF(self.mesh.vertices, self.mesh.faces)
self.num_samples = num_samples
assert self.num_samples % 8 == 0, "num_samples must be divisible by 8."
self.clip_sdf = clip_sdf
self.size = size
def __len__(self):
return self.size
def __getitem__(self, _):
# online sampling
sdfs = np.zeros((self.num_samples, 1))
# surface
points_surface = self.mesh.sample(self.num_samples * 7 // 8)
# perturb surface
points_surface[self.num_samples // 2:] += 0.01 * np.random.randn(self.num_samples * 3 // 8, 3)
# random
points_uniform = np.random.rand(self.num_samples // 8, 3) * 2 - 1
points = np.concatenate([points_surface, points_uniform], axis=0).astype(np.float32)
sdfs[self.num_samples // 2:] = -self.sdf_fn(points[self.num_samples // 2:])[:,None].astype(np.float32)
# clip sdf
if self.clip_sdf is not None:
sdfs = sdfs.clip(-self.clip_sdf, self.clip_sdf)
results = {
'sdfs': sdfs,
'points': points,
}
#plot_pointcloud(points, sdfs)
return results