Skip to content

Commit

Permalink
added test for energy
Browse files Browse the repository at this point in the history
  • Loading branch information
kan-bayashi committed Jul 28, 2020
1 parent 75fe817 commit 38b7a46
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
6 changes: 5 additions & 1 deletion espnet2/tts/feats_extract/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def get_parameters(self) -> Dict[str, Any]:
return dict(
fs=self.fs,
n_fft=self.n_fft,
n_shift=self.hop_length,
hop_length=self.hop_length,
window=self.window,
win_length=self.win_length,
center=self.stft.center,
pad_mode=self.stft.pad_mode,
normalized=self.stft.normalized,
use_token_averaged_energy=self.use_token_averaged_energy
)

def forward(
Expand Down
31 changes: 31 additions & 0 deletions test/espnet2/tts/feats_extract/test_energy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
import torch

from espnet2.tts.feats_extract.energy import Energy


@pytest.mark.parametrize("use_token_averaged_energy", [False, True])
def test_forward(use_token_averaged_energy):
layer = Energy(
n_fft=128,
hop_length=64,
fs="16k",
use_token_averaged_energy=use_token_averaged_energy,
)
x = torch.randn(2, 256)
if not use_token_averaged_energy:
layer(x, torch.LongTensor([256, 128]))
else:
d = torch.LongTensor([[1, 2, 2], [3, 0, 0]])
dlens = torch.LongTensor([3, 1])
layer(x, torch.LongTensor([256, 128]), durations=d, durations_lengths=dlens)


def test_output_size():
layer = Energy(n_fft=4, hop_length=1, fs="16k")
print(layer.output_size())


def test_get_parameters():
layer = Energy(n_fft=4, hop_length=1, fs="16k")
print(layer.get_parameters())

0 comments on commit 38b7a46

Please sign in to comment.