Skip to content

Commit

Permalink
Porting the torchaudio kaldi fbank implementation to audio_utils (hug…
Browse files Browse the repository at this point in the history
…gingface#26182)

* add kaldi fbank

* make style

* add herz_to_mel_kaldi tests

* add mel to hertz kaldi test

* integration tests

* correct test and remove comment

* make style

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <[email protected]>

* change parameter name

* Apply suggestions from Arthur review

Co-authored-by: Arthur <[email protected]>

* Update remove_dc_offset description

* fix bug  + make style

* fix error in using np.exp instead of np.power

* make style

---------

Co-authored-by: Sanchit Gandhi <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people authored Sep 21, 2023
1 parent b132c17 commit 9a30753
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 11 deletions.
46 changes: 35 additions & 11 deletions src/transformers/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
freq (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in hertz (Hz).
mel_scale (`str`, *optional*, defaults to `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`.
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns:
`float` or `np.ndarray`: The frequencies on the mel scale.
"""

if mel_scale not in ["slaney", "htk"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')

if mel_scale == "htk":
return 2595.0 * np.log10(1.0 + (freq / 700.0))
elif mel_scale == "kaldi":
return 1127.0 * np.log(1.0 + (freq / 700.0))

min_log_hertz = 1000.0
min_log_mel = 15.0
Expand All @@ -64,17 +66,19 @@ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Unio
mels (`float` or `np.ndarray`):
The frequency, or multiple frequencies, in mels.
mel_scale (`str`, *optional*, `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`.
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
Returns:
`float` or `np.ndarray`: The frequencies in hertz.
"""

if mel_scale not in ["slaney", "htk"]:
raise ValueError('mel_scale should be one of "htk" or "slaney".')
if mel_scale not in ["slaney", "htk", "kaldi"]:
raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')

if mel_scale == "htk":
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
elif mel_scale == "kaldi":
return 700.0 * (np.exp(mels / 1127.0) - 1.0)

min_log_hertz = 1000.0
min_log_mel = 15.0
Expand Down Expand Up @@ -120,6 +124,7 @@ def mel_filter_bank(
sampling_rate: int,
norm: Optional[str] = None,
mel_scale: str = "htk",
triangularize_in_mel_space: bool = False,
) -> np.ndarray:
"""
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
Expand Down Expand Up @@ -155,7 +160,10 @@ def mel_filter_bank(
norm (`str`, *optional*):
If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
mel_scale (`str`, *optional*, defaults to `"htk"`):
The mel frequency scale to use, `"htk"` or `"slaney"`.
The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
Returns:
`np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
Expand All @@ -164,15 +172,21 @@ def mel_filter_bank(
if norm is not None and norm != "slaney":
raise ValueError('norm must be one of None or "slaney"')

# frequencies of FFT bins in Hz
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)

# center points of the triangular mel filters
mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)

if triangularize_in_mel_space:
# frequencies of FFT bins in Hz, but filters triangularized in mel space
fft_bin_width = sampling_rate / (num_frequency_bins * 2)
fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
filter_freqs = mel_freqs
else:
# frequencies of FFT bins in Hz
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)

mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)

if norm is not None and norm == "slaney":
Expand Down Expand Up @@ -218,6 +232,7 @@ def window_function(
- `"boxcar"`: a rectangular window
- `"hamming"`: the Hamming window
- `"hann"`: the Hann window
- `"povey"`: the Povey window
Args:
window_length (`int`):
Expand All @@ -243,6 +258,8 @@ def window_function(
window = np.hamming(length)
elif name in ["hann", "hann_window"]:
window = np.hanning(length)
elif name in ["povey"]:
window = np.power(np.hanning(length), 0.85)
else:
raise ValueError(f"Unknown window function '{name}'")

Expand Down Expand Up @@ -281,6 +298,7 @@ def spectrogram(
reference: float = 1.0,
min_value: float = 1e-10,
db_range: Optional[float] = None,
remove_dc_offset: Optional[bool] = None,
dtype: np.dtype = np.float32,
) -> np.ndarray:
"""
Expand Down Expand Up @@ -363,6 +381,9 @@ def spectrogram(
db_range (`float`, *optional*):
Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
remove_dc_offset (`bool`, *optional*):
Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
`np.complex64`.
Expand Down Expand Up @@ -414,6 +435,9 @@ def spectrogram(
for frame_idx in range(num_frames):
buffer[:frame_length] = waveform[timestep : timestep + frame_length]

if remove_dc_offset:
buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()

if preemphasis is not None:
buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
buffer[0] *= 1 - preemphasis
Expand Down
105 changes: 105 additions & 0 deletions tests/utils/test_audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def test_hertz_to_mel(self):
expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])
self.assertTrue(np.allclose(hertz_to_mel(inputs, "slaney"), expected))

inputs = np.array([60, 100, 200, 1000, 1001, 2000])
expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
self.assertTrue(np.allclose(hertz_to_mel(inputs, "kaldi"), expected))

with pytest.raises(ValueError):
hertz_to_mel(100, mel_scale=None)

Expand All @@ -63,6 +67,10 @@ def test_mel_to_hertz(self):
expected = np.array([60, 100, 200, 1000, 1001, 2000])
self.assertTrue(np.allclose(mel_to_hertz(inputs, "slaney"), expected))

inputs = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])
expected = np.array([60, 100, 200, 1000, 1001, 2000])
self.assertTrue(np.allclose(mel_to_hertz(inputs, "kaldi"), expected))

with pytest.raises(ValueError):
mel_to_hertz(100, mel_scale=None)

Expand All @@ -89,6 +97,18 @@ def test_mel_filter_bank_shape(self):
)
self.assertEqual(mel_filters.shape, (513, 13))

