Skip to content

Commit

Permalink
working commit used to compute v1 results for DCASE2020 FOA/MIC datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
sharathadavanne committed Jan 2, 2021
1 parent 13a6912 commit 48576f1
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 45 deletions.
2 changes: 0 additions & 2 deletions cls_feature_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,6 @@ def extract_all_labels(self):
create_folder(self._label_dir)

for file_cnt, file_name in enumerate(os.listdir(self._desc_dir)):
if len(file_name)!=26: #checking clean metadata files #TODO this is not required if the DCASE2020 dataset is clean
continue
wav_filename = '{}.wav'.format(file_name.split('.')[0])
desc_file_polar = self.load_output_format_file(os.path.join(self._desc_dir, file_name))
desc_file = self.convert_output_format_polar_to_cartesian(desc_file_polar)
Expand Down
5 changes: 3 additions & 2 deletions cls_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ def partial_compute_metric(self, dist_mat, gt_activity, pred_activity=None):

def get_results(self):
if self._is_baseline:
return self._localization_error/self._total_gt
localization_error = self._localization_error/self._total_gt
return 180.*localization_error/np.pi
else:
localization_error = self._localization_error/self._tp_doa
localization_recall = self._tp_doa/self._total_gt
return localization_error, localization_recall
return 180.*localization_error/np.pi, 100.*localization_recall



14 changes: 8 additions & 6 deletions doanet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,14 @@ def __init__(self, in_feat_shape, out_shape, params):
# Branch for activity detection
self.fnn_act_list = torch.nn.ModuleList()
if self.use_hnet and self.use_activity_out:
for fc_cnt in range(params['nb_fnn_act_layers']):
self.fnn_act_list.append(
torch.nn.Linear(params['fnn_act_size'] if fc_cnt else params['rnn_size'] , params['fnn_act_size'], bias=True)
)
self.fnn_act_list.append(
torch.nn.Linear(params['rnn_size'] , params['fnn_size'], bias=True)
torch.nn.Linear(params['fnn_act_size'] if params['nb_fnn_act_layers'] else params['rnn_size'], params['unique_classes'], bias=True)
)

self.fnn_act_list.append(
torch.nn.Linear(params['fnn_size'] if params['nb_fnn_layers'] else params['rnn_size'], params['unique_classes'], bias=True)
)

def forward(self, x):
'''input: (batch_size, mic_channels, time_steps, mel_bins)'''
Expand Down Expand Up @@ -203,8 +204,9 @@ def forward(self, x):
'''(batch_size, time_steps, label_dim)'''

if self.use_hnet and self.use_activity_out:
activity = torch.relu_(self.fnn_act_list[0](x_rnn))
activity = self.fnn_act_list[1](activity)
for fnn_cnt in range(len(self.fnn_act_list)-1):
x_rnn = torch.relu_(self.fnn_act_list[fnn_cnt](x_rnn))
activity = torch.tanh(self.fnn_act_list[-1](x_rnn))

return doa, activity
else:
Expand Down
43 changes: 35 additions & 8 deletions doanet_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_params(argv='1'):

# DATASET LOADING PARAMETERS
mode='dev', # 'dev' - development or 'eval' - evaluation dataset
dataset='foa', # 'foa' - ambisonic or 'mic' - microphone signals
dataset='mic', # 'foa' - ambisonic or 'mic' - microphone signals

#FEATURE PARAMS
fs=24000,
Expand All @@ -50,12 +50,16 @@ def get_params(argv='1'):

nb_fnn_layers=2,
fnn_size=128, # FNN contents, length of list = number of layers, list value = number of nodes
nb_epochs=100, # Train for maximum epochs

nb_fnn_act_layers=2,
fnn_act_size=128, # FNN contents, length of list = number of layers, list value = number of nodes

