-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_pro.py
104 lines (72 loc) · 2.96 KB
/
test_pro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import argparse
import DeepDT_util
import DTU
from DeepDT_data import *
from DeepDT_Parallel import *
import R_GCN_model
import DeepDT_dataloader
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='Path to config file')
args = parser.parse_args()
cfg = DeepDT_util.load_config(args.config)
cfg = DeepDT_util.augment_config(cfg)
cfg = DeepDT_util.check_config(cfg)
if not os.path.exists(cfg["experiment_dir"]):
raise RuntimeError("experiment_dir does not exist.")
if cfg['use_normal']:
geo_in = 7
else:
geo_in = 1
train_model = R_GCN_model.R_GCN(geo_in)
model_path = cfg["model_path"]
if cfg["cuda"]:
train_model = DeepDTParallel(train_model, device_ids=cfg["device_ids"])
device = torch.device("cuda:{}".format(cfg["device_ids"][0]))
train_model = train_model.to(device)
if cfg["pretrained"]:
if os.path.exists(model_path):
d = torch.load(model_path, map_location="cpu")
# m = d.module
train_model.load_state_dict(d)
print("pretrained model loaded")
else:
print("training model from scratch")
else:
print("training model from scratch")
test_data = DTU.DTUDelDataset(cfg, "test")
test_data_loader = DeepDT_dataloader.DataListLoader(test_data, cfg["batch_size"], num_workers=cfg["num_workers"])
for test_data_list in test_data_loader:
for data in test_data_list:
data.adj = sparse_mx_to_torch_sparse_tensor(data.adj)
for d in [torch.device('cuda:{}'.format(cfg["device_ids"][i])) for i in range(len(cfg["device_ids"]))]:
with torch.cuda.device(d):
torch.cuda.empty_cache()
for test_data_list in test_data_loader:
for d in [torch.device('cuda:{}'.format(cfg["device_ids"][i])) for i in range(len(cfg["device_ids"]))]:
with torch.cuda.device(d):
torch.cuda.empty_cache()
for data in test_data_list:
data.adj = sparse_mx_to_torch_sparse_tensor(data.adj)
train_model.eval()
cell_pred, loss1, loss2 = train_model(test_data_list)
preds = cell_pred.max(dim=1)[1]
labels_pr = preds.detach().cpu() + 1
cnt = 0
for data in test_data_list:
label_num = data.cell_vertex_idx.shape[0]
label_begin = cnt
label_end = cnt + label_num
cnt += label_num
data_labels_pr = labels_pr[label_begin:label_end].numpy()
loss1_pr = torch.mean(loss1[label_begin:label_end]).item()
loss2_pr = torch.mean(loss2[label_begin:label_end]).item()
output_dir = os.path.join(cfg["experiment_dir"], data.data_name)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
np.savetxt(os.path.join(output_dir, "pre_label.txt"), data_labels_pr, fmt='%d')
print('loss1 %.6f, loss2 %.6f' % (loss1_pr, loss2_pr))
print("test", data.data_name, "done.")
for d in [torch.device('cuda:{}'.format(cfg["device_ids"][i])) for i in range(len(cfg["device_ids"]))]:
with torch.cuda.device(d):
torch.cuda.empty_cache()