Skip to content

Commit

Permalink
Fix wrong input-output ASR input utts order (facebookresearch#5149)
Browse files Browse the repository at this point in the history
Co-authored-by: Andros Tjandra <[email protected]>
  • Loading branch information
androstj and androstj authored May 24, 2023
1 parent b50b649 commit 25c20e6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/mms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ wget https://dl.fbaipublicfiles.com/mms/tts/azj-script_latin.tar.gz # North Azer
Run this command to transcribe one or more audio files:
```shell command
cd /path/to/fairseq-py/
python examples/mms/asr/infer/mms_infer.py --model "/path/to/asr/model" --lang lang_code --audio "/path/to/audio_1.wav" "/path/to/audio_1.wav"
python examples/mms/asr/infer/mms_infer.py --model "/path/to/asr/model" --lang lang_code \
--audio "/path/to/audio_1.wav" "/path/to/audio_2.wav" "/path/to/audio_3.wav"
```

For more advance configuration and calculate CER/WER, you could prepare manifest folder by creating a folder with this format:
Expand Down
22 changes: 16 additions & 6 deletions examples/mms/asr/infer/mms_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,38 @@ def parser():
parser.add_argument("--format", type=str, choices=["none", "letter"], default="letter")
return parser.parse_args()

def reorder_decode(hypos):
outputs = []
for hypo in hypos:
idx = int(re.findall("\(None-(\d+)\)$", hypo)[0])
hypo = re.sub("\(\S+\)$", "", hypo).strip()
outputs.append((idx, hypo))
outputs = sorted(outputs)
return outputs

def process(args):
with tempfile.TemporaryDirectory() as tmpdir:
print(">>> preparing tmp manifest dir ...", file=sys.stderr)
tmpdir = Path(tmpdir)
with open(tmpdir / "dev.tsv", "w") as fw:
with open(tmpdir / "dev.tsv", "w") as fw, open(tmpdir / "dev.uid", "w") as fu:
fw.write("/\n")
for audio in args.audio:
nsample = sf.SoundFile(audio).frames
fw.write(f"{audio}\t{nsample}\n")
with open(tmpdir / "dev.uid", "w") as fw:
fw.write(f"{audio}\n"*len(args.audio))
fu.write(f"{audio}\n")
with open(tmpdir / "dev.ltr", "w") as fw:
fw.write("d u m m y | d u m m y\n"*len(args.audio))
fw.write("d u m m y | d u m m y |\n"*len(args.audio))
with open(tmpdir / "dev.wrd", "w") as fw:
fw.write("dummy dummy\n"*len(args.audio))
cmd = f"""
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir}
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=1440000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir}
"""
print(">>> loading model & running inference ...", file=sys.stderr)
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
with open(tmpdir/"hypo.word") as fr:
for ii, hypo in enumerate(fr):
hypos = fr.readlines()
outputs = reorder_decode(hypos)
for ii, hypo in outputs:
hypo = re.sub("\(\S+\)$", "", hypo).strip()
print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')

Expand Down

0 comments on commit 25c20e6

Please sign in to comment.