forked from smarsland/AviaNZ
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Training.py
836 lines (737 loc) · 38 KB
/
Training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
# Version 3.0 14/09/20
# Authors: Stephen Marsland, Nirosha Priyadarshani, Julius Juodakis, Virginia Listanti
# AviaNZ bioacoustic analysis program
# Copyright (C) 2017--2020
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# Holds most of the code for training CNNs
import os, gc, re, json, tempfile
from shutil import copyfile
from shutil import disk_usage
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import model_from_json
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import numpy as np
import matplotlib.pyplot as plt
from time import strftime, gmtime
import math
import SupportClasses
import SignalProc
import CNN
import Segment, WaveletSegment
import AviaNZ_batch
import wavio
class CNNtrain:
def __init__(self, configdir, filterdir, folderTrain1=None, folderTrain2=None, recogniser=None, imgWidth=0, CLI=False):
# Two important things:
# 1. LearningParams.txt, which a dictionary of parameters *** including spectrogram parameters
# 2. CLI: whether it runs off the command line, which makes picking the ROC curve parameters hard
# Qn: what is imgWidth? Why not a learning param?
self.filterdir = filterdir
self.configdir =configdir
cl = SupportClasses.ConfigLoader()
self.FilterDict = cl.filters(filterdir, bats=False)
self.LearningDict = cl.learningParams(os.path.join(configdir, "LearningParams.txt"))
self.sp = SignalProc.SignalProc(self.LearningDict['sgramWindowWidth'], self.LearningDict['sgramHop'])
self.imgsize = [self.LearningDict['imgX'], self.LearningDict['imgY']]
self.tmpdir1 = False
self.tmpdir2 = False
self.ROCdata = {}
self.CLI = CLI
if CLI:
self.filterName = recogniser
self.folderTrain1 = folderTrain1
self.folderTrain2 = folderTrain2
self.imgWidth = imgWidth
self.autoThr = True
self.correction = True
self.annotatedAll = True
else:
self.autoThr = False
self.correction = False
self.imgWidth = imgWidth
def setP1(self, folderTrain1, folderTrain2, recogniser, annotationLevel):
# This is a function that the Wizard calls to set parameters
self.folderTrain1 = folderTrain1
self.folderTrain2 = folderTrain2
self.filterName = recogniser
self.annotatedAll = annotationLevel
def setP6(self, recogniser):
# This is a function that the Wizard calls to set parameters
self.newFilterName = recogniser
def cliTrain(self):
# This is the main training function for CLI-based learning.
# It proceeds very much like the wizard
# Get info from wavelet filter
self.readFilter()
# Load data
# Note: no error checking in the CLI version
# Find segments belong to each class in the training data
self.genSegmentDataset(hasAnnotation=True)
# Check on memory space
self.checkDisk()
# OK, WTF?
self.windowWidth = self.imgsize[0] * self.LearningDict['windowScaling']
self.windowInc = int(np.ceil(self.imgWidth * self.fs / (self.imgsize[1] - 1)) )
# Train
self.train()
# Save the output
self.saveFilter()
def readFilter(self):
# Read the current (wavelet) filter and get the details
if self.filterName.lower().endswith('.txt'):
self.currfilt = self.FilterDict[self.filterName[:-4]]
else:
self.currfilt = self.FilterDict[self.filterName]
self.fs = self.currfilt["SampleRate"]
self.species = self.currfilt["species"]
mincallengths = []
maxcallengths = []
f1 = []
f2 = []
self.maxgaps = []
self.calltypes = []
for fi in self.currfilt['Filters']:
self.calltypes.append(fi['calltype'])
mincallengths.append(fi['TimeRange'][0])
maxcallengths.append(fi['TimeRange'][1])
self.maxgaps.append(fi['TimeRange'][3])
f1.append(fi['FreqRange'][0])
f2.append(fi['FreqRange'][1])
self.mincallength = np.max(mincallengths)
self.maxcallength = np.max(maxcallengths)
self.f1 = np.min(f1)
self.f2 = np.max(f2)
print("Manually annotated: %s" % self.folderTrain1)
print("Auto processed and reviewed: %s" % self.folderTrain2)
print("Recogniser: %s" % self.currfilt)
print("Species: %s" % self.species)
print("Call types: %s" % self.calltypes)
print("Call length: %.2f - %.2f sec" % (self.mincallength, self.maxcallength))
print("Sample rate: %d Hz" % self.fs)
print("Frequency range: %d - %d Hz" % (self.f1, self.f2))
def checkDisk(self):
# Check disk usage
totalbytes, usedbytes, freebytes = disk_usage(os.path.expanduser("~"))
freeGB = freebytes/1024/1024/1024
print('\nFree space in the user directory: %.2f GB/ %.2f GB\n' % (freeGB, totalbytes/1024/1024/2014))
if freeGB < 10:
print('Warning: You may run out of space in the user directory!')
return freeGB, totalbytes/1024/1024/1024
def genSegmentDataset(self, hasAnnotation):
# Prepares segments for input to the learners
self.traindata = []
self.DataGen = CNN.GenerateData(self.currfilt, 0, 0, 0, 0, 0, 0, 0)
# For manually annotated data where the user is confident about full annotation,
# choose anything else in the spectrograms as noise examples
if self.annotatedAll=="All":
self.noisedata1 = self.DataGen.findNoisesegments(self.folderTrain1)
print('----noise data1:')
for x in self.noisedata1:
self.traindata.append(x)
if self.annotatedAll=="All-nowt":
self.noisedata1 = self.DataGen.findAllsegments(self.folderTrain1)
print('----noise data1:')
for x in self.noisedata1:
self.traindata.append(x)
# Call type segments
print('----CT data1:')
if hasAnnotation:
for i in range(len(self.calltypes)):
ctdata = self.DataGen.findCTsegments(self.folderTrain1, i)
print(self.calltypes[i])
for x in ctdata:
self.traindata.append(x)
# For wavelet outputs that have been manually verified get noise segments from .corrections
if os.path.isdir(self.folderTrain2):
for root, dirs, files in os.walk(str(self.folderTrain2)):
for file in files:
if file.lower().endswith('.wav') and file + '.corrections' in files:
# Read the .correction (from allspecies review)
cfile = os.path.join(root, file + '.corrections')
wavfile = os.path.join(root, file)
try:
f = open(cfile, 'r')
annots = json.load(f)
f.close()
except Exception as e:
print("ERROR: file %s failed to load with error:" % file)
print(e)
return
for seg in annots:
if isinstance(seg, dict):
continue
if len(seg) != 2:
print("Warning: old format corrections detected")
continue
oldlabel = seg[0][4]
# check in cases like: [kiwi] -> [kiwi, morepork]
# (these will be stored in .corrections, but aren't incorrect detections)
newsp = [lab["species"] for lab in seg[1]]
if len(oldlabel) != 1:
# this was made manually
print("Warning: ignoring labels with multiple species")
continue
if oldlabel[0]['species'] == self.species and self.species not in newsp:
# store this as "noise" calltype
self.traindata.append([wavfile, seg[0][:2], len(self.calltypes)])
self.correction = True
elif file.lower().endswith('.wav') and file + '.corrections_' + self.cleanSpecies(self.species) in files:
# Read the .correction (from single sp review)
cfile = os.path.join(root, file + '.corrections_' + self.cleanSpecies(self.species))
wavfile = os.path.join(root, file)
try:
f = open(cfile, 'r')
annots = json.load(f)
f.close()
except Exception as e:
print("ERROR: file %s failed to load with error:" % file)
print(e)
return
for seg in annots:
if isinstance(seg, dict):
continue
else:
# store this as "noise" calltype
self.traindata.append([wavfile, seg[:2], len(self.calltypes)])
self.correction = True
# Call type segments
print('----CT data2:')
for i in range(len(self.calltypes)):
ctdata = self.DataGen.findCTsegments(self.folderTrain2, i)
print(self.calltypes[i])
for x in ctdata:
self.traindata.append(x)
# How many of each class
target = np.array([rec[-1] for rec in self.traindata])
self.trainN = [np.sum(target == i) for i in range(len(self.calltypes) + 1)]
def genImgDataset(self, hop):
''' Generate training images for each calltype and noise'''
for ct in range(len(self.calltypes) + 1):
os.makedirs(os.path.join(self.tmpdir1.name, str(ct)))
self.imgsize[1], self.Nimg = self.DataGen.generateFeatures(dirName=self.tmpdir1.name, dataset=self.traindata, hop=hop)
def train(self):
# Create temp dir to hold img data and model
try:
if self.tmpdir1:
self.tmpdir1.cleanup()
if self.tmpdir2:
self.tmpdir2.cleanup()
except:
pass
self.tmpdir1 = tempfile.TemporaryDirectory(prefix='CNN_')
print('Temporary img dir:', self.tmpdir1.name)
self.tmpdir2 = tempfile.TemporaryDirectory(prefix='CNN_')
print('Temporary model dir:', self.tmpdir2.name)
# Find train segments belong to each class
self.DataGen = CNN.GenerateData(self.currfilt, self.imgWidth, self.windowWidth, self.windowInc, self.imgsize[0], self.imgsize[1], self.f1, self.f2)
# Find how many images with default hop (=imgWidth), adjust hop to make a good number of images also keep space
# for some in-built augmenting (width-shift)
hop = [self.imgWidth for i in range(len(self.calltypes)+1)]
imgN = self.DataGen.getImgCount(dirName=self.tmpdir1.name, dataset=self.traindata, hop=hop)
print('Expected number of images when no overlap: ', imgN)
print('Updating hop...')
hop = self.updateHop(imgN, hop)
imgN = self.DataGen.getImgCount(dirName=self.tmpdir1.name, dataset=self.traindata, hop=hop)
print('Expected number of images with updated hop: ', imgN)
print('Generating CNN images...')
self.genImgDataset(hop)
print('\nGenerated images:\n')
for i in range(len(self.calltypes)):
print("\t%s:\t%d\n" % (self.calltypes[i], self.Nimg[i]))
print("\t%s:\t%d\n" % ("Noise", self.Nimg[-1]))
# CNN training
cnn = CNN.CNN(self.configdir, self.species, self.calltypes, self.fs, self.imgWidth, self.windowWidth, self.windowInc, self.imgsize[0], self.imgsize[1])
# 1. Data augmentation
print('Data augmenting...')
filenames, labels = cnn.getImglist(self.tmpdir1.name)
labels = np.argmax(labels, axis=1)
ns = [np.shape(np.where(labels == i)[0])[0] for i in range(len(self.calltypes) + 1)]
# create image data augmentation generator in-build
datagen = ImageDataGenerator(width_shift_range=0.3, fill_mode='nearest')
# Data augmentation for each call type
for ct in range(len(self.calltypes) + 1):
if self.LearningDict['t'] - ns[ct] > self.LearningDict['batchsize']:
# load this ct images
samples = cnn.loadCTImg(os.path.join(self.tmpdir1.name, str(ct)))
# prepare iterator
it = datagen.flow(samples, batch_size=self.LearningDict['batchsize'])
# generate samples
batch = it.next()
for j in range(int((self.LearningDict['t'] - ns[ct]) / self.LearningDict['batchsize'])):
newbatch = it.next()
batch = np.vstack((batch, newbatch))
# Save augmented data
k = 0
for sgRaw in batch:
np.save(os.path.join(self.tmpdir1.name, str(ct), str(ct) + '_aug' + "%06d" % k + '.npy'),
sgRaw)
k += 1
try:
del batch
del samples
del newbatch
except:
pass
gc.collect()
# 2. TRAIN - use custom image generator
filenamesall, labelsall = cnn.getImglist(self.tmpdir1.name)
print('Final CNN images...')
labelsalld = np.argmax(labelsall, axis=1)
ns = [np.shape(np.where(labelsalld == i)[0])[0] for i in range(len(self.calltypes) + 1)]
for i in range(len(self.calltypes)):
print("\t%s:\t%d\n" % (self.calltypes[i], ns[i]))
print("\t%s:\t%d\n" % ("Noise", ns[-1]))
filenamesall, labelsall = shuffle(filenamesall, labelsall)
X_train_filenames, X_val_filenames, y_train, y_val = train_test_split(filenamesall, labelsall, test_size=self.LearningDict['test_size'], random_state=1)
training_batch_generator = CNN.CustomGenerator(X_train_filenames, y_train, self.LearningDict['batchsize'], self.tmpdir1.name, cnn.imageheight, cnn.imagewidth, 1)
validation_batch_generator = CNN.CustomGenerator(X_val_filenames, y_val, self.LearningDict['batchsize'], self.tmpdir1.name, cnn.imageheight, cnn.imagewidth, 1)
print('Creating CNN architecture...')
cnn.createArchitecture()
print('Training...')
cnn.train(modelsavepath=self.tmpdir2.name, training_batch_generator=training_batch_generator, validation_batch_generator=validation_batch_generator)
print('Training complete!')
self.bestThr = [[0, 0] for i in range(len(self.calltypes))]
self.bestThrInd = [0 for i in range(len(self.calltypes))]
# 3. Prepare ROC plots
print('Generating ROC statistics...')
# Load the model
# Find best weights
weights = []
epoch = []
for r, d, files in os.walk(self.tmpdir2.name):
for f in files:
if f.endswith('.h5') and 'weights' in f:
epoch.append(int(f.split('weights.')[-1][:2]))
weights.append(f)
j = np.argmax(epoch)
weightfile = weights[j]
model = os.path.join(self.tmpdir2.name, 'model.json')
self.bestweight = os.path.join(self.tmpdir2.name, weightfile)
# Load the model and prepare
jsonfile = open(model, 'r')
loadedmodeljson = jsonfile.read()
jsonfile.close()
model = model_from_json(loadedmodeljson)
# Load weights into new model
model.load_weights(self.bestweight)
# Compile the model
model.compile(loss=self.LearningDict['loss'], optimizer=self.LearningDict['optimizer'],
metrics=self.LearningDict['metrics'])
# model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print('Loaded CNN model from ', self.tmpdir2.name)
TPs = [0 for i in range(len(self.calltypes) + 1)]
FPs = [0 for i in range(len(self.calltypes) + 1)]
TNs = [0 for i in range(len(self.calltypes) + 1)]
FNs = [0 for i in range(len(self.calltypes) + 1)]
CTps = [[[] for i in range(len(self.calltypes) + 1)] for j in range(len(self.calltypes) + 1)]
# Do all the plots based on Validation set (eliminate augmented?)
# N = len(filenames)
N = len(X_val_filenames)
y_val = np.argmax(y_val, axis=1)
print('Validation data: ', N)
if os.path.isdir(self.tmpdir2.name):
print('Model directory exists')
else:
print('Model directory DOES NOT exist')
if os.path.isdir(self.tmpdir1.name):
print('Img directory exists')
else:
print('Img directory DOES NOT exist')
for i in range(int(np.ceil(N / self.LearningDict['batchsize_ROC']))):
# imagesb = cnn.loadImgBatch(filenames[i * self.LearningDict['batchsize_ROC']:min((i + 1) * self.LearningDict['batchsize_ROC'], N)])
# labelsb = labels[i * self.LearningDict['batchsize_ROC']:min((i + 1) * self.LearningDict['batchsize_ROC'], N)]
imagesb = cnn.loadImgBatch(X_val_filenames[i * self.LearningDict['batchsize_ROC']:min((i + 1) * self.LearningDict['batchsize_ROC'], N)])
labelsb = y_val[i * self.LearningDict['batchsize_ROC']:min((i + 1) * self.LearningDict['batchsize_ROC'], N)]
for ct in range(len(self.calltypes) + 1):
res, ctp = self.testCT(ct, imagesb, labelsb, model) # res=[thrlist, TPs, FPs, TNs, FNs], ctp=[[0to0 probs], [0to1 probs], [0to2 probs]]
for j in range(len(self.calltypes) + 1):
CTps[ct][j] += ctp[j]
if TPs[ct] == 0:
TPs[ct] = res[1]
FPs[ct] = res[2]
TNs[ct] = res[3]
FNs[ct] = res[4]
else:
TPs[ct] = [TPs[ct][i] + res[1][i] for i in range(len(TPs[ct]))]
FPs[ct] = [FPs[ct][i] + res[2][i] for i in range(len(FPs[ct]))]
TNs[ct] = [TNs[ct][i] + res[3][i] for i in range(len(TNs[ct]))]
FNs[ct] = [FNs[ct][i] + res[4][i] for i in range(len(FNs[ct]))]
self.Thrs = res[0]
print('Thrs: ', self.Thrs)
print('validation TPs[0]: ', TPs[0])
self.TPRs = [[0.0 for i in range(len(self.Thrs))] for j in range(len(self.calltypes) + 1)]
self.FPRs = [[0.0 for i in range(len(self.Thrs))] for j in range(len(self.calltypes) + 1)]
self.Precisions = [[0.0 for i in range(len(self.Thrs))] for j in range(len(self.calltypes) + 1)]
self.Accs = [[0.0 for i in range(len(self.Thrs))] for j in range(len(self.calltypes) + 1)]
plt.style.use('ggplot')
fig, axs = plt.subplots(len(self.calltypes) + 1, len(self.calltypes) + 1, sharey=True, sharex='col')
for ct in range(len(self.calltypes) + 1):
self.TPRs[ct] = [TPs[ct][i] / (TPs[ct][i] + FNs[ct][i]) for i in range(len(self.Thrs))]
self.FPRs[ct] = [FPs[ct][i] / (TNs[ct][i] + FPs[ct][i]) for i in range(len(self.Thrs))]
self.Precisions[ct] = [0.0 if (TPs[ct][i] + FPs[ct][i]) == 0 else TPs[ct][i] / (TPs[ct][i] + FPs[ct][i]) for i in range(len(self.Thrs))]
self.Accs[ct] = [(TPs[ct][i] + TNs[ct][i]) / (TPs[ct][i] + TNs[ct][i] + FPs[ct][i] + FNs[ct][i]) for
i in range(len(self.Thrs))]
# Temp plot is saved in train data directory - prediction probabilities for instances of current ct
for i in range(len(self.calltypes) + 1):
CTps[ct][i] = sorted(CTps[ct][i], key=float)
axs[i, ct].plot(CTps[ct][i], 'k')
axs[i, ct].plot(CTps[ct][i], 'bo')
if ct == i == len(self.calltypes):
axs[i, 0].set_ylabel('Noise')
axs[0, ct].set_title('Noise')
elif ct == i:
axs[i, 0].set_ylabel(str(self.calltypes[ct]))
axs[0, ct].set_title(str(self.calltypes[ct]))
if i == len(self.calltypes):
axs[i, ct].set_xlabel('Number of samples')
fig.suptitle('Human')
if self.folderTrain1:
fig.savefig(os.path.join(self.folderTrain1, 'validation-plots.png'))
print('Validation plot is saved: ', os.path.join(self.folderTrain1, 'validation-plots.png'))
else:
fig.savefig(os.path.join(self.folderTrain2, 'validation-plots.png'))
print('Validation plot is saved: ', os.path.join(self.folderTrain2, 'validation-plots.png'))
plt.close()
# Collate ROC daaa
self.ROCdata["TPR"] = self.TPRs
self.ROCdata["FPR"] = self.FPRs
self.ROCdata["thr"] = self.Thrs
print('TPR: ', self.ROCdata["TPR"])
print('FPR: ', self.ROCdata["FPR"])
# 4. Auto select the upper threshold (fpr = 0)
for ct in range(len(self.calltypes)):
try:
self.bestThr[ct][1] = self.Thrs[self.FPRs[ct].index(0.0)]
except:
self.bestThr[ct][1] = self.Thrs[len(self.FPRs[ct]) - 1]
# 5. Auto select lower threshold IF the user asked so
if self.autoThr:
for ct in range(len(self.calltypes)):
# Get min distance to ROC from (0 FPR, 1 TPR)
distarr = (np.float64(1) - self.TPRs[ct]) ** 2 + (np.float64(0) - self.FPRs[ct]) ** 2
self.thr_min_ind = np.unravel_index(np.argmin(distarr), distarr.shape)[0]
self.bestThr[ct][0] = self.Thrs[self.thr_min_ind]
self.bestThrInd[ct] = self.thr_min_ind
return True
def updateHop(self, imgN, hop):
''' Update hop'''
# Compare against the expected number of total images per class (t)
for i in range(len(self.calltypes) + 1):
fillratio1 = imgN[i] / (self.LearningDict['t'] - self.LearningDict['tWidthShift'])
fillratio2 = imgN[i] / self.LearningDict['t']
if fillratio1 < 0.75: # too less, decrease hop
if i == len(self.calltypes):
print('Noise: only %d images, adjusting hop from %.2f to %.2f' % (imgN[i], hop[i], hop[i]*fillratio1))
else:
print('%s: only %d images, adjusting hop from %.2f to %.2f' % (self.calltypes[i], imgN[i], hop[i], hop[i]*fillratio1))
hop[i] = hop[i]*fillratio1
elif fillratio1 > 1 and fillratio2 > 0.75: # increase hop and make room for augmenting
if i == len(self.calltypes):
print('Noise: %d images, adjusting hop from %.2f to %.2f' % (imgN[i], hop[i], hop[i]*fillratio1))
else:
print('%s: %d images, adjusting hop from %.2f to %.2f' % (self.calltypes[i], imgN[i], hop[i], hop[i]*fillratio1))
hop[i] = hop[i]*fillratio1
elif fillratio2 > 1: # too many, avoid hop
if i == len(self.calltypes):
print('Noise: %d images, adjusting hop from %.2f to %.2f' % (imgN[i], hop[i], hop[i]*fillratio2))
else:
print('%s: %d images, adjusting hop from %.2f to %.2f' % (self.calltypes[i], imgN[i], hop[i], hop[i]*fillratio2))
hop[i] = hop[i]*fillratio2
return hop
def testCT(self, ct, testimages, targets, model):
'''
:param ct: integer relevant to call type
:return: [thrlist, TPs, FPs, TNs, FNs], ctprob
'''
self.thrs = []
self.TPs = []
self.FPs = []
self.TNs = []
self.FNs = []
# Predict and temp plot
pre = model.predict(testimages)
ctprob = [[] for i in range(len(self.calltypes) + 1)]
for i in range(len(targets)):
if targets[i] == ct:
for ind in range(len(self.calltypes) + 1):
ctprob[ind].append(pre[i][ind])
# Get the stats over different thr
labels = [i for i in range(len(self.calltypes) + 1)]
for thr in np.linspace(0.00001, 1, 100):
predictions = [self.pred(p, thr=thr, ct=ct) for p in pre]
CM = confusion_matrix(predictions, targets, labels=labels)
TP = CM[ct][ct]
FP = np.sum(CM[ct][:]) - TP
colct = 0
for i in range(len(self.calltypes) + 1):
colct += CM[i][ct]
FN = colct - TP
TN = np.sum(CM) - FP - FN - TP
self.thrs.append(thr)
self.TPs.append(TP)
self.FPs.append(FP)
self.TNs.append(TN)
self.FNs.append(FN)
return [self.thrs, self.TPs, self.FPs, self.TNs, self.FNs], ctprob
def pred(self, p, thr, ct):
if p[ct] > thr:
prediction = ct
elif ct == len(self.calltypes):
prediction = 0
else:
prediction = len(self.calltypes)
return prediction
def saveFilter(self):
# Add CNN component to the current filter
self.addCNNFilter()
# CNNdic = {}
# CNNdic["CNN_name"] = "CNN_name"
# CNNdic["loss"] = self.LearningDict['loss']
# CNNdic["optimizer"] = self.LearningDict['optimizer']
# CNNdic["windowInc"] = [self.windowWidth,self.windowInc]
# CNNdic["win"] = [self.imgWidth,self.imgWidth/5] # TODO: remove hop
# CNNdic["inputdim"] = self.imgsize
# output = {}
# thr = []
# for ct in range(len(self.calltypes)):
# output[str(ct)] = self.calltypes[ct]
# thr.append(self.bestThr[ct])
# output[str(len(self.calltypes))] = "Noise"
# # thr.append(self.wizard().parameterPage.bestThr[len(self.calltypes)])
# CNNdic["output"] = output
# CNNdic["thr"] = thr
# print(CNNdic)
# self.currfilt["CNN"] = CNNdic
if self.CLI:
# write out the filter and CNN model
modelsrc = os.path.join(self.tmpdir2.name, 'model.json')
CNN_name = self.species + strftime("_%H-%M-%S", gmtime())
self.currfilt["CNN"]["CNN_name"] = CNN_name
rocfilename = self.species + "_ROCNN" + strftime("_%H-%M-%S", gmtime())
self.currfilt["ROCNN"] = rocfilename
rocfilename = os.path.join(self.filterdir, rocfilename + '.json')
modelfile = os.path.join(self.filterdir, CNN_name + '.json')
weightsrc = self.bestweight
weightfile = os.path.join(self.filterdir, CNN_name + '.h5')
filename = os.path.join(self.filterdir, self.filterName)
if not filename.lower().endswith('.txt'):
filename = filename + '.txt'
print("Updating the existing recogniser ", filename)
f = open(filename, 'w')
f.write(json.dumps(self.currfilt))
f.close()
# Actually copy the model
copyfile(modelsrc, modelfile)
copyfile(weightsrc, weightfile)
# save ROC
f = open(rocfilename, 'w')
f.write(json.dumps(self.ROCdata))
f.close()
# And remove temp dirs
self.tmpdir1.cleanup()
self.tmpdir2.cleanup()
print("Recogniser saved, don't forget to test it!")
def addCNNFilter(self):
# Add CNN component to the current filter
CNNdic = {}
CNNdic["CNN_name"] = "CNN_name"
CNNdic["loss"] = self.LearningDict['loss']
CNNdic["optimizer"] = self.LearningDict['optimizer']
CNNdic["windowInc"] = [self.windowWidth,self.windowInc]
CNNdic["win"] = [self.imgWidth,self.imgWidth/5] # TODO: remove hop
CNNdic["inputdim"] = [int(self.imgsize[0]), int(self.imgsize[1])]
if self.f1 == 0 and self.f2 == self.fs/2:
print('no frequency masking used')
else:
print('frequency masking used', self.f1, self.f2)
CNNdic["fRange"] = [int(self.f1), int(self.f2)]
output = {}
thr = []
for ct in range(len(self.calltypes)):
output[str(ct)] = self.calltypes[ct]
thr.append(self.bestThr[ct])
output[str(len(self.calltypes))] = "Noise"
# thr.append(self.wizard().parameterPage.bestThr[len(self.calltypes)])
CNNdic["output"] = output
CNNdic["thr"] = thr
print(CNNdic)
self.currfilt["CNN"] = CNNdic
def cleanSpecies(self, species):
""" Returns cleaned species name"""
return re.sub(r'[^A-Za-z0-9()-]', "_", species)
class CNNtest:
# Test a previously-trained CNN
def __init__(self,testDir,currfilt,filtname,configdir,filterdir,CLI=False):
""" currfilt: the recognizer to be used (dict) """
self.testDir = testDir
self.outfile = open(os.path.join(self.testDir, "test-results.txt"),"w")
self.currfilt = currfilt
self.filtname = filtname
self.configdir = configdir
self.filterdir = filterdir
# Note: this is just the species name, unlike the self.species in Batch mode
species = self.currfilt['species']
self.sampleRate = self.currfilt['SampleRate']
self.calltypes = []
for fi in self.currfilt['Filters']:
self.calltypes.append(fi['calltype'])
self.outfile.write("Recogniser name: %s\n" %(filtname))
self.outfile.write("Species name: %s\n" % (species))
self.outfile.write("Using data: %s\n" % (self.testDir))
# 0. Generate GT files from annotations in test folder
self.manSegNum = 0
self.window = 1
inc = None
print('Generating GT...')
for root, dirs, files in os.walk(self.testDir):
for file in files:
wavFile = os.path.join(root, file)
if file.lower().endswith('.wav') and os.stat(wavFile).st_size != 0 and file + '.data' in files:
segments = Segment.SegmentList()
segments.parseJSON(wavFile + '.data')
self.manSegNum += len(segments.getSpecies(species))
# Currently, we ignore call types here and just
# look for all calls for the target species.
segments.exportGT(wavFile, species, resolution=self.window)
if self.manSegNum == 0:
print("ERROR: no segments for species %s found" % species)
self.text = 0
return
# 1. Run Batch Processing upto WF and generate .tempdata files (no post-proc)
avianz_batch = AviaNZ_batch.AviaNZ_batchProcess(parent=None, configdir=self.configdir, mode="test", sdir=self.testDir, recogniser=filtname, wind=1)
# NOTE: will use wind-robust detection
# 2. Report statistics of WF followed by general post-proc steps (no CNN but wind-merge neighbours-delete short)
self.text = self.getSummary(CNN=False)
# 3. Report statistics of WF followed by post-proc steps (wind-CNN-merge neighbours-delete short)
if "CNN" in self.currfilt:
cl = SupportClasses.ConfigLoader()
filterlist = cl.filters(self.filterdir, bats=False)
CNNDicts = cl.CNNmodels(filterlist, self.filterdir, [filtname])
# Providing one filter, so only one CNN should be returned:
if len(CNNDicts)!=1:
print("ERROR: Couldn't find a unique matching CNN!")
self.outfile.write("No matching CNN found!\n")
self.outfile.write("-- End of testing --\n")
self.outfile.close()
return
CNNmodel = list(CNNDicts)[0]
self.text = self.getSummary(CNN=True)
self.outfile.write("-- End of testing --\n")
self.outfile.close()
print("Testing output written to " + os.path.join(self.testDir, "test-results.txt"))
# Tidy up
# for root, dirs, files in os.walk(self.testDir):
# for file in files:
# if file.endswith('.tmpdata'):
# os.remove(os.path.join(root, file))
def getOutput(self):
return self.text
def findCTsegments(self, datafile, calltypei):
calltypeSegments = []
species = self.currfilt["species"]
segments = Segment.SegmentList()
segments.parseJSON(datafile)
if len(self.calltypes) == 1:
ctSegments = segments.getSpecies(species)
else:
ctSegments = segments.getCalltype(species, self.calltypes[calltypei])
calltypeSegments = [segments[indx][:2] for indx in ctSegments]
return calltypeSegments
def getSummary(self, CNN=False):
autoSegCTnum = [0] * len(self.calltypes)
ws = WaveletSegment.WaveletSegment()
TP = FP = TN = FN = 0
for root, dirs, files in os.walk(self.testDir):
for file in files:
wavFile = os.path.join(root, file)
if file.lower().endswith('.wav') and os.stat(wavFile).st_size != 0 and \
file + '.tmpdata' in files and file[:-4] + '-GT.txt' in files:
# Extract all segments and back-convert to 0/1:
_, duration, _, _ = wavio.readFmt(wavFile)
duration = math.ceil(duration)
det01 = np.zeros(duration)
for i in range(len(self.calltypes)):
if CNN:
# read segments
ctsegments = self.findCTsegments(wavFile+'.tmpdata', i)
else:
# read segments from an identical postproc pipeline w/o CNN
ctsegments = self.findCTsegments(wavFile+'.tmp2data', i)
autoSegCTnum[i] += len(ctsegments)
for seg in ctsegments:
det01[math.floor(seg[0]):math.ceil(seg[1])] = 1
# get and parse the agreement metrics
GT = self.loadGT(os.path.join(root, file[:-4] + '-GT.txt'), duration)
_, _, tp, fp, tn, fn = ws.fBetaScore(GT, det01)
TP += tp
FP += fp
TN += tn
FN += fn
# Summary
total = TP + FP + TN + FN
if total == 0:
print("ERROR: failed to find any testing data")
return
if TP + FN != 0:
recall = TP / (TP + FN)
else:
recall = 0
if TP + FP != 0:
precision = TP / (TP + FP)
else:
precision = 0
if TN + FP != 0:
specificity = TN / (TN + FP)
else:
specificity = 0
accuracy = (TP + TN) / (TP + FP + TN + FN)
if CNN:
self.outfile.write("\n\n-- Wavelet Pre-Processor + CNN detection summary --\n")
else:
self.outfile.write("\n-- Wavelet Pre-Processor detection summary --\n")
self.outfile.write("TP | FP | TN | FN seconds:\t %.2f | %.2f | %.2f | %.2f\n" % (TP, FP, TN, FN))
self.outfile.write("Specificity:\t\t%.2f %%\n" % (specificity * 100))
self.outfile.write("Recall (sensitivity):\t%.2f %%\n" % (recall * 100))
self.outfile.write("Precision (PPV):\t%.2f %%\n" % (precision * 100))
self.outfile.write("Accuracy:\t\t%.2f %%\n\n" % (accuracy * 100))
self.outfile.write("Manually labelled segments:\t%d\n" % (self.manSegNum))
for i in range(len(self.calltypes)):
self.outfile.write("Auto suggested \'%s\' segments:\t%d\n" % (self.calltypes[i], autoSegCTnum[i]))
self.outfile.write("Total auto suggested segments:\t%d\n\n" % sum(autoSegCTnum))
if CNN:
text = "Wavelet Pre-Processor + CNN detection summary\n\n\tTrue Positives:\t%d seconds (%.2f %%)\n\tFalse Positives:\t%d seconds (%.2f %%)\n\tTrue Negatives:\t%d seconds (%.2f %%)\n\tFalse Negatives:\t%d seconds (%.2f %%)\n\n\tSpecificity:\t%.2f %%\n\tRecall:\t\t%.2f %%\n\tPrecision:\t%.2f %%\n\tAccuracy:\t%.2f %%\n" \
% (TP, TP * 100 / total, FP, FP * 100 / total, TN, TN * 100 / total, FN, FN * 100 / total,
specificity * 100, recall * 100, precision * 100, accuracy * 100)
else:
text = "Wavelet Pre-Processor detection summary\n\n\tTrue Positives:\t%d seconds (%.2f %%)\n\tFalse Positives:\t%d seconds (%.2f %%)\n\tTrue Negatives:\t%d seconds (%.2f %%)\n\tFalse Negatives:\t%d seconds (%.2f %%)\n\n\tSpecificity:\t%.2f %%\n\tRecall:\t\t%.2f %%\n\tPrecision:\t%.2f %%\n\tAccuracy:\t%.2f %%\n" \
% (TP, TP * 100 / total, FP, FP * 100 / total, TN, TN * 100 / total, FN, FN * 100 / total,
specificity * 100, recall * 100, precision * 100, accuracy * 100)
return text
def loadGT(self, filename, length):
import csv
annotation = []
# Get the segmentation from the txt file
with open(filename) as f:
reader = csv.reader(f, delimiter="\t")
d = list(reader)
if d[-1] == []:
d = d[:-1]
if len(d) != length:
print("ERROR: annotation length %d does not match file duration %d!" % (len(d), length))
self.annotation = []
return False
# for each second, store 0/1 presence:
for row in d:
annotation.append(int(row[1]))
return annotation