-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtrsf_genre_classification.py
61 lines (50 loc) · 1.87 KB
/
trsf_genre_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json
from glob import glob
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from models import transformer_classifier
from prepare_data import get_id_from_path, DataGenerator
if __name__ == "__main__":
from collections import Counter
h5_name = "transformer.h5"
batch_size = 32
epochs = 50
CLASS_MAPPING = json.load(open("/media/ml/data_ml/fma_metadata/mapping.json"))
id_to_genres = json.load(open("/media/ml/data_ml/fma_metadata/tracks_genre.json"))
id_to_genres = {int(k): v for k, v in id_to_genres.items()}
base_path = "/media/ml/data_ml/fma_large"
files = sorted(list(glob(base_path + "/*/*.npy")))
files = [x for x in files if id_to_genres[int(get_id_from_path(x))]]
labels = [id_to_genres[int(get_id_from_path(x))] for x in files]
print(len(labels))
samples = list(zip(files, labels))
strat = [a[-1] for a in labels]
cnt = Counter(strat)
strat = [a if cnt[a] > 2 else "" for a in strat]
train, val = train_test_split(
samples, test_size=0.2, random_state=1337, stratify=strat
)
model = transformer_classifier(n_classes=len(CLASS_MAPPING))
checkpoint = ModelCheckpoint(
h5_name,
monitor="val_loss",
verbose=1,
save_best_only=True,
mode="min",
save_weights_only=True,
)
reduce_o_p = ReduceLROnPlateau(
monitor="val_loss", patience=20, min_lr=1e-7, mode="min"
)
model.fit_generator(
DataGenerator(train, batch_size=batch_size, class_mapping=CLASS_MAPPING),
validation_data=DataGenerator(
val, batch_size=batch_size, class_mapping=CLASS_MAPPING
),
epochs=epochs,
callbacks=[checkpoint, reduce_o_p],
use_multiprocessing=True,
workers=12,
verbose=2,
max_queue_size=64,
)