Skip to content

Commit

Permalink
WN Wheez fix, T synth const, docker, bugs fix
Browse files Browse the repository at this point in the history
- Fix WaveNet Wheez problem (checkerboard artifacts and training scheme optimization)
- Add Wavenet upsample layer types and Nearest neighbor upsample init
- Add multi-GPU WaveNet implementation
- Fix WaveNet scopes/names
- Tacotron attention mechanism synthesis constraint (long utterance generation)
- Tacotron encoder masking default
- Tacotron batch norm after relu option (default)
- Add docker build options
- Major bugs (train + synthesis) fix
- Minor bugs fix
- Performance debugging and optimization

TODO: finish documentation + minor bugs fix (if existing)
  • Loading branch information
Rayhane-mamah authored Jan 4, 2019
1 parent 970b080 commit 4dc16ce
Show file tree
Hide file tree
Showing 27 changed files with 1,749 additions and 675 deletions.
50 changes: 46 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# Tacotron-2:
Tensorflow implementation of DeepMind's Tacotron-2. A deep neural network architecture described in this paper: [Natural TTS synthesis by conditioning Wavenet on MEL spectogram predictions](https://arxiv.org/pdf/1712.05884.pdf)

This Repository contains additional improvements and attempts over the paper, we thus propose **paper_hparams.py** file which holds the exact hyperparameters to reproduce the paper results without any additional extras.

Suggested **hparams.py** file which is default in use, contains the hyperparameters with extras that proved to provide better results in most cases. Feel free to toy with the parameters as needed.

DIFFERENCES WILL BE HIGHLIGHTED IN DOCUMENTATION SHORTLY.