nb_epochs=200, # Train for maximum epochs
lr=1e-3,
dMOTA_wt = 1,
dMOTP_wt = 1,
dMOTP_wt = 50,
IDS_wt = 1,
branch_weights=[1, 1],
branch_weights=[1, 10.],
use_dmot_only=False,
)

Expand All @@ -64,14 +68,28 @@ def get_params(argv='1'):
print("USING DEFAULT PARAMETERS\n")

elif argv == '50':
params['use_dmotp_only']= True
params['use_hnet']= True
params['use_dmot_only']= True
params['feat_label_dir']='/scratch/asignal/sharath/DCASE2020_SELD_dataset/feat_label/'

elif argv == '51':
params['use_dmotp_only']= False
params['use_hnet']= True
params['use_dmot_only']= False
params['feat_label_dir']='/scratch/asignal/sharath/DCASE2020_SELD_dataset/feat_label/'

elif argv == '52':
params['use_hnet']= False
params['feat_label_dir']='/scratch/asignal/sharath/DCASE2020_SELD_dataset/feat_label_baseline/'

elif argv == '53':
params['use_hnet']= True
params['use_dmot_only']= True
params['feat_label_dir']='/scratch/asignal/sharath/DCASE2020_SELD_dataset/feat_label_augmented/'

elif argv == '54':
params['batch_size']= 256
params['label_sequence_length']= 60
params['use_hnet']= True
params['use_dmot_only']= False
params['feat_label_dir']='/scratch/asignal/sharath/DCASE2020_SELD_dataset/feat_label_augmented/'

elif argv == '55':
params['batch_size']= 128
Expand Down Expand Up @@ -243,6 +261,15 @@ def get_params(argv='1'):
elif argv == '96':
params['dMOTP_wt'] = 100

elif argv == '97':
params['branch_weights'] = [1, 10]

elif argv == '98':
params['branch_weights'] = [1, 50]

elif argv == '99':
params['branch_weights'] = [1, 100]

