-
Notifications
You must be signed in to change notification settings - Fork 2
/
obj2mesh.py
121 lines (92 loc) · 3.88 KB
/
obj2mesh.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
import os
import torch
import psutil
import gc
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.data.objaverse import load_obj
from src.utils import mesh
from src.utils.material import Material
import argparse
def bytes_to_megabytes(bytes):
return bytes / (1024 * 1024)
def bytes_to_gigabytes(bytes):
return bytes / (1024 * 1024 * 1024)
def print_memory_usage(stage):
process = psutil.Process(os.getpid())
memory_info = process.memory_info()
allocated = torch.cuda.memory_allocated() / 1024**2
cached = torch.cuda.memory_reserved() / 1024**2
print(
f"[{stage}] Process memory: {memory_info.rss / 1024**2:.2f} MB, "
f"Allocated CUDA memory: {allocated:.2f} MB, Cached CUDA memory: {cached:.2f} MB"
)
def process_obj(index, root_dir, final_save_dir, paths):
obj_path = os.path.join(root_dir, paths[index], paths[index] + '.obj')
mtl_path = os.path.join(root_dir, paths[index], paths[index] + '.mtl')
if os.path.exists(os.path.join(final_save_dir, f"{paths[index]}.pth")):
return None
try:
with torch.no_grad():
ref_mesh, vertices, faces, normals, nfaces, texcoords, tfaces, uber_material = load_obj(
obj_path, return_attributes=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ref_mesh = mesh.compute_tangents(ref_mesh)
with open(mtl_path, 'r') as file:
lines = file.readlines()
if len(lines) >= 250:
return None
final_mesh_attributes = {
"v_pos": ref_mesh.v_pos.detach().cpu(),
"v_nrm": ref_mesh.v_nrm.detach().cpu(),
"v_tex": ref_mesh.v_tex.detach().cpu(),
"v_tng": ref_mesh.v_tng.detach().cpu(),
"t_pos_idx": ref_mesh.t_pos_idx.detach().cpu(),
"t_nrm_idx": ref_mesh.t_nrm_idx.detach().cpu(),
"t_tex_idx": ref_mesh.t_tex_idx.detach().cpu(),
"t_tng_idx": ref_mesh.t_tng_idx.detach().cpu(),
"mat_dict": {key: ref_mesh.material[key] for key in ref_mesh.material.mat_keys},
}
torch.save(final_mesh_attributes, f"{final_save_dir}/{paths[index]}.pth")
print(f"==> Saved to {final_save_dir}/{paths[index]}.pth")
del ref_mesh
torch.cuda.empty_cache()
return paths[index]
except Exception as e:
print(f"Failed to process {paths[index]}: {e}")
return None
finally:
gc.collect()
torch.cuda.empty_cache()
def main(root_dir, save_dir):
os.makedirs(save_dir, exist_ok=True)
finish_lists = os.listdir(save_dir)
paths = os.listdir(root_dir)
valid_uid = []
print_memory_usage("Start")
batch_size = 100
num_batches = (len(paths) + batch_size - 1) // batch_size
for batch in tqdm(range(num_batches)):
start_index = batch * batch_size
end_index = min(start_index + batch_size, len(paths))
with ThreadPoolExecutor(max_workers=8) as executor:
futures = [
executor.submit(process_obj, index, root_dir, save_dir, paths)
for index in range(start_index, end_index)
]
for future in as_completed(futures):
result = future.result()
if result is not None:
valid_uid.append(result)
print_memory_usage(f"=====> After processing batch {batch + 1}")
torch.cuda.empty_cache()
gc.collect()
print_memory_usage("End")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process OBJ files and save final results.")
parser.add_argument("root_dir", type=str, help="Directory containing the root OBJ files.")
parser.add_argument("save_dir", type=str, help="Directory to save the processed results.")
args = parser.parse_args()
main(args.root_dir, args.save_dir)