forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_wav_from_fbank.py
executable file
·187 lines (149 loc) · 5.89 KB
/
generate_wav_from_fbank.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder."""
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import logging
import os
import time
import h5py
import numpy as np
import pysptk
import torch
from scipy.io.wavfile import write
from sklearn.preprocessing import StandardScaler
from espnet.nets.pytorch_backend.wavenet import decode_mu_law
from espnet.nets.pytorch_backend.wavenet import encode_mu_law
from espnet.nets.pytorch_backend.wavenet import WaveNet
from espnet.utils.cli_readers import file_reader_helper
from espnet.utils.cli_utils import get_commandline_args
class TimeInvariantMLSAFilter(object):
"""Time invariant MLSA filter.
This module is used to perform noise shaping described in
`An investigation of noise shaping with perceptual
weighting for WaveNet-based speech generation`_.
Args:
coef (ndaaray): MLSA filter coefficient (D,).
alpha (float): All pass constant value.
n_shift (int): Shift length in points.
.. _`An investigation of noise shaping with perceptual
weighting for WaveNet-based speech generation`:
https://ieeexplore.ieee.org/abstract/document/8461332
"""
def __init__(self, coef, alpha, n_shift):
self.coef = coef
self.n_shift = n_shift
self.mlsa_filter = pysptk.synthesis.Synthesizer(
pysptk.synthesis.MLSADF(order=coef.shape[0] - 1, alpha=alpha),
hopsize=n_shift,
)
def __call__(self, y):
"""Apply time invariant MLSA filter.
Args:
y (ndarray): Waveform signal normalized from -1 to 1 (N,).
Returns:
y (ndarray): Filtered waveform signal normalized from -1 to 1 (N,).
"""
# check shape and type
assert len(y.shape) == 1
y = np.float64(y)
# get frame number and then replicate mlsa coef
num_frames = int(len(y) / self.n_shift) + 1
coef = np.tile(self.coef, [num_frames, 1])
return self.mlsa_filter.synthesis(y, coef)
def get_parser():
parser = argparse.ArgumentParser(
description="generate wav from FBANK using wavenet vocoder",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--fs", type=int, default=22050, help="Sampling frequency")
parser.add_argument("--n_fft", type=int, default=1024, help="FFT length in point")
parser.add_argument(
"--n_shift", type=int, default=256, help="Shift length in point"
)
parser.add_argument("--model", type=str, default=None, help="WaveNet model")
parser.add_argument(
"--filetype",
type=str,
default="mat",
choices=["mat", "hdf5"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument("rspecifier", type=str, help="Input feature e.g. scp:feat.scp")
parser.add_argument("outdir", type=str, help="Output directory")
return parser
def main():
parser = get_parser()
args = parser.parse_args()
# logging info
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.info(get_commandline_args())
# check directory
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
# load model config
model_dir = os.path.dirname(args.model)
train_args = torch.load(os.path.join(model_dir, "model.conf"))
# load statistics
scaler = StandardScaler()
with h5py.File(os.path.join(model_dir, "stats.h5")) as f:
scaler.mean_ = f["/melspc/mean"][()]
scaler.scale_ = f["/melspc/scale"][()]
# TODO(kan-bayashi): include following info as default
coef = f["/mlsa/coef"][()]
alpha = f["/mlsa/alpha"][()]
# define MLSA filter for noise shaping
mlsa_filter = TimeInvariantMLSAFilter(coef=coef, alpha=alpha, n_shift=args.n_shift,)
# define model and laod parameters
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = WaveNet(
n_quantize=train_args.n_quantize,
n_aux=train_args.n_aux,
n_resch=train_args.n_resch,
n_skipch=train_args.n_skipch,
dilation_depth=train_args.dilation_depth,
dilation_repeat=train_args.dilation_repeat,
kernel_size=train_args.kernel_size,
upsampling_factor=train_args.upsampling_factor,
)
model.load_state_dict(torch.load(args.model, map_location="cpu")["model"])
model.eval()
model.to(device)
for idx, (utt_id, lmspc) in enumerate(
file_reader_helper(args.rspecifier, args.filetype), 1
):
logging.info("(%d) %s" % (idx, utt_id))
# perform preprocesing
x = encode_mu_law(
np.zeros((1)), mu=train_args.n_quantize
) # quatize initial seed waveform
h = scaler.transform(lmspc) # normalize features
# convert to tensor
x = torch.tensor(x, dtype=torch.long, device=device) # (1,)
h = torch.tensor(h, dtype=torch.float, device=device) # (T, n_aux)
# get length of waveform
n_samples = (h.shape[0] - 1) * args.n_shift + args.n_fft
# generate
start_time = time.time()
with torch.no_grad():
y = model.generate(x, h, n_samples, interval=100)
logging.info(
"generation speed = %s (sec / sample)"
% ((time.time() - start_time) / (len(y) - 1))
)
y = decode_mu_law(y, mu=train_args.n_quantize)
# apply mlsa filter for noise shaping
y = mlsa_filter(y)
# save as .wav file
write(
os.path.join(args.outdir, "%s.wav" % utt_id),
args.fs,
(y * np.iinfo(np.int16).max).astype(np.int16),
)
if __name__ == "__main__":
main()