Skip to content

Commit

Permalink
Added frequency masking
Browse files Browse the repository at this point in the history
  • Loading branch information
Nirosha201 committed Jan 13, 2021
1 parent b5fa674 commit bf7d875
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 17 deletions.
4 changes: 3 additions & 1 deletion AviaNZ_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ def openFile(self, fileName=None):
self.listLoadFile(fileNameOld)

#self.fillFileList(self.SoundFileDir, current)
self.listFiles.setCurrentItem(current)
# self.listFiles.setCurrentItem(current) # TODO: Check this

def listLoadFile(self,current):
""" Listener for when the user clicks on a filename (also called by openFile() )
Expand Down Expand Up @@ -4765,6 +4765,7 @@ def segment(self):
# 5. Delete short segmentsost process to remove short segments, wind, rain, and use F0 check.
if str(alg) != 'Wavelets':
print('Segments detected: ', len(newSegments))
print(newSegments)
print('Post-processing...')
post = Segment.PostProcess(configdir=self.configdir, audioData=self.audiodata, sampleRate=self.sampleRate,
segments=newSegments, subfilter={})
Expand All @@ -4779,6 +4780,7 @@ def segment(self):
newSegments = post.segments
else:
print('Segments detected: ', sum(isinstance(seg, list) for subf in newSegments for seg in subf))
print(newSegments)
print('Post-processing...')
# load target CNN model if exists
self.CNNDicts = self.ConfigLoader.CNNmodels(self.FilterDicts, self.filtersDir, [filtname])
Expand Down
11 changes: 10 additions & 1 deletion CNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ class GenerateData:
.correction has segments for the noise class
3. when extracted pieces of sounds (of call types and noise) are presented TODO
"""
def __init__(self, filter, length, windowwidth, inc, imageheight, imagewidth):
def __init__(self, filter, length, windowwidth, inc, imageheight, imagewidth, f1, f2):
self.filter = filter
self.species = filter["species"]
# not sure if this is needed?
Expand All @@ -419,6 +419,8 @@ def __init__(self, filter, length, windowwidth, inc, imageheight, imagewidth):
for fi in filter['Filters']:
self.calltypes.append(fi['calltype'])
self.fs = filter["SampleRate"]
self.f1 = f1
self.f2 = f2
self.length = length
self.windowwidth = windowwidth
self.inc = inc
Expand Down Expand Up @@ -621,6 +623,13 @@ def generateFeatures(self, dirName, dataset, hop):
if sgstart < 0:
continue
sgRaw_i = sgRaw[sgstart:sgend, :]
# Frequency masking
bin_width = self.fs / 2 / np.shape(sgRaw_i)[1]
lb = int(np.ceil(self.f1 / bin_width))
ub = int(np.floor(self.f2 / bin_width))
sgRaw_i[:, 0:lb] = 0.0
sgRaw_i[:, ub:] = 0.0

maxg = np.max(sgRaw_i)
# Normalize and rotate
sgRaw_i = np.rot90(sgRaw_i / maxg)
Expand Down
117 changes: 110 additions & 7 deletions DialogsTraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,9 @@ def __init__(self, filtdir, config, configdir, parent=None):

# P3
self.parameterPage = BuildCNNWizard.WPageParameters(self.cnntrain, config)
self.parameterPage.registerField("frqMasked*", self.parameterPage.cbfrange, "isChecked")
self.parameterPage.registerField("f1*", self.parameterPage.f1, "value", self.parameterPage.f1.valueChanged)
self.parameterPage.registerField("f2*", self.parameterPage.f2, "value", self.parameterPage.f2.valueChanged)
self.addPage(self.parameterPage)

# add the Save & Test button
Expand All @@ -1863,7 +1866,7 @@ def __init__(self, cnntrain, config, parent=None):
self.setTitle('Select data')
self.setSubTitle('Choose the recogniser that you want to extend with CNN, then select training data.')

self.setMinimumSize(250, 200)
self.setMinimumSize(300, 200)
self.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
self.adjustSize()

