-
Notifications
You must be signed in to change notification settings - Fork 1
/
fma.py
84 lines (76 loc) · 3 KB
/
fma.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import json
import random
import pickle
import numpy as np
import pandas as pd
import torch
from typing import Callable, List, Dict, Any
from torch.utils.data import Dataset
class FMA_Dataset(Dataset):
def __init__(self, data_path, split, audio_embs):
self.data_path = data_path
self.split = split
self.audio_embs = audio_embs
self.black_list = [99134, 108925, 133297]
self.get_split()
self.get_file_list()
def get_split(self):
track_split = json.load(open(os.path.join(self.data_path, "fma", "track_split.json"), "r"))
self.train_track = track_split['train_track']
self.valid_track = track_split['valid_track']
self.test_track = track_split['test_track']
def get_file_list(self):
annotation = json.load(open(os.path.join(self.data_path, "fma", "annotation.json"), 'r'))
self.list_of_label = json.load(open(os.path.join(self.data_path, "fma", "fma_tags.json"), 'r'))
self.tag_to_idx = {i:idx for idx, i in enumerate(self.list_of_label)}
if self.split == "TRAIN":
self.fl = [annotation[str(i)] for i in self.train_track if int(i) not in self.black_list]
elif self.split == "VALID":
self.fl = [annotation[str(i)] for i in self.valid_track if int(i) not in self.black_list]
elif self.split == "TEST":
self.fl = [annotation[str(i)] for i in self.test_track if int(i) not in self.black_list]
elif self.split == "ALL":
self.fl = [v for k,v in annotation.items() if int(k) not in self.black_list]
else:
raise ValueError(f"Unexpected split name: {self.split}")
del annotation
def tag_to_binary(self, text):
bainry = np.zeros([len(self.list_of_label),], dtype=np.float32)
if isinstance(text, str):
bainry[self.tag_to_idx[text]] = 1.0
elif isinstance(text, list):
for tag in text:
bainry[self.tag_to_idx[tag]] = 1.0
return bainry
def get_train_item(self, index):
item = self.fl[index]
tag_list = item['tag']
binary = self.tag_to_binary(tag_list)
audio_tensor = self.audio_embs[str(item['track_id'])]
return {
"audio":audio_tensor,
"binary":binary
}
def get_eval_item(self, index):
item = self.fl[index]
tag_list = item['tag']
binary = self.tag_to_binary(tag_list)
text = ", ".join(tag_list)
tags = self.list_of_label
track_id = item['track_id']
audio = self.audio_embs[str(track_id)]
return {
"audio":audio,
"track_id":track_id,
"tags":tags,
"binary":binary,
"text":text
}
def __getitem__(self, index):
if (self.split=='TRAIN') or (self.split=='VALID'):
return self.get_train_item(index)
else:
return self.get_eval_item(index)
def __len__(self):
return len(self.fl)