From 38b7a46327173cb9a00695ba8fc52efc053b73cd Mon Sep 17 00:00:00 2001 From: kan-bayashi Date: Tue, 28 Jul 2020 15:31:49 +0900 Subject: [PATCH] added test for energy --- espnet2/tts/feats_extract/energy.py | 6 +++- test/espnet2/tts/feats_extract/test_energy.py | 31 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 test/espnet2/tts/feats_extract/test_energy.py diff --git a/espnet2/tts/feats_extract/energy.py b/espnet2/tts/feats_extract/energy.py index 3bd71f73c22..a9c4797410f 100644 --- a/espnet2/tts/feats_extract/energy.py +++ b/espnet2/tts/feats_extract/energy.py @@ -66,9 +66,13 @@ def get_parameters(self) -> Dict[str, Any]: return dict( fs=self.fs, n_fft=self.n_fft, - n_shift=self.hop_length, + hop_length=self.hop_length, window=self.window, win_length=self.win_length, + center=self.stft.center, + pad_mode=self.stft.pad_mode, + normalized=self.stft.normalized, + use_token_averaged_energy=self.use_token_averaged_energy ) def forward( diff --git a/test/espnet2/tts/feats_extract/test_energy.py b/test/espnet2/tts/feats_extract/test_energy.py new file mode 100644 index 00000000000..66877c4612e --- /dev/null +++ b/test/espnet2/tts/feats_extract/test_energy.py @@ -0,0 +1,31 @@ +import pytest +import torch + +from espnet2.tts.feats_extract.energy import Energy + + +@pytest.mark.parametrize("use_token_averaged_energy", [False, True]) +def test_forward(use_token_averaged_energy): + layer = Energy( + n_fft=128, + hop_length=64, + fs="16k", + use_token_averaged_energy=use_token_averaged_energy, + ) + x = torch.randn(2, 256) + if not use_token_averaged_energy: + layer(x, torch.LongTensor([256, 128])) + else: + d = torch.LongTensor([[1, 2, 2], [3, 0, 0]]) + dlens = torch.LongTensor([3, 1]) + layer(x, torch.LongTensor([256, 128]), durations=d, durations_lengths=dlens) + + +def test_output_size(): + layer = Energy(n_fft=4, hop_length=1, fs="16k") + print(layer.output_size()) + + +def test_get_parameters(): + layer = Energy(n_fft=4, hop_length=1, fs="16k") + print(layer.get_parameters())