Expand Down Expand Up @@ -2010,6 +2013,8 @@ def __init__(self, cnntrain, configdir, parent=None):
self.msgrecclens.setStyleSheet("QLabel { color : #808080; }")
self.msgrecfs = QLabel("")
self.msgrecfs.setStyleSheet("QLabel { color : #808080; }")
self.msgrecfrange = QLabel("")
self.msgrecfrange.setStyleSheet("QLabel { color : #808080; }")
self.warnLabel = QLabel("")
self.warnLabel.setStyleSheet("QLabel { color : #800000; }")
self.warnoise = QLabel("")
Expand Down Expand Up @@ -2046,7 +2051,8 @@ def __init__(self, cnntrain, configdir, parent=None):
layout.addWidget(self.msgreccts, 15, 2)
layout.addWidget(self.msgrecclens, 16, 2)
layout.addWidget(self.msgrecfs, 17, 2)
layout.addWidget(self.warnLabel, 18, 2)
layout.addWidget(self.msgrecfrange, 18, 2)
layout.addWidget(self.warnLabel, 19, 2)
self.setLayout(layout)

def initializePage(self):
Expand Down Expand Up @@ -2112,6 +2118,7 @@ def initializePage(self):
self.msgreccts.setText("<b>Call types:</b> %s" % (self.cnntrain.calltypes))
self.msgrecclens.setText("<b>Call length:</b> %.2f - %.2f sec" % (self.cnntrain.mincallength, self.cnntrain.maxcallength))
self.msgrecfs.setText("<b>Sample rate:</b> %d Hz" % (self.cnntrain.fs))
self.msgrecfrange.setText("<b>Frequency range:</b> %d - %d Hz" % (self.cnntrain.f1, self.cnntrain.f2))

for i in range(len(self.cnntrain.calltypes)):
self.msgseg.setText("%s:\t%d\t" % (self.msgseg.text() + self.cnntrain.calltypes[i], self.cnntrain.trainN[i]))
Expand Down Expand Up @@ -2181,7 +2188,7 @@ def __init__(self, cnntrain, config,parent=None):
self.setSizePolicy(QSizePolicy.Minimum, QSizePolicy.Minimum)
self.adjustSize()

self.cnntrain = cnntrain
# self.cnntrain = cnntrain
self.config = config
self.indx = np.ndarray(0)

Expand All @@ -2197,6 +2204,26 @@ def __init__(self, cnntrain, config,parent=None):
self.imgsec.valueChanged.connect(self.imglenChange)
self.imgtext = QLabel('0.25 sec')

self.cbfrange = QCheckBox("Limit frequency range")
self.cbfrange.setStyleSheet("QCheckBox { font-weight: bold; }")
self.cbfrange.toggled.connect(self.onClickedFrange)

self.f1 = QSlider(Qt.Horizontal)
self.f1.setTickPosition(QSlider.TicksBelow)
self.f1.setTickInterval(1000)
# self.f1.setRange(0, self.cnntrain.fs) # 0-6 sec
# self.f1.setValue(self.cnntrain.f1)
# self.f1.valueChanged.connect(self.f1Change)
self.f1text = QLabel('')

self.f2 = QSlider(Qt.Horizontal)
self.f2.setTickPosition(QSlider.TicksBelow)
self.f2.setTickInterval(1000)
# self.f2.setRange(0, self.cnntrain.fs) # 0-6 sec
# self.f2.setValue(self.cnntrain.f2)
# self.f2.valueChanged.connect(self.f2Change)
self.f2text = QLabel('')

space = QLabel()
space.setFixedSize(10, 30)
msglayout = QVBoxLayout()
Expand All @@ -2215,12 +2242,25 @@ def __init__(self, cnntrain, config,parent=None):

layout0 = QVBoxLayout()
layout0.addLayout(msglayout)
layout0.addWidget(space)
# layout0.addWidget(space)
layout0.addWidget(QLabel('<b>Choose call length (sec) you want to show to CNN</b>'))
layout0.addWidget(QLabel('Make sure an image covers at least couple of syllables when appropriate'))
layout0.addWidget(space)
# layout0.addWidget(space)
layout0.addWidget(self.imgtext)
layout0.addWidget(self.imgsec)
layout0.addWidget(self.cbfrange)
layout0a = QHBoxLayout()
layout0a1 = QVBoxLayout()
# layout0a1.addWidget(QLabel('Lower frq. limit (Hz)'))
layout0a1.addWidget(self.f1text)
layout0a1.addWidget(self.f1)
layout0a2 = QVBoxLayout()
# layout0a2.addWidget(QLabel('Upper frq. limit (Hz)'))
layout0a2.addWidget(self.f2text)
layout0a2.addWidget(self.f2)
layout0a.addLayout(layout0a1)
layout0a.addLayout(layout0a2)
layout0.addLayout(layout0a)

