Skip to content

Commit

Permalink
Update minor changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhavir committed Jul 30, 2018
1 parent f9ba632 commit a4679dc
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 26 deletions.
33 changes: 21 additions & 12 deletions core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,30 @@
'44mel256_test': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([-2.44529629], [1.96563387])
])
]),
'24mel256_train': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([-2.1824522], [2.08129025])
]),
'24mel256_test': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([-2.1824522], [2.08129025])
]),
}


def cache_spectrogram(filename: str):
pcm = mf.read_wav(os.path.join('data', 'train', filename), target_sample_rate=44100)
pcm = mf.read_wav(os.path.join('data', 'train', filename), target_sample_rate=24000)
spec = spectrum.mel(pcm)
name, file_extension = os.path.splitext(filename)
utils.save_array(spec, os.path.join('data', 'cache', '44mel256_train', name + '.h5'))
utils.save_array(spec, os.path.join('data', 'cache', '24mel256_train', name + '.h5'))


def cache_test_spectrogram(filename: str):
pcm = mf.read_wav(os.path.join('data', 'test', filename), target_sample_rate=44100)
pcm = mf.read_wav(os.path.join('data', 'test', filename), target_sample_rate=24000)
spec = spectrum.mel(pcm)
name, file_extension = os.path.splitext(filename)
utils.save_array(spec, os.path.join('data', 'cache', 'test', name + '.h5'))
utils.save_array(spec, os.path.join('data', 'cache', '24mel256_test', name + '.h5'))


def load_and_slice(entry: dict):
Expand Down Expand Up @@ -90,12 +98,13 @@ def load_and_slice_test(entry: dict):