# Repository Structure:
Tacotron-2
Expand All @@ -20,14 +26,25 @@ Tensorflow implementation of DeepMind's Tacotron-2. A deep neural network archit
│  │  └── wavs
│   ├── mel-spectrograms
│   ├── plots
│   ├── pretrained
│   ├── taco_pretrained
│   ├── metas
│   └── wavs
├── logs-Wavenet (4)
│   ├── eval-dir
│   │  ├── plots
│  │  └── wavs
│   ├── plots
│   ├── pretrained
│   ├── wave_pretrained
│   ├── metas
│   └── wavs
├── logs-Tacotron-2 ( * )
│   ├── eval-dir
│   │  ├── plots
│  │  └── wavs
│   ├── plots
│   ├── taco_pretrained
│   ├── wave_pretrained
│   ├── metas
│   └── wavs
├── papers
├── tacotron
Expand Down Expand Up @@ -60,6 +77,8 @@ The previous tree shows the current state of the repository (separate training,
- Step **(4)**: Train your Wavenet model. Yield the **logs-Wavenet** folder.
- Step **(5)**: Synthesize audio using the Wavenet model. Gives the **wavenet_output** folder.

- Note: Steps 2, 3, and 4 can be made with a simple run for both Tacotron and WaveNet (Tacotron-2, step ( * )).


Note:
- **Our preprocessing only supports Ljspeech and Ljspeech-like datasets (M-AILABS speech data)!** If running on datasets stored differently, you will probably need to make your own preprocessing script.
Expand Down Expand Up @@ -87,12 +106,33 @@ To have an overview of our advance on this project, please refer to [this discus
since the two parts of the global model are trained separately, we can start by training the feature prediction model to use his predictions later during the wavenet training.

# How to start
first, you need to have python 3 installed along with [Tensorflow](https://www.tensorflow.org/install/).
- **Machine Setup:**

First, you need to have python 3 installed along with [Tensorflow](https://www.tensorflow.org/install/).

next you can install the requirements. If you are an Anaconda user: (else replace **pip** with **pip3** and **python** with **python3**)
Next, you need to install some Linux dependencies to ensure audio libraries work properly:

> apt-get install -y libasound-dev portaudio19-dev libportaudio2 libportaudiocpp0 ffmpeg libav-tools
Finally, you can install the requirements. If you are an Anaconda user: (else replace **pip** with **pip3** and **python** with **python3**)

> pip install -r requirements.txt
- **Docker:**

Alternatively, one can build the **docker image** to ensure everything is setup automatically and use the project inside the docker containers.
**Dockerfile is insider "docker" folder**

docker image can be built with:

> docker build -t tacotron-2_image docker/
Then containers are runnable with:

> docker run -i --name new_container tacotron-2_image
Please report any issues with the Docker usage with our models, I'll get to it. Thanks!

# Dataset:
We tested the code above on the [ljspeech dataset](https://keithito.com/LJ-Speech-Dataset/), which has almost 24 hours of labeled single actress voice recording. (further info on the dataset are available in the README file when you download it)

Expand All @@ -105,6 +145,8 @@ Before proceeding, you must pick the hyperparameters that suit best your needs.

To pick optimal fft parameters, I have made a **griffin_lim_synthesis_tool** notebook that you can use to invert real extracted mel/linear spectrograms and choose how good your preprocessing is. All other options are well explained in the **hparams.py** and have meaningful names so that you can try multiple things with them.

AWAIT DOCUMENTATION ON HPARAMS SHORTLY!!

# Preprocessing
Before running the following steps, please make sure you are inside **Tacotron-2 folder**

Expand Down
57 changes: 45 additions & 12 deletions datasets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ def save_wav(wav, path, sr):
#proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))

def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def save_wavenet_wav(wav, path, sr, inv_preemphasize, k):
wav = inv_preemphasis(wav, k, inv_preemphasize)
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, sr, wav.astype(np.int16))

def preemphasis(wav, k, preemphasize=True):
if preemphasize:
Expand Down Expand Up @@ -57,16 +59,18 @@ def get_hop_size(hparams):
return hop_size

def linearspectrogram(wav, hparams):
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
# D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
D = _stft(wav, hparams)
S = _amp_to_db(np.abs(D)**hparams.magnitude_power, hparams) - hparams.ref_level_db

if hparams.signal_normalization:
return _normalize(S, hparams)
return S

def melspectrogram(wav, hparams):
D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
# D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
D = _stft(wav, hparams)
S = _amp_to_db(_linear_to_mel(np.abs(D)**hparams.magnitude_power, hparams), hparams) - hparams.ref_level_db

if hparams.signal_normalization:
return _normalize(S, hparams)
Expand All @@ -79,7 +83,7 @@ def inv_linear_spectrogram(linear_spectrogram, hparams):
else:
D = linear_spectrogram

S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
S = _db_to_amp(D + hparams.ref_level_db)**(1/hparams.magnitude_power) #Convert back to linear

if hparams.use_lws:
processor = _lws_processor(hparams)
Expand All @@ -97,7 +101,7 @@ def inv_mel_spectrogram(mel_spectrogram, hparams):
else:
D = mel_spectrogram

S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db)**(1/hparams.magnitude_power), hparams) # Convert back to linear

if hparams.use_lws:
processor = _lws_processor(hparams)
Expand Down Expand Up @@ -127,7 +131,7 @@ def _stft(y, hparams):
if hparams.use_lws:
return _lws_processor(hparams).stft(y).T
else:
return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size, pad_mode='constant')

def _istft(y, hparams):
return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
Expand Down Expand Up @@ -155,11 +159,16 @@ def pad_lr(x, fsize, fshift):
return pad, pad + r
##########################################################
#Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
'''compute right padding (final frame)
'''
return int(fsize // 2)

assert pad_sides in (1, 2)
# return int(fsize // 2)
pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
if pad_sides == 1:
return 0, pad
else:
return pad // 2, pad // 2 + pad % 2

# Conversions
_mel_basis = None
Expand Down Expand Up @@ -216,3 +225,27 @@ def _denormalize(D, hparams):
return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
else:
return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)

def normalize_tf(S, hparams):
#[0, 1]
if hparams.normalize_for_wavenet:
if hparams.allow_clipping_in_normalization:
return tf.minimum(tf.maximum((S - hparams.min_level_db) / (-hparams.min_level_db),
-hparams.max_abs_value), hparams.max_abs_value)

else:
return (S - hparams.min_level_db) / (-hparams.min_level_db)

#[-max, max] or [0, max]
if hparams.allow_clipping_in_normalization:
if hparams.symmetric_mels:
return tf.minimum(tf.maximum((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
-hparams.max_abs_value), hparams.max_abs_value)
else:
return tf.minimum(tf.maximum(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0), hparams.max_abs_value)

assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
if hparams.symmetric_mels:
return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
else:
return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
13 changes: 10 additions & 3 deletions datasets/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,17 @@ def _process_utterance(mel_dir, linear_dir, wav_dir, index, wav_path, text, hpar
wav_path))
return None

#Pre-emphasize
wav = audio.preemphasis(wav, hparams.preemphasis, hparams.preemphasize)

#rescale wav
if hparams.rescale:
wav = wav / np.abs(wav).max() * hparams.rescaling_max

#Assert all audio is in [-1, 1]
if (wav > 1.).any() or (wav < -1.).any():
raise RuntimeError('wav has invalid value: {}'.format(wav))

#M-AILABS extra silence specific
if hparams.trim_silence:
wav = audio.trim_silence(wav, hparams)
Expand Down Expand Up @@ -125,10 +132,10 @@ def _process_utterance(mel_dir, linear_dir, wav_dir, index, wav_path, text, hpar
out = np.pad(out, (l, r), mode='constant', constant_values=constant_values)
else:
#Ensure time resolution adjustement between audio and mel-spectrogram
pad = audio.librosa_pad_lr(wav, hparams.n_fft, audio.get_hop_size(hparams))
l_pad, r_pad = audio.librosa_pad_lr(wav, hparams.n_fft, audio.get_hop_size(hparams), hparams.wavenet_pad_sides)

#Reflect pad audio signal (Just like it's done in Librosa to avoid frame inconsistency)
out = np.pad(out, pad, mode='reflect')
#Reflect pad audio signal on the right (Just like it's done in Librosa to avoid frame inconsistency)
out = np.pad(out, (l_pad, r_pad), mode='constant', constant_values=constant_values)

assert len(out) >= mel_frames * audio.get_hop_size(hparams)

Expand Down
11 changes: 9 additions & 2 deletions datasets/wavenet_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,17 @@ def _process_utterance(mel_dir, wav_dir, index, wav_path, hparams):
wav_path))
return None