elif argv == '999':
print("QUICK TEST MODE\n")
params['quick_test'] = True
Expand Down
53 changes: 28 additions & 25 deletions train_doanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ def test_epoch(data_generator, model, hnet_model, activity_loss, criterion, metr
da_mat = da_mat.sigmoid() # (batch*sequence, max_nb_doas, max_nb_doas)
da_mat = da_mat.view(dist_mat.shape)
da_mat = (da_mat>0.5)*da_mat
da_activity = da_mat.max(-1)[0]

# Compute dMOTP loss for true positives
dMOTP_loss = (torch.mul(dist_mat, da_mat).sum(-1).sum(-1) * da_mat.sum(-1).sum(-1)*params['dMOTP_wt']).sum()/da_mat.sum()
# dMOTP_loss = torch.mul(dist_mat, da_mat).sum()/ da_mat.sum()

# Compute dMOTA loss
M = da_mat.max(-1)[0].sum(-1)
M = da_activity.sum(-1)
N = torch.Tensor(nb_framewise_doas_gt).to(device)
FP = torch.clamp(M-N, min=0)
FN = torch.clamp(N-M, min=0)
Expand All @@ -101,7 +102,7 @@ def test_epoch(data_generator, model, hnet_model, activity_loss, criterion, metr
train_dMOTA_loss += dMOTA_loss.item()
loss = dMOTP_loss+params['dMOTA_wt']*dMOTA_loss
if not params['use_dmot_only']:
act_loss = activity_loss(activity_out, da_mat.max(-1)[0])
act_loss = activity_loss(activity_out, (da_activity>0.5).float())
loss = params['branch_weights'][0] * loss + params['branch_weights'][1] * act_loss
train_act_loss += act_loss.item()
else:
Expand Down Expand Up @@ -147,6 +148,7 @@ def train_epoch(data_generator, optimizer, model, hnet_model, activity_loss, cri
else:
output, activity_out = model(data)
activity_out = activity_out.view(-1, activity_out.shape[-1])
# activity_binary = (torch.sigmoid(activity_out).cpu().detach().numpy() > 0.5)
else:
output = model(data)

Expand All @@ -168,32 +170,33 @@ def train_epoch(data_generator, optimizer, model, hnet_model, activity_loss, cri
da_mat = da_mat.sigmoid() # (batch*sequence, max_nb_doas, max_nb_doas)
da_mat = da_mat.view(dist_mat.shape)
da_mat = (da_mat>0.5)*da_mat
da_activity = da_mat.max(-1)[0]

# Compute dMOTP loss for true positives
dMOTP_loss = (torch.mul(dist_mat, da_mat).sum(-1).sum(-1) * da_mat.sum(-1).sum(-1)*params['dMOTP_wt']).sum()/da_mat.sum()
# dMOTP_loss = torch.mul(dist_mat, da_mat).sum()/ da_mat.sum()

# Compute dMOTA loss
M = da_mat.max(-1)[0].sum(-1)
M = da_activity.sum(-1)
N = torch.Tensor(nb_framewise_doas_gt).to(device)
FP = torch.clamp(M-N, min=0)
FN = torch.clamp(N-M, min=0)
IDS = (da_mat[1:]*(1-da_mat[:-1])).sum(-1).sum(-1)
IDS = torch.cat((torch.Tensor([0]).to(device), IDS))
dMOTA_loss = ((FP + FN + params['IDS_wt']* IDS).sum() / (M+ torch.finfo(torch.float32).eps).sum())

train_dMOTP_loss += dMOTP_loss.item()
train_dMOTA_loss += dMOTA_loss.item()
loss = dMOTP_loss+params['dMOTA_wt']*dMOTA_loss
if not params['use_dmot_only']:
act_loss = activity_loss(activity_out, da_mat.max(-1)[0])
act_loss = activity_loss(activity_out, (da_activity>0.5).float())
loss = params['branch_weights'][0] * loss + params['branch_weights'][1] * act_loss
train_act_loss += act_loss.item()
else:
loss = criterion(output, target)
loss.backward()
optimizer.step()

train_loss += loss.item()
nb_train_batches += 1
if params['quick_test'] and nb_train_batches == 4:
Expand Down Expand Up @@ -282,23 +285,24 @@ def main(argv):
params=params, split=val_splits[split_cnt], shuffle=False
)

# Collect the reference labels for validation data
# Collect i/o data size and load model configuration
data_in, data_out = data_gen_train.get_data_sizes()
print('FEATURES:\n\tdata_in: {}\n\tdata_out: {}\n'.format(data_in, data_out))
model = doanet_model.CRNN(data_in, data_out, params).to(device)
# model.load_state_dict(torch.load("models/95_4441371_foa_dev_split1_model.h5", map_location='cpu'))

print('---------------- DOA-net -------------------')
print('FEATURES:\n\tdata_in: {}\n\tdata_out: {}\n'.format(data_in, data_out))
print('MODEL:\n\tdropout_rate: {}\n\tCNN: nb_cnn_filt: {}, f_pool_size{}, t_pool_size{}\n\trnn_size: {}, fnn_size: {}\n'.format(
params['dropout_rate'], params['nb_cnn2d_filt'], params['f_pool_size'], params['t_pool_size'], params['rnn_size'],
params['fnn_size']))

model = doanet_model.CRNN(data_in, data_out, params).to(device)
# model.load_state_dict(torch.load("models/1_4415973_foa_dev_split1_model.h5", map_location='cpu'))
print('---------------- DOA-net -------------------')
print(model)
best_val_loss = 99999

# start training
best_val_epoch = -1
best_doa, best_recall = 180, 0
patience_cnt = 0

nb_epoch = 200 if params['quick_test'] else params['nb_epochs']
nb_epoch = 2 if params['quick_test'] else params['nb_epochs']
tr_loss_list = np.zeros(nb_epoch)
val_loss_list = np.zeros(nb_epoch)
hung_tr_loss_list = np.zeros(nb_epoch)
Expand All @@ -308,7 +312,6 @@ def main(argv):
criterion = torch.nn.MSELoss()
activity_loss = nn.BCEWithLogitsLoss()

# start training
for epoch_cnt in range(nb_epoch):
# ---------------------------------------------------------------------
# TRAINING
Expand All @@ -327,25 +330,25 @@ def main(argv):
val_hung_loss, val_recall_doa = val_metric.get_results()
else:
val_hung_loss = val_metric.get_results()
val_recall_doa = 100.
val_time = time.time() - start_time

# Save model if loss is good
if val_hung_loss < best_val_loss:
best_val_loss = val_hung_loss
best_val_epoch = epoch_cnt
if val_hung_loss < best_doa:
best_doa, best_val_epoch, best_recall = val_hung_loss, epoch_cnt, val_recall_doa
torch.save(model.state_dict(), model_name)

# Print stats and plot scores
print(
'epoch: {}, time: {:0.2f}/{:0.2f}, '
'train_loss: {:0.2f} {}, val_loss: {:0.2f} {}, '
'val_hung_loss_deg: {:0.3f}{}, '
'best_val_epoch: {}'.format(
'LE/LR: {:0.3f}/{}, '
'best_val_epoch: {} {}'.format(
epoch_cnt, train_time, val_time,
train_loss, '({:0.2f},{:0.2f},{:0.2f})'.format(train_dMOTP_loss, train_dMOTA_loss, train_act_loss) if params['use_hnet'] else '',
val_loss, '({:0.2f},{:0.2f},{:0.2f})'.format(val_dMOTP_loss, val_dMOTA_loss, val_act_loss) if params['use_hnet'] else '',
180*val_hung_loss/np.pi, '/{:0.2f}'.format(val_recall_doa*100.0) if params['use_hnet'] and not params['use_dmot_only']else '',
best_val_epoch)
val_hung_loss, '{:0.2f}'.format(val_recall_doa) if params['use_hnet'] and not params['use_dmot_only']else '100.0',
best_val_epoch, '({:0.2f}, {:0.2f})'.format(best_doa, best_recall) if params['use_hnet'] and not params['use_dmot_only']else '({:0.2f})'.format(best_doa))
)

tr_loss_list[epoch_cnt], val_loss_list[epoch_cnt], hung_val_loss_list[epoch_cnt] = train_loss, val_loss, val_hung_loss
Expand Down Expand Up @@ -375,9 +378,9 @@ def main(argv):
test_hung_loss = test_metric.get_results()

print(
'test_loss: {:0.2f} {}, test_hung_loss_deg: {:0.3f}{}'.format(
'test_loss: {:0.2f} {}, LE/LR: {:0.3f}/{}'.format(
test_loss, '({:0.2f},{:0.2f},{:0.2f})'.format(test_dMOTP_loss, test_dMOTA_loss, test_act_loss) if params['use_hnet'] else '',
180*test_hung_loss/np.pi, '/{:0.2f}'.format(test_recall_doa*100.0) if params['use_hnet'] and not params['use_dmot_only'] else '')
test_hung_loss, '{:0.2f}'.format(test_recall_doa) if params['use_hnet'] and not params['use_dmot_only'] else '100.0')
)


Expand Down
3 changes: 1 addition & 2 deletions visualize_doanet_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def main(argv):
with torch.no_grad():
file_cnt = 0
for data, target in data_gen_test.generate():

data, target = torch.tensor(data).to(device).float(), torch.tensor(target[:,:,:-1]).to(device).float()
data, target = torch.tensor(data).to(device).float(), torch.tensor(target[:,:,:-params['unique_classes']]).to(device).float()
output, activity_out = model(data)

# (batch, sequence, max_nb_doas*3) to (batch, sequence, 3, max_nb_doas)
Expand Down

0 comments on commit 48576f1

Please sign in to comment.