layout2 = QVBoxLayout()
layout2.addWidget(QLabel('<i>Example images from your dataset</i>'))
Expand Down Expand Up @@ -2252,6 +2292,23 @@ def __init__(self, cnntrain, config,parent=None):
self.setButtonText(QWizard.NextButton, 'Generate CNN images and Train>')

def initializePage(self):
self.cnntrain = self.wizard().confirminputPage.cnntrain
self.cnntrain.windowWidth = 512
self.cnntrain.windowInc = 256
self.f1.setRange(0, self.cnntrain.fs/2)
self.f1.setValue(0)
self.f1text.setText('Lower frq. limit 0 Hz')
self.f2.setRange(0, self.cnntrain.fs/2)
self.f2.setValue(self.cnntrain.fs/2)
self.f2text.setText('Upper frq. limit ' + str(self.cnntrain.fs/2) + ' Hz')
self.f1.valueChanged.connect(self.f1Change)
self.f2.valueChanged.connect(self.f2Change)
self.cbfrange.setChecked(False)
self.f1text.setEnabled(False)
self.f1.setEnabled(False)
self.f2text.setEnabled(False)
self.f2.setEnabled(False)

self.wizard().button(QWizard.NextButton).setDefault(False)
self.msgspp.setText("<b>Species:</b> %s" % (self.cnntrain.species))

Expand Down Expand Up @@ -2285,6 +2342,28 @@ def onClicked(self):
self.redopages = True
self.completeChanged.emit()

def onClickedFrange(self):
cbutton = self.sender()
if cbutton.isChecked():
self.f1.setEnabled(True)
self.f1text.setEnabled(True)
self.f2.setEnabled(True)
self.f2text.setEnabled(True)
if self.f1.value() == 0 and self.f2.value() == self.cnntrain.fs/2:
self.f1.setValue(self.cnntrain.f1)
self.f2.setValue(self.cnntrain.f2)
self.f1text.setText('Lower frq. limit ' + str(self.cnntrain.f1) + ' Hz')
self.f2text.setText('Upper frq. limit ' + str(self.cnntrain.f2) + ' Hz')
else:
self.f1.setValue(0)
self.f2.setValue(self.cnntrain.fs)
self.f1text.setText('Lower frq. limit ' + str(0) + ' Hz')
self.f2text.setText('Upper frq. limit ' + str(self.cnntrain.fs/2) + ' Hz')
self.f1.setEnabled(False)
self.f1text.setEnabled(False)
self.f2.setEnabled(False)
self.f2text.setEnabled(False)

