Skip to content

Commit

Permalink
[MMS] TTS text uromanization + cpu inference (facebookresearch#5140)
Browse files Browse the repository at this point in the history
* mms tts uroman + cpu support for inference

* remove mps support to accommodate all pytorch versions

* add explanation to arg

---------

Co-authored-by: Bowen Shi <[email protected]>
  • Loading branch information
chevalierNoir and Bowen Shi authored May 24, 2023
1 parent 1082b61 commit fea3361
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions examples/mms/tts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# LICENSE file in the root directory of this source tree.

import os
import re
import glob
import json
import tempfile
import math
import torch
from torch import nn
Expand All @@ -15,6 +17,7 @@
import commons
import utils
import argparse
import subprocess
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from scipy.io.wavfile import write
Expand All @@ -41,6 +44,24 @@ def text_to_sequence(self, text, cleaner_names):
sequence += [symbol_id]
return sequence

def uromanize(self, text, uroman_pl):
iso = "xxx"
with tempfile.NamedTemporaryFile() as tf, \
tempfile.NamedTemporaryFile() as tf2:
with open(tf.name, "w") as f:
f.write("\n".join([text]))
cmd = f"perl " + uroman_pl
cmd += f" -l {iso} "
cmd += f" < {tf.name} > {tf2.name}"
os.system(cmd)
outtexts = []
with open(tf2.name) as f:
for line in f:
line = re.sub(r"\s+", " ", line).strip()
outtexts.append(line)
outtext = outtexts[0]
return outtext

def get_text(self, text, hps):
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
Expand All @@ -59,9 +80,16 @@ def generate():
parser.add_argument('--model-dir', type=str, help='model checkpoint dir')
parser.add_argument('--wav', type=str, help='output wav path')
parser.add_argument('--txt', type=str, help='input text')
parser.add_argument('--uroman-dir', type=str, help='uroman lib dir (will download if not specified)')
args = parser.parse_args()
ckpt_dir, wav_path, txt = args.model_dir, args.wav, args.txt

if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")

print(f"Run inference with {device}")
vocab_file = f"{ckpt_dir}/vocab.txt"
config_file = f"{ckpt_dir}/config.json"
assert os.path.isfile(config_file), f"{config_file} doesn't exist"
Expand All @@ -72,7 +100,7 @@ def generate():
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
net_g.cuda()
net_g.to(device)
_ = net_g.eval()

g_pth = f"{ckpt_dir}/G_100000.pth"
Expand All @@ -81,12 +109,24 @@ def generate():
_ = utils.load_checkpoint(g_pth, net_g, None)

print(f"text: {txt}")
is_uroman = hps.data.training_files.split('.')[-1] == 'uroman'
if is_uroman:
with tempfile.TemporaryDirectory() as tmp_dir:
if args.uroman_dir is None:
cmd = f"git clone [email protected]:isi-nlp/uroman.git {tmp_dir}"
print(cmd)
subprocess.check_output(cmd, shell=True)
args.uroman_dir = tmp_dir
uroman_pl = os.path.join(args.uroman_dir, "bin", "uroman.pl")
print(f"uromanize")
txt = text_mapper.uromanize(txt, uroman_pl)
print(f"uroman text: {txt}")
txt = txt.lower()
txt = text_mapper.filter_oov(txt)
stn_tst = text_mapper.get_text(txt, hps)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).cuda()
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
hyp = net_g.infer(
x_tst, x_tst_lengths, noise_scale=.667,
noise_scale_w=0.8, length_scale=1.0
Expand Down

0 comments on commit fea3361

Please sign in to comment.