Skip to content

Commit

Permalink
add specaug
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed May 10, 2022
1 parent 4280d58 commit 9066122
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_csv_spk_to_utts(csv_file):

def get_triplet(spk_to_utts):
"""Get a triplet of anchor/pos/neg samples."""
pos_spk, neg_spk = random.sample(spk_to_utts.keys(), 2)
pos_spk, neg_spk = random.sample(list(spk_to_utts.keys()), 2)
anchor_utt, pos_utt = random.sample(spk_to_utts[pos_spk], 2)
neg_utt = random.sample(spk_to_utts[neg_spk], 1)[0]
return (anchor_utt, pos_utt, neg_utt)
14 changes: 9 additions & 5 deletions feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import myconfig
import dataset
import specaug


def extract_features(audio_file):
Expand Down Expand Up @@ -43,11 +44,14 @@ def get_triplet_features(spk_to_utts):
extract_features(neg_utt))


def trim_features(features):
def trim_features(features, apply_specaug):
"""Trim features to SEQ_LEN."""
full_length = features.shape[0]
start = random.randint(0, full_length - myconfig.SEQ_LEN)
return features[start: start + myconfig.SEQ_LEN, :]
trimmed_features = features[start: start + myconfig.SEQ_LEN, :]
if apply_specaug:
trimmed_features = specaug.apply_specaug(trimmed_features)
return trimmed_features


class TrimmedTripletFeaturesFetcher:
Expand All @@ -63,9 +67,9 @@ def __call__(self, _):
pos.shape[0] < myconfig.SEQ_LEN or
neg.shape[0] < myconfig.SEQ_LEN):
anchor, pos, neg = get_triplet_features(self.spk_to_utts)
return np.stack([trim_features(anchor),
trim_features(pos),
trim_features(neg)])
return np.stack([trim_features(anchor, myconfig.SPEC_AUG_TRAINING),
trim_features(pos, myconfig.SPEC_AUG_TRAINING),
trim_features(neg, myconfig.SPEC_AUG_TRAINING)])


def get_batched_triplet_input(spk_to_utts, batch_size, pool=None):
Expand Down
3 changes: 3 additions & 0 deletions myconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
# Number of MFCCs for librosa.feature.mfcc.
N_MFCC = 40

# Whether we are going to train with SpecAugment.
SPEC_AUG_TRAINING = False

# Hidden size of LSTM layers.
LSTM_HIDDEN_SIZE = 64

Expand Down
31 changes: 31 additions & 0 deletions specaug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import random

import myconfig

FREQ_MASK_PROB = 0.2
TIME_MASK_PROB = 0.2

FREQ_MASK_MAX_WIDTH = myconfig.N_MFCC // 10
TIME_MASK_MAX_WIDTH = myconfig.SEQ_LEN // 10


def apply_specaug(features):
"""Apply SpecAugment to features."""
seq_len, n_mfcc = features.shape
outputs = features
mean_feature = np.mean(features)

# Frequancy masking.
if random.random() < FREQ_MASK_PROB:
width = random.randint(1, FREQ_MASK_MAX_WIDTH)
start = random.randint(0, n_mfcc - width)
outputs[:, start: start + width] = mean_feature

# Time masking.
if random.random() < TIME_MASK_PROB:
width = random.randint(1, TIME_MASK_MAX_WIDTH)
start = random.randint(0, seq_len - width)
outputs[start: start + width, :] = mean_feature

return outputs
8 changes: 8 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile

import dataset
import specaug
import feature_extraction
import neural_net
import evaluation
Expand Down Expand Up @@ -52,6 +53,13 @@ def test_get_triplet(self):
self.assertNotEqual(neg1_spk, neg2_spk)


class TestSpecAug(unittest.TestCase):
def test_specaug(self):
features = np.random.rand(myconfig.SEQ_LEN, myconfig.N_MFCC)
outputs = specaug.apply_specaug(features)
self.assertEqual(outputs.shape, (myconfig.SEQ_LEN, myconfig.N_MFCC))


class TestFeatureExtraction(unittest.TestCase):
def setUp(self):
self.spk_to_utts = dataset.get_librispeech_spk_to_utts(
Expand Down

0 comments on commit 9066122

Please sign in to comment.