#Pre-emphasize
wav = audio.preemphasis(wav, hparams.preemphasis, hparams.preemphasize)

#rescale wav
if hparams.rescale:
wav = wav / np.abs(wav).max() * hparams.rescaling_max

#Assert all audio is in [-1, 1]
if (wav > 1.).any() or (wav < -1.).any():
raise RuntimeError('wav has invalid value: {}'.format(wav))

#M-AILABS extra silence specific
if hparams.trim_silence:
wav = audio.trim_silence(wav, hparams)
Expand Down Expand Up @@ -112,10 +119,10 @@ def _process_utterance(mel_dir, wav_dir, index, wav_path, hparams):
out = np.pad(out, (l, r), mode='constant', constant_values=constant_values)
else:
#Ensure time resolution adjustement between audio and mel-spectrogram
pad = audio.librosa_pad_lr(wav, hparams.n_fft, audio.get_hop_size(hparams))
l_pad, r_pad = audio.librosa_pad_lr(wav, hparams.n_fft, audio.get_hop_size(hparams))

#Reflect pad audio signal (Just like it's done in Librosa to avoid frame inconsistency)
out = np.pad(out, pad, mode='reflect')
out = np.pad(out, (l_pad, r_pad), mode='constant', constant_values=constant_values)

assert len(out) >= mel_frames * audio.get_hop_size(hparams)

Expand Down
14 changes: 14 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM continuumio/anaconda3:latest
FROM tensorflow/tensorflow:latest-gpu-py3

RUN apt-get update
RUN apt-get install -y libasound-dev portaudio19-dev libportaudio2 libportaudiocpp0 ffmpeg libav-tools wget git vim

RUN wget http://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
RUN tar -jxvf LJSpeech-1.1.tar.bz2

RUN git clone https://github.com/Rayhane-mamah/Tacotron-2.git

WORKDIR Tacotron-2
RUN ln -s ../LJSpeech-1.1 .
RUN pip install -r requirements.txt
Loading

0 comments on commit 4dc16ce

Please sign in to comment.