Skip to content

Commit

Permalink
Merge pull request fgnt#29 from boeddeker/master
Browse files Browse the repository at this point in the history
update docstrings for numpy implementations
  • Loading branch information
LukasDrude authored Jun 19, 2019
2 parents db03e9b + 911b9a4 commit 54e2e3f
Showing 1 changed file with 71 additions and 3 deletions.
74 changes: 71 additions & 3 deletions nara_wpe/wpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,28 @@ def wpe_v0(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu

def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='full'):
"""
Batched WPE implementation.
Short of wpe_v7 with no extern references.
Applicable in for-loops.
Args:
Y: Complex valued STFT signal with shape (..., D, T).
taps: Filter order
delay: Delay as a guard interval, such that X does not become zero.
iterations:
psd_context: Defines the number of elements in the time window
to improve the power estimation. Total number of elements will
be (psd_context + 1 + psd_context).
statistics_mode: Either 'full' or 'valid'.
'full': Pad the observation with zeros on the left for the
estimation of the correlation matrix and vector.
'valid': Only calculate correlation matrix and vector on valid
slices of the observation.
Returns:
Estimated signal with the same shape as Y
"""

if statistics_mode == 'full':
Expand All @@ -394,7 +414,24 @@ def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu

def wpe_v7(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='full'):
"""
Modular wpe version.
Batched and modular WPE version.
Args:
Y: Complex valued STFT signal with shape (..., D, T).
taps: Filter order
delay: Delay as a guard interval, such that X does not become zero.
iterations:
psd_context: Defines the number of elements in the time window
to improve the power estimation. Total number of elements will
be (psd_context + 1 + psd_context).
statistics_mode: Either 'full' or 'valid'.
'full': Pad the observation with zeros on the left for the
estimation of the correlation matrix and vector.
'valid': Only calculate correlation matrix and vector on valid
slices of the observation.
Returns:
Estimated signal with the same shape as Y
"""
X = Y
Y_tilde = build_y_tilde(Y, taps, delay)
Expand All @@ -415,7 +452,36 @@ def wpe_v7(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu

def wpe_v8(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='full'):
"""
v8 is faster than v7 and offers an optional batch mode.
Loopy Multiple Input Multiple Output Weighted Prediction Error [1, 2] implementation
Is many cases this implementation is the fastes numpy implementation.
It loops over the independent axes. This reduces the memory footprint
and in experiments it turned out that this is faster than a batched
implementation (i.e. `wpe_v6` or `wpe_v7`).
Args:
Y: Complex valued STFT signal with shape (..., D, T).
taps: Filter order
delay: Delay as a guard interval, such that X does not become zero.
iterations:
psd_context: Defines the number of elements in the time window
to improve the power estimation. Total number of elements will
be (psd_context + 1 + psd_context).
statistics_mode: Either 'full' or 'valid'.
'full': Pad the observation with zeros on the left for the
estimation of the correlation matrix and vector.
'valid': Only calculate correlation matrix and vector on valid
slices of the observation.
Returns:
Estimated signal with the same shape as Y
[1] "Generalization of multi-channel linear prediction methods for blind MIMO
impulse response shortening", Yoshioka, Takuya and Nakatani, Tomohiro, 2012
[2] NARA-WPE: A Python package for weighted prediction error dereverberation in
Numpy and Tensorflow for online and offline processing, Drude, Lukas and
Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold, 2018
"""
ndim = Y.ndim
if ndim == 2:
Expand Down Expand Up @@ -452,7 +518,9 @@ def wpe_v8(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='fu
else:
return np.stack(out, axis=batch_axis)
else:
raise NotImplementedError('Input shape is to be (F, D, T) or (D, T).')
raise NotImplementedError(
'Input shape has to be (..., D, T) and not {}.'.format(Y.shape)
)


wpe = wpe_v7
Expand Down

0 comments on commit 54e2e3f

Please sign in to comment.