Skip to content

Commit

Permalink
Update calculate_rtf.py
Browse files Browse the repository at this point in the history
Make calculate_rtf.py work with espnet2
  • Loading branch information
espnetUser authored May 20, 2022
1 parent 047d0c4 commit 837553f
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions utils/calculate_rtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import argparse
import codecs
from dateutil import parser
import glob
import os

Expand All @@ -20,6 +21,12 @@ def get_parser():
default=None,
help="path to logging directory",
)
parser.add_argument(
"--log-name",
type=str,
default="decode",
help="name of logfile, e.g., 'decode' or 'asr_inference'",
)
return parser


Expand All @@ -29,12 +36,15 @@ def main():

audio_sec = 0
decode_sec = 0
latency_sec = 0
n_utt = 0

audio_durations = []
start_times = []
end_times = []
for x in glob.glob(os.path.join(args.log_dir, "decode.*.log")):
eos_times = []
log_files = args.log_name + ".*.log"
for x in glob.glob(os.path.join(args.log_dir, log_files)):
with codecs.open(x, "r", "utf-8") as f:
for line in f:
x = line.strip()
Expand All @@ -43,6 +53,8 @@ def main():
start_times += [parser.parse(x.split("(")[0])]
elif "INFO: prediction" in x:
end_times += [parser.parse(x.split("(")[0])]
elif "INFO: received final input" in x:
eos_times += [parser.parse(x.split("(")[0])]
assert len(audio_durations) == len(end_times), (
len(audio_durations),
len(end_times),
Expand All @@ -55,13 +67,24 @@ def main():
for start, end in zip(start_times, end_times)
]
)
if len(eos_times):
assert len(eos_times) == len(end_times), (len(eos_times), len(end_times))
latency_sec += sum(
[
(end - start).total_seconds()
for start, end in zip(eos_times, end_times)
]
)
n_utt += len(audio_durations)

print("Total audio duration: %.3f [sec]" % audio_sec)
print("Total decoding time: %.3f [sec]" % decode_sec)
rtf = decode_sec / audio_sec if audio_sec > 0 else 0
print("RTF: %.3f" % rtf)
latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0
if len(eos_times):
latency = latency_sec * 1000 / n_utt if n_utt > 0 else 0
else:
latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0
print("Latency: %.3f [ms/sentence]" % latency)


Expand Down

0 comments on commit 837553f

Please sign in to comment.