forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalculate_rtf.py
executable file
·113 lines (99 loc) · 3.42 KB
/
calculate_rtf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2021 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import glob
import os
from dateutil import parser
def get_parser():
parser = argparse.ArgumentParser(description="calculate real time factor (RTF)")
parser.add_argument(
"--log-dir",
type=str,
default=None,
help="path to logging directory",
)
parser.add_argument(
"--log-name",
type=str,
default="decode",
choices=["decode", "asr_inference"],
help="name of logfile, e.g., 'decode' (espnet1) and "
"'asr_inference' (espnet2)",
)
parser.add_argument(
"--input-shift",
type=float,
default=10.0,
help="shift of inputs in milliseconds",
)
parser.add_argument(
"--start-times-marker",
type=str,
default="input lengths",
choices=["input lengths", "speech length"],
help="String marking start of decoding in logfile, e.g., "
"'input lengths' (espnet1) and 'speech length' (espnet2)",
)
parser.add_argument(
"--end-times-marker",
type=str,
default="prediction",
choices=["prediction", "best hypo"],
help="String marking end of decoding in logfile, e.g., "
"'prediction' (espnet1) and 'best hypo' (espnet2)",
)
parser.add_argument(
"--inf-num",
type=int,
default=1,
help="number of inference hypothesis for each utterance, e.g. "
">1 in multi-speaker asr.",
)
return parser
def main():
args = get_parser().parse_args()
audio_sec = 0
decode_sec = 0
n_utt = 0
log_files = args.log_name + ".*.log"
start_times_marker = "INFO: " + args.start_times_marker
end_times_marker = "INFO: " + args.end_times_marker
for x in glob.glob(os.path.join(args.log_dir, log_files)):
audio_durations = []
start_times = []
end_times = []
with codecs.open(x, "r", "utf-8") as f:
for line in f:
x = line.strip()
if start_times_marker in x:
audio_durations += [int(x.split(args.start_times_marker + ": ")[1])]
start_times += [parser.parse(x.split("(")[0])]
elif end_times_marker in x:
end_times += [parser.parse(x.split("(")[0])]
if args.inf_num > 1:
# When inf_num > 1, select the last speaker's end as end time
end_times = end_times[args.inf_num - 1 :: args.inf_num]
assert len(audio_durations) == len(end_times), (
len(audio_durations),
len(end_times),
)
assert len(start_times) == len(end_times), (len(start_times), len(end_times))
audio_sec += sum(audio_durations) * args.input_shift / 1000 # [sec]
decode_sec += sum(
[
(end - start).total_seconds()
for start, end in zip(start_times, end_times)
]
)
n_utt += len(audio_durations)
print("Total audio duration: %.3f [sec]" % audio_sec)
print("Total decoding time: %.3f [sec]" % decode_sec)
rtf = decode_sec / audio_sec if audio_sec > 0 else 0
print("RTF: %.3f" % rtf)
latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0
print("Latency: %.3f [ms/sentence]" % latency)
if __name__ == "__main__":
main()