-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathtrainerAE.py
104 lines (85 loc) · 4.32 KB
/
trainerAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.optim as optim
from tqdm import tqdm
from model import CADTransformer
from .base import BaseTrainer
from .loss import CADLoss
from .scheduler import GradualWarmupScheduler
from cadlib.macro import *
class TrainerAE(BaseTrainer):
def build_net(self, cfg):
self.net = CADTransformer(cfg).cuda()
def set_optimizer(self, cfg):
"""set optimizer and lr scheduler used in training"""
self.optimizer = optim.Adam(self.net.parameters(), cfg.lr)
self.scheduler = GradualWarmupScheduler(self.optimizer, 1.0, cfg.warmup_step)
def set_loss_function(self):
self.loss_func = CADLoss(self.cfg).cuda()
def forward(self, data):
commands = data['command'].cuda() # (N, S)
args = data['args'].cuda() # (N, S, N_ARGS)
outputs = self.net(commands, args)
loss_dict = self.loss_func(outputs)
return outputs, loss_dict
def encode(self, data, is_batch=False):
"""encode into latent vectors"""
commands = data['command'].cuda()
args = data['args'].cuda()
if not is_batch:
commands = commands.unsqueeze(0)
args = args.unsqueeze(0)
z = self.net(commands, args, encode_mode=True)
return z
def decode(self, z):
"""decode given latent vectors"""
outputs = self.net(None, None, z=z, return_tgt=False)
return outputs
def logits2vec(self, outputs, refill_pad=True, to_numpy=True):
"""network outputs (logits) to final CAD vector"""
out_command = torch.argmax(torch.softmax(outputs['command_logits'], dim=-1), dim=-1) # (N, S)
out_args = torch.argmax(torch.softmax(outputs['args_logits'], dim=-1), dim=-1) - 1 # (N, S, N_ARGS)
if refill_pad: # fill all unused element to -1
mask = ~torch.tensor(CMD_ARGS_MASK).bool().cuda()[out_command.long()]
out_args[mask] = -1
out_cad_vec = torch.cat([out_command.unsqueeze(-1), out_args], dim=-1)
if to_numpy:
out_cad_vec = out_cad_vec.detach().cpu().numpy()
return out_cad_vec
def evaluate(self, test_loader):
"""evaluatinon during training"""
self.net.eval()
pbar = tqdm(test_loader)
pbar.set_description("EVALUATE[{}]".format(self.clock.epoch))
all_ext_args_comp = []
all_line_args_comp = []
all_arc_args_comp = []
all_circle_args_comp = []
for i, data in enumerate(pbar):
with torch.no_grad():
commands = data['command'].cuda()
args = data['args'].cuda()
outputs = self.net(commands, args)
out_args = torch.argmax(torch.softmax(outputs['args_logits'], dim=-1), dim=-1) - 1
out_args = out_args.long().detach().cpu().numpy() # (N, S, n_args)
gt_commands = commands.squeeze(1).long().detach().cpu().numpy() # (N, S)
gt_args = args.squeeze(1).long().detach().cpu().numpy() # (N, S, n_args)
ext_pos = np.where(gt_commands == EXT_IDX)
line_pos = np.where(gt_commands == LINE_IDX)
arc_pos = np.where(gt_commands == ARC_IDX)
circle_pos = np.where(gt_commands == CIRCLE_IDX)
args_comp = (gt_args == out_args).astype(np.int)
all_ext_args_comp.append(args_comp[ext_pos][:, -N_ARGS_EXT:])
all_line_args_comp.append(args_comp[line_pos][:, :2])
all_arc_args_comp.append(args_comp[arc_pos][:, :4])
all_circle_args_comp.append(args_comp[circle_pos][:, [0, 1, 4]])
all_ext_args_comp = np.concatenate(all_ext_args_comp, axis=0)
sket_plane_acc = np.mean(all_ext_args_comp[:, :N_ARGS_PLANE])
sket_trans_acc = np.mean(all_ext_args_comp[:, N_ARGS_PLANE:N_ARGS_PLANE+N_ARGS_TRANS])
extent_one_acc = np.mean(all_ext_args_comp[:, -N_ARGS_EXT_PARAM])
line_acc = np.mean(np.concatenate(all_line_args_comp, axis=0))
arc_acc = np.mean(np.concatenate(all_arc_args_comp, axis=0))
circle_acc = np.mean(np.concatenate(all_circle_args_comp, axis=0))
self.val_tb.add_scalars("args_acc",
{"line": line_acc, "arc": arc_acc, "circle": circle_acc,
"plane": sket_plane_acc, "trans": sket_trans_acc, "extent": extent_one_acc},
global_step=self.clock.epoch)