mel_filters = mel_filter_bank(
num_frequency_bins=513,
num_mel_filters=13,
min_frequency=100,
max_frequency=4000,
sampling_rate=16000,
norm="slaney",
mel_scale="slaney",
triangularize_in_mel_space=True,
)
self.assertEqual(mel_filters.shape, (513, 13))

def test_mel_filter_bank_htk(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
Expand Down Expand Up @@ -153,6 +173,39 @@ def test_mel_filter_bank_slaney(self):
# fmt: on
self.assertTrue(np.allclose(mel_filters, expected))

def test_mel_filter_bank_kaldi(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
num_mel_filters=4,
min_frequency=0,
max_frequency=2000,
sampling_rate=4000,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)
# fmt: off
expected = np.array(
[[0.0000, 0.0000, 0.0000, 0.0000],
[0.6086, 0.0000, 0.0000, 0.0000],
[0.8689, 0.1311, 0.0000, 0.0000],
[0.4110, 0.5890, 0.0000, 0.0000],
[0.0036, 0.9964, 0.0000, 0.0000],
[0.0000, 0.6366, 0.3634, 0.0000],
[0.0000, 0.3027, 0.6973, 0.0000],
[0.0000, 0.0000, 0.9964, 0.0036],
[0.0000, 0.0000, 0.7135, 0.2865],
[0.0000, 0.0000, 0.4507, 0.5493],
[0.0000, 0.0000, 0.2053, 0.7947],
[0.0000, 0.0000, 0.0000, 0.9752],
[0.0000, 0.0000, 0.0000, 0.7585],
[0.0000, 0.0000, 0.0000, 0.5539],
[0.0000, 0.0000, 0.0000, 0.3599],
[0.0000, 0.0000, 0.0000, 0.1756]]
)
# fmt: on
self.assertTrue(np.allclose(mel_filters, expected, atol=5e-5))

def test_mel_filter_bank_slaney_norm(self):
mel_filters = mel_filter_bank(
num_frequency_bins=16,
Expand Down Expand Up @@ -271,6 +324,58 @@ def test_spectrogram_integration_test(self):
self.assertEqual(spec.shape, (257, 732))
self.assertTrue(np.allclose(spec[:64, 400], expected))

mel_filters = mel_filter_bank(
num_frequency_bins=256,
num_mel_filters=400,
min_frequency=20,
max_frequency=8000,
sampling_rate=16000,
norm=None,
mel_scale="kaldi",
triangularize_in_mel_space=True,
)

mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))

spec = spectrogram(
waveform,
window_function(400, "povey", periodic=False),
frame_length=400,
hop_length=160,
fft_length=512,
power=2.0,
center=False,
pad_mode="reflect",
onesided=True,
preemphasis=0.97,
mel_filters=mel_filters,
log_mel="log",
mel_floor=1.1920928955078125e-07,
remove_dc_offset=True,
)
self.assertEqual(spec.shape, (400, 584))

# fmt: off
expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-6.52463769, -7.73677889, -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -4.18650018, -3.37195286,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-4.70190154, -2.4217066 , -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -5.62755239, -3.53385194,
-15.94238515, -15.94238515, -15.94238515, -15.94238515,
-9.43303023, -8.77480925, -15.94238515, -15.94238515,
-15.94238515, -15.94238515, -4.2951092 , -5.51585994,
-15.94238515, -15.94238515, -15.94238515, -4.40151721,
-3.95228878, -15.94238515, -15.94238515, -15.94238515,
-6.10365415, -4.59494697, -15.94238515, -15.94238515,
-15.94238515, -8.10727767, -6.2585298 , -15.94238515,
-15.94238515, -15.94238515, -5.60161702, -4.47217004,
-15.94238515, -15.94238515, -15.94238515, -5.91641988]
)
# fmt: on
self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))

def test_spectrogram_center_padding(self):
waveform = self._load_datasamples(1)[0]

Expand Down

0 comments on commit 9a30753

Please sign in to comment.