forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MMS] TTS text uromanization + cpu inference (facebookresearch#5140)
* mms tts uroman + cpu support for inference * remove mps support to accommodate all pytorch versions * add explanation to arg --------- Co-authored-by: Bowen Shi <[email protected]>
- Loading branch information
1 parent
1082b61
commit fea3361
Showing
1 changed file
with
43 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,10 @@ | |
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import re | ||
import glob | ||
import json | ||
import tempfile | ||
import math | ||
import torch | ||
from torch import nn | ||
|
@@ -15,6 +17,7 @@ | |
import commons | ||
import utils | ||
import argparse | ||
import subprocess | ||
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate | ||
from models import SynthesizerTrn | ||
from scipy.io.wavfile import write | ||
|
@@ -41,6 +44,24 @@ def text_to_sequence(self, text, cleaner_names): | |
sequence += [symbol_id] | ||
return sequence | ||
|
||
def uromanize(self, text, uroman_pl): | ||
iso = "xxx" | ||
with tempfile.NamedTemporaryFile() as tf, \ | ||
tempfile.NamedTemporaryFile() as tf2: | ||
with open(tf.name, "w") as f: | ||
f.write("\n".join([text])) | ||
cmd = f"perl " + uroman_pl | ||
cmd += f" -l {iso} " | ||
cmd += f" < {tf.name} > {tf2.name}" | ||
os.system(cmd) | ||
outtexts = [] | ||
with open(tf2.name) as f: | ||
for line in f: | ||
line = re.sub(r"\s+", " ", line).strip() | ||
outtexts.append(line) | ||
outtext = outtexts[0] | ||
return outtext | ||
|
||
def get_text(self, text, hps): | ||
text_norm = self.text_to_sequence(text, hps.data.text_cleaners) | ||
if hps.data.add_blank: | ||
|
@@ -59,9 +80,16 @@ def generate(): | |
parser.add_argument('--model-dir', type=str, help='model checkpoint dir') | ||
parser.add_argument('--wav', type=str, help='output wav path') | ||
parser.add_argument('--txt', type=str, help='input text') | ||
parser.add_argument('--uroman-dir', type=str, help='uroman lib dir (will download if not specified)') | ||
args = parser.parse_args() | ||
ckpt_dir, wav_path, txt = args.model_dir, args.wav, args.txt | ||
|
||
if torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
else: | ||
device = torch.device("cpu") | ||
|
||
print(f"Run inference with {device}") | ||
vocab_file = f"{ckpt_dir}/vocab.txt" | ||
config_file = f"{ckpt_dir}/config.json" | ||
assert os.path.isfile(config_file), f"{config_file} doesn't exist" | ||
|
@@ -72,7 +100,7 @@ def generate(): | |
hps.data.filter_length // 2 + 1, | ||
hps.train.segment_size // hps.data.hop_length, | ||
**hps.model) | ||
net_g.cuda() | ||
net_g.to(device) | ||
_ = net_g.eval() | ||
|
||
g_pth = f"{ckpt_dir}/G_100000.pth" | ||
|
@@ -81,12 +109,24 @@ def generate(): | |
_ = utils.load_checkpoint(g_pth, net_g, None) | ||
|
||
print(f"text: {txt}") | ||
is_uroman = hps.data.training_files.split('.')[-1] == 'uroman' | ||
if is_uroman: | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
if args.uroman_dir is None: | ||
cmd = f"git clone [email protected]:isi-nlp/uroman.git {tmp_dir}" | ||
print(cmd) | ||
subprocess.check_output(cmd, shell=True) | ||
args.uroman_dir = tmp_dir | ||
uroman_pl = os.path.join(args.uroman_dir, "bin", "uroman.pl") | ||
print(f"uromanize") | ||
txt = text_mapper.uromanize(txt, uroman_pl) | ||
print(f"uroman text: {txt}") | ||
txt = txt.lower() | ||
txt = text_mapper.filter_oov(txt) | ||
stn_tst = text_mapper.get_text(txt, hps) | ||
with torch.no_grad(): | ||
x_tst = stn_tst.unsqueeze(0).cuda() | ||
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda() | ||
x_tst = stn_tst.unsqueeze(0).to(device) | ||
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) | ||
hyp = net_g.infer( | ||
x_tst, x_tst_lengths, noise_scale=.667, | ||
noise_scale_w=0.8, length_scale=1.0 | ||
|