Skip to content

Commit

Permalink
Fixed windowing, buffering, and added multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
Joao Felipe Santos committed Nov 6, 2015
1 parent 4f00a7a commit 5c550f5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
7 changes: 5 additions & 2 deletions srmrpy/hilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def hilbert(x, N=None, axis=-1):
raise ValueError("x must be real.")
if N is None:
N = x.shape[axis]
# Make N multiple of 16 to make sure the transform will be fast
if N % 16:
N = int(ceil(N/16)*16)
if N <= 0:
raise ValueError("N must be positive.")

Expand All @@ -64,6 +67,6 @@ def hilbert(x, N=None, axis=-1):
ind = [newaxis] * x.ndim
ind[axis] = slice(None)
h = h[ind]
x = ifft(Xf * h, axis=axis)
return x
y = ifft(Xf * h, axis=axis)
return y[:x.shape[axis]]

2 changes: 2 additions & 0 deletions srmrpy/segmentaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def segment_axis(a, length, overlap=0, axis=None, end='cut', endvalue=0):
b[..., (overlap):(l_orig+overlap)] = a
b[..., (l_orig+overlap):] = endvalue
a = b
else:
raise ValueError("end has to be either 'cut', 'pad', 'wrap', or 'delay'.")

a = a.swapaxes(-1,axis)

Expand Down
50 changes: 31 additions & 19 deletions srmrpy/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from gammatone.filters import centre_freqs, make_erb_filters, erb_filterbank
from srmrpy.segmentaxis import segment_axis

from scipy.io.wavfile import read as readwav

def calc_erbs(low_freq, fs, n_filters):
ear_q = 9.26449 # Glasberg and Moore Parameters
min_bw = 24.7
Expand Down Expand Up @@ -49,27 +51,27 @@ def srmr(x, fs, n_cochlear_filters=23, low_freq=125, min_cf=4, max_cf=128, fast=
mod_filter_cfs = compute_modulation_cfs(min_cf, max_cf, 8)
MF = modulation_filterbank(mod_filter_cfs, mfs, 2)

n_frames = np.ceil((gt_env.shape[1])/wInc)
w = hamming(wLength)
n_frames = 1 + (gt_env.shape[1] - wLength)//wInc
w = hamming(wLength+1)[:-1] # window is periodic, not symmetric

energy = np.zeros((n_cochlear_filters, 8, n_frames))
for i, ac_ch in enumerate(gt_env):
mod_out = modfilt(MF, ac_ch)
for j, mod_ch in enumerate(mod_out):
mod_out_frame = segment_axis(mod_ch, wLength, overlap=wLength-wInc, end='delay')
energy[i,j,:] = np.sum((w*mod_out_frame)**2, axis=1)
mod_out_frame = segment_axis(mod_ch, wLength, overlap=wLength-wInc, end='pad')
energy[i,j,:] = np.sum((w*mod_out_frame[:n_frames])**2, axis=1)

if norm:
peak_energy = np.max(np.mean(energy, axis=0))
min_energy = peak_energy*0.001
energy[energy < min_energy] = min_energy
energy[energy > peak_energy] = peak_energy

erbs = np.flipud(calc_erbs(low_freq, fs, n_cochlear_filters))

avg_energy = np.mean(energy, axis=2)
total_energy = np.sum(avg_energy)

AC_energy = np.sum(avg_energy, axis=1)
AC_perc = AC_energy*100/total_energy

Expand All @@ -91,10 +93,20 @@ def srmr(x, fs, n_cochlear_filters=23, low_freq=125, min_cf=4, max_cf=128, fast=

return np.sum(avg_energy[:, :4])/np.sum(avg_energy[:, 4:Kstar]), energy

def process_file(f, args):
fs, s = readwav(f)
if np.issubdtype(s.dtype, np.int):
s = s.astype('float')/np.iinfo(s.dtype).max
r, energy = srmr(s, fs, n_cochlear_filters=args.n_cochlear_filters,
min_cf=args.min_cf,
max_cf=args.max_cf,
fast=args.fast,
norm=args.norm)
return f, r

def main():
import argparse
from scipy.io.wavfile import read as readwav
import numpy as np
import argparse, multiprocessing, functools

parser = argparse.ArgumentParser(description='Compute the SRMR metric for a given WAV file')
parser.add_argument('-f', '--fast', dest='fast', action='store_true', default=False,
help='Use the faster version based on the gammatonegram')
Expand All @@ -109,16 +121,16 @@ def main():
parser.add_argument('path', metavar='path', nargs='+',
help='Path of the file or files to be processed. Can also be a folder.')
args = parser.parse_args()
for f in args.path:
fs, s = readwav(f)
if np.issubdtype(s.dtype, np.int):
s = s.astype('float')/np.iinfo(s.dtype).max
r, energy = srmr(s, fs, n_cochlear_filters=args.n_cochlear_filters,
min_cf=args.min_cf,
max_cf=args.max_cf,
fast=args.fast,
norm=args.norm)
print('%s, %f' % (f, r))

if len(args.path) > 1:
p = multiprocessing.Pool(multiprocessing.cpu_count())
results = dict(p.map(functools.partial(process_file, args=args), args.path))
for f in args.path:
print('{}: {}'.format(f, results[f]))
else:
f, r = process_file(args.path[0], args)
print('{}: {}'.format(f, r))

if __name__ == '__main__':
main()

0 comments on commit 5c550f5

Please sign in to comment.