Skip to content

Commit

Permalink
no message
Browse files Browse the repository at this point in the history
  • Loading branch information
songyouwei committed Apr 20, 2019
1 parent cc96f98 commit 0a667d0
Showing 1 changed file with 13 additions and 23 deletions.
36 changes: 13 additions & 23 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _reset_params(self):
stdv = 1. / math.sqrt(p.shape[0])
torch.nn.init.uniform_(p, a=-stdv, b=stdv)

def _train(self, criterion, optimizer, max_test_acc_overall=0):
def _train(self, criterion, optimizer):
writer = SummaryWriter(log_dir=self.opt.logdir)
max_test_acc = 0
max_f1 = 0
Expand Down Expand Up @@ -103,15 +103,16 @@ def _train(self, criterion, optimizer, max_test_acc_overall=0):
n_total += len(outputs)
train_acc = n_correct / n_total

# switch model to evaluation mode
self.model.eval()
test_acc, f1 = self._evaluate_acc_f1()
if test_acc > max_test_acc:
max_test_acc = test_acc
if test_acc > max_test_acc_overall:
if not os.path.exists('state_dict'):
os.mkdir('state_dict')
path = 'state_dict/{0}_{1}_acc{2}'.format(self.opt.model_name, self.opt.dataset, round(test_acc, 4))
torch.save(self.model.state_dict(), path)
print('>> saved: ' + path)
if not os.path.exists('state_dict'):
os.mkdir('state_dict')
path = 'state_dict/{0}_{1}_acc{2}'.format(self.opt.model_name, self.opt.dataset, round(test_acc, 4))
torch.save(self.model.state_dict(), path)
print('>> saved: ' + path)
if f1 > max_f1:
max_f1 = f1

Expand All @@ -124,8 +125,6 @@ def _train(self, criterion, optimizer, max_test_acc_overall=0):
return max_test_acc, max_f1

def _evaluate_acc_f1(self):
# switch model to evaluation mode
self.model.eval()
n_test_correct, n_test_total = 0, 0
t_targets_all, t_outputs_all = None, None
with torch.no_grad():
Expand All @@ -148,24 +147,15 @@ def _evaluate_acc_f1(self):
f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2], average='macro')
return test_acc, f1

def run(self, repeats=1):
def run(self):
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
_params = filter(lambda p: p.requires_grad, self.model.parameters())
optimizer = self.opt.optimizer(_params, lr=self.opt.learning_rate, weight_decay=self.opt.l2reg)

max_test_acc_overall = 0
max_f1_overall = 0
for i in range(repeats):
print('repeat: ', i)
self._reset_params()
max_test_acc, max_f1 = self._train(criterion, optimizer, max_test_acc_overall=max_test_acc_overall)
print('max_test_acc: {0} max_f1: {1}'.format(max_test_acc, max_f1))
max_test_acc_overall = max(max_test_acc, max_test_acc_overall)
max_f1_overall = max(max_f1, max_f1_overall)
print('#' * 100)
print("max_test_acc_overall:", max_test_acc_overall)
print("max_f1_overall:", max_f1_overall)
self._reset_params()
max_test_acc, max_f1 = self._train(criterion, optimizer)
print('max_test_acc: {0} max_f1: {1}'.format(max_test_acc, max_f1))


if __name__ == '__main__':
Expand Down Expand Up @@ -259,4 +249,4 @@ def run(self, repeats=1):
if opt.device is None else torch.device(opt.device)

ins = Instructor(opt)
ins.run(1) # _reset_params in every repeat
ins.run()

0 comments on commit 0a667d0

Please sign in to comment.