Skip to content

Commit

Permalink
Support training with multi-processing feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
wq2012 committed May 10, 2022
1 parent 16b2dc4 commit e67c23f
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 27 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ There are many reasons, for example:

* Data: We are only using a small subset of LibriSpeech clean, this is very simple data, and with very few speakers. Also, no data augmentation is used.
* Feature: We are simply using default params of MFCC in librosa.
* Model: We are simply using 3 layers of uni-directional LSTM.
* Model: We are simply using 3 layers of LSTM.
* Loss: We are using a simple triplet loss.
* Efficiency: Code is written for simplicity and readability, not really for efficiency. Multi-processing and multi-threading are not used.
* Efficiency: Code is written for simplicity and readability, not really for efficiency.

Good news is that at least we have a system working end-to-end. It hooks up with data, extracts features, defines a neural network, trains the network, and evaluates it, all successfully.

Expand Down
44 changes: 25 additions & 19 deletions feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,32 @@ def trim_features(features):
return features[start: start + myconfig.SEQ_LEN, :]


def get_triplet_features_trimmed(spk_to_utts):
"""Get a triplet of trimmed anchor/pos/neg features."""
anchor, pos, neg = get_triplet_features(spk_to_utts)
while (anchor.shape[0] < myconfig.SEQ_LEN or
pos.shape[0] < myconfig.SEQ_LEN or
neg.shape[0] < myconfig.SEQ_LEN):
anchor, pos, neg = get_triplet_features(spk_to_utts)
return (trim_features(anchor),
trim_features(pos),
trim_features(neg))


def get_batched_triplet_input(spk_to_utts, batch_size):
class TrimmedTripletFeaturesFetcher:
"""The fetcher of trimmed features for multi-processing."""

def __init__(self, spk_to_utts):
self.spk_to_utts = spk_to_utts

def __call__(self, _):
"""Get a triplet of trimmed anchor/pos/neg features."""
anchor, pos, neg = get_triplet_features(self.spk_to_utts)
while (anchor.shape[0] < myconfig.SEQ_LEN or
pos.shape[0] < myconfig.SEQ_LEN or
neg.shape[0] < myconfig.SEQ_LEN):
anchor, pos, neg = get_triplet_features(self.spk_to_utts)
return np.stack([trim_features(anchor),
trim_features(pos),
trim_features(neg)])


def get_batched_triplet_input(spk_to_utts, batch_size, pool=None):
"""Get batched triplet input for PyTorch."""
input_arrays = []
for _ in range(batch_size):
anchor, pos, neg = get_triplet_features_trimmed(
spk_to_utts)
input_arrays += [anchor, pos, neg]
batch_input = torch.from_numpy(np.stack(input_arrays)).float()
fetcher = TrimmedTripletFeaturesFetcher(spk_to_utts)
if pool is None:
input_arrays = list(map(fetcher, range(batch_size)))
else:
input_arrays = pool.map(fetcher, range(batch_size))
batch_input = torch.from_numpy(np.concatenate(input_arrays)).float()
return batch_input


Expand Down
6 changes: 5 additions & 1 deletion myconfig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file has the configurations of the experiments.
import os
import torch
import multiprocessing

# Path of downloaded LibriSpeech datasets.
TRAIN_DATA_DIR = os.path.join(
Expand Down Expand Up @@ -50,7 +51,7 @@
SAVE_MODEL_FREQUENCY = 10000

# Number of steps to train.
TRAINING_STEPS = 50000
TRAINING_STEPS = 100000

# Number of triplets to evaluate for computing Equal Error Rate (EER).
# Both the number of positive trials and number of negative trials will be
Expand All @@ -60,5 +61,8 @@
# Step of threshold sweeping for computing Equal Error Rate (EER).
EVAL_THRESHOLD_STEP = 0.001

# Number of processes for multi-processing.
NUM_PROCESSES = min(multiprocessing.cpu_count(), BATCH_SIZE)

# Wehther to use GPU or CPU.
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10 changes: 7 additions & 3 deletions neural_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing

import feature_extraction
import myconfig
Expand Down Expand Up @@ -76,7 +77,7 @@ def batch_inference(batch_input, encoder):
return batch_output


def train_network(num_steps, saved_model=None):
def train_network(num_steps, saved_model=None, pool=None):
start_time = time.time()
losses = []
spk_to_utts = feature_extraction.get_spk_to_utts(myconfig.TRAIN_DATA_DIR)
Expand All @@ -91,7 +92,7 @@ def train_network(num_steps, saved_model=None):

# Build batched input.
batch_input = feature_extraction.get_batched_triplet_input(
spk_to_utts, myconfig.BATCH_SIZE).to(myconfig.DEVICE)
spk_to_utts, myconfig.BATCH_SIZE, pool).to(myconfig.DEVICE)

# Compute loss.
batch_output = batch_inference(batch_input, encoder)
Expand All @@ -115,7 +116,10 @@ def train_network(num_steps, saved_model=None):


def run_training():
losses = train_network(myconfig.TRAINING_STEPS, myconfig.SAVED_MODEL_PATH)
with multiprocessing.Pool(myconfig.NUM_PROCESSES) as pool:
losses = train_network(myconfig.TRAINING_STEPS,
myconfig.SAVED_MODEL_PATH,
pool)
plt.plot(losses)
plt.xlabel("step")
plt.ylabel("loss")
Expand Down
10 changes: 8 additions & 2 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import unittest
import numpy as np
import multiprocessing

import feature_extraction
import neural_net
Expand Down Expand Up @@ -55,8 +56,12 @@ def test_get_triplet_features(self):
self.assertEqual(myconfig.N_MFCC, neg.shape[1])

def test_get_triplet_features_trimmed(self):
anchor, pos, neg = feature_extraction.get_triplet_features_trimmed(
fetcher = feature_extraction.TrimmedTripletFeaturesFetcher(
self.spk_to_utts)
fetched = fetcher(None)
anchor = fetched[0, :, :]
pos = fetched[1, :, :]
neg = fetched[2, :, :]
self.assertEqual(anchor.shape, (myconfig.SEQ_LEN, myconfig.N_MFCC))
self.assertEqual(pos.shape, (myconfig.SEQ_LEN, myconfig.N_MFCC))
self.assertEqual(neg.shape, (myconfig.SEQ_LEN, myconfig.N_MFCC))
Expand Down Expand Up @@ -119,7 +124,8 @@ def test_train_unilstm_network(self):
def test_train_bilstm_network(self):
myconfig.BI_LSTM = True
myconfig.FRAME_AGGREGATION_MEAN = True
losses = neural_net.train_network(num_steps=2)
with multiprocessing.Pool(myconfig.NUM_PROCESSES) as pool:
losses = neural_net.train_network(num_steps=2, pool=pool)
self.assertEqual(len(losses), 2)


Expand Down

0 comments on commit e67c23f

Please sign in to comment.