forked from DingXiaoH/GSM-SGD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
show_log.py
56 lines (48 loc) · 1.97 KB
/
show_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import glob
import re
import numpy as np
top1_pattern = re.compile('top1=(\-*\d+(?:\.\d+)?)')
top5_pattern = re.compile('top5=(\-*\d+(?:\.\d+)?)')
loss_pattern = re.compile('loss=(\-*\d+(?:\.\d+)?)')
def get_value_by_pattern(pattern, line):
return float(re.findall(pattern, line)[0])
def parse_top1_top5_loss_from_log_line(log_line):
top1 = get_value_by_pattern(top1_pattern, log_line)
top5 = get_value_by_pattern(top5_pattern, log_line)
loss = get_value_by_pattern(loss_pattern, log_line)
return top1, top5, loss
root_dirs = ['gsm_lottery_ticket_exps', 'gsm_exps']
num_logs = 10
log_files = []
for root_dir in root_dirs:
fs = glob.glob('{}/*/log.txt'.format(root_dir))
log_files += fs
for file_path in log_files:
top1_list = []
top5_list = []
loss_list = []
with open(file_path, 'r') as f:
origin_lines = f.readlines()
log_lines = [l for l in origin_lines if 'top1' in l and 'top5' in l and 'loss' in l]
avg_loss = 'NULL'
for l in origin_lines[::-1]:
if 'TRAIN LOSS collected over last' in l:
avg_loss = l.strip()[-8:]
break
last_lines = log_lines[-num_logs:]
for l in last_lines:
if 'top1' not in l or 'loss' not in l or 'top5' not in l:
continue
top1, top5, loss = parse_top1_top5_loss_from_log_line(l)
top1_list.append(top1)
top5_list.append(top5)
loss_list.append(loss)
if len(top1_list) < num_logs:
continue
network_try_arg = file_path.split('/')[1].replace('_train', '')
last_validation = last_lines[-1]
last_epoch = last_validation
last_epoch_pattern = re.compile('epoch (\d+)')
last_epoch = int(last_epoch_pattern.findall(last_validation)[0])
print('{}, \t top1={:.3f}, \t top5={:.3f}, \t loss={:.5f}, \t {} logs, train_loss={}, last_epoch={}'.format(network_try_arg,
np.mean(top1_list), np.mean(top5_list), np.mean(loss_list), len(top1_list), avg_loss, last_epoch))