def loadFile(self, filename, duration=0, offset=0, fs=0):
"""
Read audio file.
Expand Down Expand Up @@ -2315,7 +2394,16 @@ def showimg(self, indices=[]):
self.cnntrain.sp.data = audiodata
self.cnntrain.sp.sampleRate = self.cnntrain.fs
sgRaw = self.cnntrain.sp.spectrogram(window_width=self.cnntrain.windowWidth, incr=self.cnntrain.windowInc)
# Frequency masking
f1 = self.f1.value()
f2 = self.f2.value()
# Mask out of band elements
bin_width = self.cnntrain.fs / 2 / np.shape(sgRaw)[1]
lb = int(np.ceil(f1 / bin_width))
ub = int(np.floor(f2 / bin_width))
maxsg = np.min(sgRaw)
sgRaw[:, 0:lb] = 0.0
sgRaw[:, ub:] = 0.0
self.sg = np.abs(np.where(sgRaw == 0, 0.0, 10.0 * np.log10(sgRaw / maxsg)))
self.setColourMap()
picbtn = SupportClasses_GUI.PicButton(1, np.fliplr(self.sg), self.cnntrain.sp.data, self.cnntrain.sp.audioFormat, self.imgsec.value(), 0, 0, self.lut, self.colourStart, self.colourEnd, False,
Expand Down Expand Up @@ -2374,6 +2462,22 @@ def imglenChange(self, value):
self.setWindowInc()
self.showimg(self.indx)

def f1Change(self, value):
value = value//10*10
if value < 0:
value = 0
self.cnntrain.f1 = value
self.f1text.setText('Lower frq. limit ' + str(value) + ' Hz')
self.showimg(self.indx)

def f2Change(self, value):
value = value//10*10
if value < 0:
value = 0
self.cnntrain.f2 = value
self.f2text.setText('Upper frq. limit ' + str(value) + ' Hz')
self.showimg(self.indx)

def cleanupPage(self):
self.imgDirwarn.setText('')
self.img1.setText('')
Expand Down Expand Up @@ -2724,5 +2828,4 @@ def undoROCPages(self):
self.addPage(self.savePage)

self.parameterPage.setFinalPage(False)
self.parameterPage.completeChanged.emit()

self.parameterPage.completeChanged.emit()
18 changes: 14 additions & 4 deletions Segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,11 @@ def __init__(self, configdir, audioData=None, sampleRate=0, tgtsampleRate=0, seg
self.CNNoutputs = CNNmodel[3]
self.CNNwindowInc = CNNmodel[4]
self.CNNthrs = CNNmodel[5]
if CNNmodel[6]:
self.CNNfMask = True
self.CNNfRange = CNNmodel[7]
else:
self.CNNfMask = False
self.tgtsampleRate = tgtsampleRate
else:
self.CNNmodel = None
Expand Down Expand Up @@ -1284,8 +1289,8 @@ def generateFeaturesCNN_frqMasked(self, seg, data, fs):
sp.data = data
sp.sampleRate = fs
_ = sp.spectrogram()
f1 = 90 - 50 # TODO: hardcoded for bittern testing
f2 = 250 + 100
f1 = self.CNNfRange[0]
f2 = self.CNNfRange[1]
# Mask out of band elements
bin_width = fs / 2 / np.shape(sp.sg)[1]
lb = int(np.ceil(f1 / bin_width))
Expand Down Expand Up @@ -1399,7 +1404,7 @@ def CNN1(self):

for ix in reversed(range(len(self.segments))):
seg = self.segments[ix]
# print('\n--- Segment', seg)
print('\n--- Segment', seg)
if seg[0][1] - seg[0][0] > max(self.syllen, 1):
n = 5
else:
Expand Down Expand Up @@ -1501,6 +1506,7 @@ def CNN(self):

for ix in reversed(range(len(self.segments))):
seg = self.segments[ix]
print('\n--- Segment', seg)
# expand the segment if its too small
callength = max(self.CNNwindow, self.maxLen/2)
if callength >= seg[0][1] - seg[0][0]:
Expand All @@ -1522,7 +1528,10 @@ def CNN(self):
sp.sampleRate = self.sampleRate
if self.sampleRate != self.tgtsampleRate:
sp.resample(self.tgtsampleRate)
featuress = self.generateFeaturesCNN(seg=seg[0], data=sp.data, fs=sp.sampleRate)
if self.CNNfMask:
featuress = self.generateFeaturesCNN_frqMasked(seg=seg[0], data=sp.data, fs=sp.sampleRate)
else:
featuress = self.generateFeaturesCNN(seg=seg[0], data=sp.data, fs=sp.sampleRate)
# featuress = self.generateFeaturesCNN_frqMasked(seg=seg[0], data=sp.data, fs=sp.sampleRate)
# featuress = self.generateFeaturesCNN2(seg=seg[0], data=sp.data, fs=sp.sampleRate)
featuress = np.array(featuress)
Expand All @@ -1534,6 +1543,7 @@ def CNN(self):
probs = self.CNNmodel.predict(featuress)
else:
probs = 0
print("probabilities: ", probs)

if isinstance(probs, int):
# Zero images from this segment, very unlikely to be a true seg.
Expand Down
11 changes: 9 additions & 2 deletions SupportClasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,15 @@ def CNNmodels(self, filters, dircnn, targetspecies):
model.load_weights(os.path.join(dircnn, filt["CNN"]["CNN_name"]) + '.h5')
print('Loaded model:', os.path.join(dircnn, filt["CNN"]["CNN_name"]))
model.compile(loss=filt["CNN"]["loss"], optimizer=filt["CNN"]["optimizer"], metrics=['accuracy'])
targetmodels[species] = [model, filt["CNN"]["win"], filt["CNN"]["inputdim"], filt["CNN"]["output"],
filt["CNN"]["windowInc"], filt["CNN"]["thr"]]
if 'fRange' in filt["CNN"]:
targetmodels[species] = [model, filt["CNN"]["win"], filt["CNN"]["inputdim"],
filt["CNN"]["output"],
filt["CNN"]["windowInc"], filt["CNN"]["thr"], True,
filt["CNN"]["fRange"]]
else:
targetmodels[species] = [model, filt["CNN"]["win"], filt["CNN"]["inputdim"],
filt["CNN"]["output"], filt["CNN"]["windowInc"],
filt["CNN"]["thr"], False]
except Exception as e:
print("Could not load CNN model from file:", os.path.join(dircnn, filt["CNN"]["CNN_name"]))
print(e)
Expand Down
Loading

0 comments on commit bf7d875

Please sign in to comment.