class SoundData(object):
def __init__(self, cache_prefix='mel256', test_size=0.2, num_processes=8, seed=42):
def __init__(self, cache_prefix='mel256', test_size=0.2, num_processes=8, seed=42, prevent_cache=False):
self.df = pd.read_csv(os.path.join(DATA_PATH, 'train.csv'))
self.cache_dir = os.path.join(DATA_PATH, 'cache', f'{cache_prefix}_train')
if not os.path.exists(self.cache_dir):
os.mkdir(self.cache_dir)
self.num_processes = num_processes
self.prevent_cache = prevent_cache
self.unique_label = np.sort(self.df.label.unique()).tolist()
self.label2idx = dict(zip(self.unique_label, range(len(self.unique_label))))
self.idx2label = dict(zip(range(len(self.unique_label)), self.unique_label))
Expand All @@ -108,7 +117,7 @@ def __init__(self, cache_prefix='mel256', test_size=0.2, num_processes=8, seed=4
self.train_idx, self.test_idx = train_test_split(self.idxs, test_size=test_size, random_state=seed)

def cache_samples(self):
if not os.listdir(self.cache_dir):
if not os.listdir(self.cache_dir) and not self.prevent_cache:
print(f"Caching in {self.num_processes} processes...")
pool = mp.Pool(processes=self.num_processes)
pool.map(cache_spectrogram, (self.df.fname).tolist())
Expand Down Expand Up @@ -201,11 +210,11 @@ def __getitem__(self, idx):
if __name__ == '__main__':
from time import time
t0 = time()
sound_data = SoundData(cache_prefix='44mel256', num_processes=6)
#sound_data = SoundData(cache_prefix='24mel256', num_processes=6)
#train_df, test_df = sound_data.get_train_test_split()
trainset = Dset(sound_data.df, num_processes=6, transform=data_transforms['44mel256_train'])
#trainset = Dset(sound_data.df, num_processes=6, transform=data_transforms['24mel256_train'])
#valset = Dset(test_df, num_processes=6, transform=data_transforms['test'])
#testset = TestDset(num_processes=2, transform=data_transforms['test'])
testset = TestDset(cache_prefix='24mel256', num_processes=4, transform=data_transforms['44mel256_test'])
print(time() - t0)
print(len(os.listdir(sound_data.cache_dir)))
print(trainset[0][0].size())
print(len(os.listdir(testset.cache_dir)))
print(testset[0][0].size())
29 changes: 24 additions & 5 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,29 @@
from tqdm import tqdm
from collections import defaultdict
import argparse
import gc

CKPT_DIR = os.path.join('checkpoints')

# Collect arguments (if any)
parser = argparse.ArgumentParser()

# Cache prefix
parser.add_argument('cache_prefix', nargs='?', type=str, choices=['mel256', 'wavelet'], default='mel256', help="Mel spectrogram or wavelets.")
parser.add_argument('cache_prefix', nargs='?', type=str, choices=['mel256', 'wavelet', '44mel256', '24mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Checkpoint directory
parser.add_argument('-dir', '--ckpt_dir', type=str, choices=os.listdir(CKPT_DIR), default=sorted(os.listdir(CKPT_DIR))[-1], help="Checkpoints dir.")
# Checkpoint directory
parser.add_argument('-dir2', '--ckpt_dir2', type=str, choices=os.listdir(CKPT_DIR), default=sorted(os.listdir(CKPT_DIR))[-2], help="Checkpoints dir.")
parser.add_argument('-dir2', '--ckpt_dir2', type=str, choices=os.listdir(CKPT_DIR), default=sorted(os.listdir(CKPT_DIR))[-2], help="Second checkpoints dir.")
# Cache prefix
parser.add_argument('--cache_prefix2', type=str, choices=['mel256', 'wavelet'], default='mel256', help="Mel spectrogram or wavelets.")
parser.add_argument('--cache_prefix2', type=str, choices=['mel256', 'wavelet', '44mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Checkpoint directory
parser.add_argument('-dir3', '--ckpt_dir3', type=str, choices=os.listdir(CKPT_DIR), default=None, help="Third checkpoints dir.")
# Cache prefix
parser.add_argument('--cache_prefix3', type=str, choices=['mel256', 'wavelet', '44mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Checkpoint directory
parser.add_argument('-dir4', '--ckpt_dir4', type=str, choices=os.listdir(CKPT_DIR), default=None, help="Fourth checkpoints dir.")
# Cache prefix
parser.add_argument('--cache_prefix4', type=str, choices=['mel256', 'wavelet', '44mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Type of evaluation
parser.add_argument('-t', '--type', type=str, choices=['all', 'last', 'combine-last', 'combine-all'], default='last', help="Type of experiment evaluation.")
# Batch size
Expand All @@ -39,7 +48,7 @@

print(f"Loading snapshots from experiment: {args.ckpt_dir}")

idx2label = cd.SoundData(cache_prefix=args.cache_prefix).idx2label
idx2label = cd.SoundData(prevent_cache=True).idx2label
#sound_data = cd.SoundData(phase='test', num_processes=args.num_workers)
testset = cd.TestDset(cache_prefix=args.cache_prefix, num_processes=args.num_workers, transform=cd.data_transforms[f'{args.cache_prefix}_test'])
testloader = thd.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
Expand All @@ -54,8 +63,17 @@
snaps_dir2 = os.path.join(RES_DIR2, 'snaps')
runs += [os.path.join(snaps_dir2, run_name) for run_name in sorted(os.listdir(snaps_dir2))]
prefixes += ([args.cache_prefix2] * len(os.listdir(snaps_dir2)))
if args.ckpt_dir3 is not None:
RES_DIR3 = os.path.join(CKPT_DIR, args.ckpt_dir3)
snaps_dir3 = os.path.join(RES_DIR3, 'snaps')
runs += [os.path.join(snaps_dir3, run_name) for run_name in sorted(os.listdir(snaps_dir3))]
prefixes += ([args.cache_prefix3] * len(os.listdir(snaps_dir3)))
if args.ckpt_dir4 is not None:
RES_DIR4 = os.path.join(CKPT_DIR, args.ckpt_dir4)
snaps_dir4 = os.path.join(RES_DIR4, 'snaps')
runs += [os.path.join(snaps_dir4, run_name) for run_name in sorted(os.listdir(snaps_dir4))]
prefixes += ([args.cache_prefix4] * len(os.listdir(snaps_dir4)))
is_ensemble = len(runs) > 1
import pdb; pdb.set_trace()

def eval_model(loader, model, model_num):
predictions = defaultdict(list)
Expand Down Expand Up @@ -83,6 +101,7 @@ def eval_model(loader, model, model_num):
if mname.endswith('last.model'):
print(f"Evaluating model {mname}")
model = torch.load(os.path.join(run_dir, mname))
model.eval()
if args.multi_gpu:
model = nn.DataParallel(model)
if not prefixes[split_num] == active_prefix:
Expand Down
31 changes: 27 additions & 4 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from copy import deepcopy
import argparse
import logging
import gc

logger = logging.getLogger()
logger, RUN_DIR = log.setup_logger(logger)
Expand All @@ -42,7 +43,7 @@

class Experiment(object):
def __init__(self, model: str, batch_size: int, epochs: int, lr: float, cache_prefix: str='mel256',
eval_interval: int=1, optimizer: str='sgd', schedule: str=None, step_size: int=10, gamma: float=0.5,
eval_interval: int=1, optimizer: str='sgd', schedule: str=None, step_size: int=50, gamma: float=0.5,
use_mixup: bool=True, mixup_alpha: float=0.5, weighted: bool=False, cross_validate: bool=False,
n_splits: int=5, seed: int=42, metric: str='accuracy', no_snaps: bool=False, debug_limit: int=None,
device: str=('cuda' if torch.cuda.is_available() else 'cpu'), num_processes: int=8, multi_gpu: bool=False, **kwargs):
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self, model: str, batch_size: int, epochs: int, lr: float, cache_pr
self.eye = torch.eye(self.num_classes).to(self.device)
self.mixup = Mixup(mixup_alpha, self.device)

#self.run_snaps_dir = os.path.join('checkpoints', '20180725-22:03:57', 'snaps')
self.model = self.load_model()

if optimizer == 'sgd':
Expand Down Expand Up @@ -177,12 +179,26 @@ def load_model(self):

return model.to(self.device)

def continue_model(self, run_snaps_dir, split_num):
model_path = os.path.join(run_snaps_dir, f'run-{split_num}', f'{self.model_str}-last.model')
logger.info(f'Loading model: {model_path}')
model = torch.load(model_path)

logger.info(f"Num params: {sum([np.prod(p.size()) for p in model.parameters()])}")
logger.info(f"Num trainable params: {sum([np.prod(p.size()) for p in model.parameters() if p.requires_grad])}")

if self.multi_gpu:
model = nn.DataParallel(model)

return model.to(self.device)

def train_loop(self, epoch):
train_loader = tqdm(self.loaders['train'], desc=f'TRAIN Epoch {epoch}',
total=(len(self.trainset)//self.batch_size + 1))

total_loss = 0.0
correct = 0; total = 0;
self.model.train()
for inputs, targets, ids in train_loader:
inputs, targets = inputs.to(self.device).float(), targets.to(self.device)

Expand Down Expand Up @@ -219,7 +235,7 @@ def eval_loop(self, epoch, phase):
predictions = defaultdict(list)
labels = defaultdict(list)
eval_loader = tqdm(self.loaders[phase], desc=f'EVALUATION Epoch {epoch}', total=(len(self.loaders[phase].dataset)//self.batch_size + 1))

self.model.eval()
with torch.no_grad():
for inputs, targets, ids in eval_loader:
inputs, targets = inputs.to(self.device).float(), targets.to(self.device)
Expand Down Expand Up @@ -307,6 +323,13 @@ def split_run(self):

self.single_run(run_fname=f'run-{split_num}')

del self.loaders
del self.trainset.data
del self.testset.data
del self.trainset
del self.testset
gc.collect()

def run(self):
if self.no_snaps:
logger.info('Preventing from snapshots')
Expand All @@ -325,7 +348,7 @@ def run(self):
# Pretrained model
parser.add_argument('model', type=str, choices=pretrained_models.keys(), help="Model to run.")
# Cache prefix
parser.add_argument('cache_prefix', nargs='?', type=str, choices=['mel256', 'wavelet'], default='mel256', help="Mel spectrogram or wavelets.")
parser.add_argument('cache_prefix', nargs='?', type=str, choices=['mel256', 'wavelet', '44mel256', '24mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Batch size
parser.add_argument('-bs', '--batch_size', type=int, default=64, help='Batch size.')
# Epochs
Expand Down Expand Up @@ -366,6 +389,6 @@ def run(self):
torch.cuda.set_device(args.gpu_device)

exp = Experiment(args.model, args.batch_size, args.epochs, args.learning_rate, args.cache_prefix, eval_interval= args.eval_interval,
use_mixup=(not args.no_mixup), mixup_alpha=args.mixup_alpha, cross_validate=args.cross_validate, schedule=args.scheduler,
use_mixup=(not args.no_mixup), mixup_alpha=args.mixup_alpha, cross_validate=args.cross_validate, schedule=args.scheduler, gamma=args.gamma,
seed=args.seed, no_snaps=args.no_snaps, debug_limit=args.debug_limit, num_processes=args.num_workers, multi_gpu=args.multi_gpu)
exp.run()
14 changes: 9 additions & 5 deletions statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm
from collections import defaultdict
import argparse
import gc

CKPT_DIR = os.path.join('checkpoints')

Expand All @@ -19,6 +20,8 @@

# Pretrained model
parser.add_argument('model', type=str, help="Model to run.")
# Cache prefix
parser.add_argument('cache_prefix', nargs='?', type=str, choices=['mel256', 'wavelet', '44mel256', '24mel256'], default='mel256', help="Mel spectrogram or wavelets.")
# Checkpoint directory
parser.add_argument('-dir', '--ckpt_dir', type=str, choices=os.listdir(CKPT_DIR), default=sorted(os.listdir(CKPT_DIR))[-1], help="Checkpoints dir.")
# Type of evaluation
Expand All @@ -39,7 +42,7 @@

print(f"Loading snapshots from experiment: {args.ckpt_dir}")

sound_data = cd.SoundData(num_processes=args.num_workers, seed=args.seed)
sound_data = cd.SoundData(cache_prefix=args.cache_prefix, num_processes=args.num_workers, seed=args.seed)
device = torch.device(args.device)
RES_DIR = os.path.join(CKPT_DIR, args.ckpt_dir)
snaps_dir = os.path.join(RES_DIR, 'snaps')
Expand Down Expand Up @@ -80,15 +83,16 @@ def eval_model(loader, model, model_num, phase):
for split_num, (train, test) in enumerate(kfold.split(sound_data.idxs, sound_data.df.target)):
sound_data.reset_index(train, test)
train_df, test_df = sound_data.get_train_test_split()
trainset = cd.Dset(train_df, args.num_workers, transform=cd.data_transforms['train'], phase='train')
testset = cd.Dset(test_df, args.num_workers, transform=cd.data_transforms['test'], phase='test')
loaders = {'train': thd.DataLoader(trainset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers),
'test': thd.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)}
testset = cd.Dset(test_df, args.num_workers, transform=cd.data_transforms[f'{args.cache_prefix}_test'], phase='test')
loaders = {'test': thd.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)}
if args.type == 'last':
model = torch.load(os.path.join(runs[split_num], args.model + '-last.model'))
model.eval()
eval_model(loaders['test'], model, split_num, 'test')
elif args.type == 'all':
for model_num, mname in enumerate(os.listdir(runs[split_num])):
if mname.endswith('.model'):
model = torch.load(os.path.join(runs[split_num], mname))
eval_model(loaders['test'], model, f'{split_num} / {model_num}', 'test')
del testset, loaders
gc.collect()

0 comments on commit a4679dc

Please sign in to comment.