-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_manger.py
58 lines (55 loc) · 2.01 KB
/
data_manger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from torch.utils.data import DataLoader
from mtat import MTAT_Dataset
from gtzan import GTZAN_Dataset
from fma import FMA_Dataset
from kvt import KVT_Dataset
from openmic import OPENMIC_Dataset
from mtg import MTG_Dataset
from emotify import EMOTIFY_Dataset
def get_dataloader(args, split, audio_embs):
dataset = get_dataset(
eval_dataset= args.eval_dataset,
data_path= args.msu_dir,
split= split,
audio_embs= audio_embs
)
if split == "TRAIN":
data_loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False
)
elif split == "VALID":
data_loader = DataLoader(
dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False
)
elif split == "TEST":
data_loader = DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False
)
elif split == "ALL":
data_loader = DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False
)
return data_loader
def get_dataset(
eval_dataset,
data_path,
split,
audio_embs
):
if eval_dataset == "mtat":
dataset = MTAT_Dataset(data_path, split, audio_embs)
elif eval_dataset == "gtzan":
dataset = GTZAN_Dataset(data_path, split, audio_embs)
elif eval_dataset == "fma":
dataset = FMA_Dataset(data_path, split, audio_embs)
elif eval_dataset == "kvt":
dataset = KVT_Dataset(data_path, split, audio_embs)
elif eval_dataset == "openmic":
dataset = OPENMIC_Dataset(data_path, split, audio_embs)
elif eval_dataset == "emotify":
dataset = EMOTIFY_Dataset(data_path, split, audio_embs)
elif "mtg" in eval_dataset:
dataset = MTG_Dataset(data_path, split, audio_embs, eval_dataset)
else:
print("error")
return dataset