Skip to content

Commit

Permalink
feat: add alternative autoregressive model (#29)
Browse files Browse the repository at this point in the history
* feat: add teacher forcing

* feat: generate and temperature

* feat: update model
  • Loading branch information
tsunrise authored Nov 17, 2022
1 parent a9117f7 commit c609d65
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 71 deletions.
3 changes: 2 additions & 1 deletion datasets.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[datasets]
# All files should be a zip archive of MIDI files (it's ok to have recursive directories)
adl-piano-midi = {midi = "https://r2.tomshen.io/cs230/adl-piano-midi.zip", prepared = "https://r2.tomshen.io/cs230/prepared_mono_all_files.pkl"}
adl-piano-midi = {midi = "https://r2.tomshen.io/cs230/adl-piano-midi.zip", prepared = "https://r2.tomshen.io/cs230/prepared_mono_all_files.pkl"}
# mastero = {midi = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip", metadata = "https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0.csv"}
65 changes: 37 additions & 28 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
import torch
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.writer import SummaryWriter
from models.lstm_tf import DeepBeatsLSTM

import utils.devices as devices
from models.lstm import DeepBeats
Expand All @@ -31,7 +32,12 @@ def train(args):
print(f"Using {device} device")

# initialize mdoel
model = DeepBeats(args.n_notes, args.embed_dim, args.hidden_dim).to(device)
if args.model_name == "lstm":
model = DeepBeats(args.n_notes, args.embed_dim, args.hidden_dim).to(device)
elif args.model_name == "lstm_tf":
model = DeepBeatsLSTM(args.n_notes, args.embed_dim, args.hidden_dim).to(device)
else:
raise NotImplementedError("Model {} is not implemented.".format(args.model))
print(model)

if args.load_checkpoint:
Expand All @@ -46,62 +52,65 @@ def train(args):
indices = np.arange(args.n_files if args.n_files != -1 else ADL_PIANO_TOTAL_SIZE)
np.random.seed(0)
np.random.shuffle(indices)
train_size = int(0.8 * len(indices))
# train_size = int(0.8 * len(indices))
train_size = len(indices)
train_indices = indices[: train_size]
val_indices = indices[train_size: ]
train_dataset = BeatsRhythmsDataset(mono=True, num_files=args.n_files, seq_len=args.seq_len, save_freq=128, max_files_to_parse=args.max_files_to_parse, indices = train_indices)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=0, collate_fn=collate_fn)
val_dataset = BeatsRhythmsDataset(mono=True, num_files=args.n_files, seq_len=args.seq_len, save_freq=128, max_files_to_parse=args.max_files_to_parse, indices = val_indices)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=0, collate_fn=collate_fn)
# val_dataset = BeatsRhythmsDataset(mono=True, num_files=args.n_files, seq_len=args.seq_len, save_freq=128, max_files_to_parse=args.max_files_to_parse, indices = val_indices)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, num_workers=0, collate_fn=collate_fn)

# define tensorboard writer
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M")
log_dir = paths.tensorboard_dir / "{}_{}/{}".format(args.model_name, "all" if args.n_files == -1 else args.n_files, current_time)
writer = SummaryWriter(log_dir = log_dir, flush_secs= 60)

# training loop
best_val_loss = np.inf
# best_val_loss = np.inf
for epoch in range(args.n_epochs):
train_batch_loss = 0
val_batch_loss = 0
# val_batch_loss = 0
num_train_batches = 0
num_val_batches = 0
# num_val_batches = 0

model.train()
for X, y in train_loader:
for X, y, y_prev in train_loader:
optimizer.zero_grad()
input_seq = torch.from_numpy(X.astype(np.float32)).to(device)
target_seq = torch.from_numpy(y.astype(np.float32)).to(device)
output = model(input_seq)
target_seq = torch.from_numpy(y).long().to(device)
target_prev_seq = torch.from_numpy(y_prev).long().to(device)
output, _ = model(input_seq, target_prev_seq)
loss = model.loss_function(output, target_seq)
train_batch_loss += loss.item()
loss.backward()
model.clip_gradients_(5) # TODO: magic number
optimizer.step()
num_train_batches += 1

model.eval()
for X, y in val_loader:
input_seq = torch.from_numpy(X.astype(np.float32)).to(device)
target_seq = torch.from_numpy(y.astype(np.float32)).to(device)
with torch.no_grad():
output = model(input_seq)
loss = model.loss_function(output, target_seq)
val_batch_loss += loss.item()
num_val_batches += 1
# model.eval()
# for X, y in val_loader:
# input_seq = torch.from_numpy(X.astype(np.float32)).to(device)
# target_seq = torch.from_numpy(y.astype(np.float32)).to(device)
# with torch.no_grad():
# output = model(input_seq)
# loss = model.loss_function(output, target_seq)
# val_batch_loss += loss.item()
# num_val_batches += 1

avg_train_loss = train_batch_loss / num_train_batches
avg_val_loss = val_batch_loss / num_val_batches
# avg_val_loss = val_batch_loss / num_val_batches

print('Epoch: {}/{}.............'.format(epoch, args.n_epochs), end=' ')
print("Train Loss: {:.4f} Validation Loss: {:.4f}".format(avg_train_loss, avg_val_loss))
print("Train Loss: {:.4f}".format(avg_train_loss))
writer.add_scalar("train loss", avg_train_loss, epoch)
writer.add_scalar("validation loss", avg_val_loss, epoch)
# writer.add_scalar("validation loss", avg_val_loss, epoch)

# save checkpoint with lowest validation loss
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
save_model(model, paths, args.model_name, args.n_files, "best")
print("Minimum Validation Loss of {:.4f} at epoch {}/{}".format(best_val_loss, epoch, args.n_epochs))
# if avg_val_loss < best_val_loss:
# best_val_loss = avg_val_loss
# save_model(model, paths, args.model_name, args.n_files, "best")
# print("Minimum Validation Loss of {:.4f} at epoch {}/{}".format(best_val_loss, epoch, args.n_epochs))

# save snapshots
if (epoch + 1) % args.snapshots_freq == 0:
Expand All @@ -113,7 +122,7 @@ def train(args):

if __name__ == '__main__':
parser = argparse.ArgumentParser('Train DeepBeats')
parser.add_argument('--model_name', type=str, default="lstm")
parser.add_argument('--model_name', type=str, default="lstm_tf")
parser.add_argument('--load_checkpoint', type=str, default="", help="load checkpoint path to continue training")
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--embed_dim', type=int, default=32)
Expand Down
21 changes: 18 additions & 3 deletions models/lstm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn as nn

from preprocess.dataset import BeatsRhythmsDataset
from torch.utils.data import DataLoader

class DeepBeats(nn.Module):
def __init__(self, num_notes, embed_size, hidden_dim):
Expand All @@ -9,15 +11,28 @@ def __init__(self, num_notes, embed_size, hidden_dim):
self.layer1 = nn.LSTM(embed_size, hidden_dim, batch_first=True)
self.layer2 = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
self.notes_output = nn.Linear(hidden_dim, num_notes)
self.num_notes = num_notes

def forward(self, x):
def forward(self, x, y_prev):
_ = y_prev # unused
x = self.durs_embed(x)
x = self.layer1(x)[0]
x = self.layer2(x)[0]
predicted_notes = self.notes_output(x)
return predicted_notes
return predicted_notes, None

def sample(self, x):
return self.forward(x, None)

def loss_function(self, pred, target):
"""
Pred: (batch_size, seq_len, num_notes), logits
Target: (batch_size, seq_len), range from 0 to num_notes-1
"""
criterion = nn.CrossEntropyLoss()
loss = criterion(pred, target)
target_one_hot = torch.nn.functional.one_hot(target, self.num_notes).float()
loss = criterion(pred, target_one_hot)
return loss

def clip_gradients_(self, max_value):
torch.nn.utils.clip_grad.clip_grad_value_(self.parameters(), max_value)
95 changes: 95 additions & 0 deletions models/lstm_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn

from models.model_utils import ConcatPrev
import numpy as np

class DeepBeatsLSTM(nn.Module):
"""
DeepBeats with Teacher Forcing. This is the same as DeepBeats, the label in previous step is used as input in the next step.
"""
def __init__(self, num_notes, embed_size, hidden_dim):
super(DeepBeatsLSTM, self).__init__()
self.note_embedding = nn.Embedding(num_notes, embed_size)
self.concat_prev = ConcatPrev()
self.concat_input_fc = nn.Linear(embed_size + 2, embed_size + 2)
self.concat_input_activation = nn.LeakyReLU()
self.layer1 = nn.LSTM(embed_size + 2, hidden_dim, batch_first=True)
self.layer2 = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
self.notes_logits_output = nn.Linear(hidden_dim, num_notes)
self.num_notes = num_notes
self.hidden_dim = hidden_dim

self._initializer_weights()

def _default_init_hidden(self, batch_size):
device = next(self.parameters()).device
h1_0 = torch.zeros(1, batch_size, self.layer1.hidden_size).to(device)
c1_0 = torch.zeros(1, batch_size, self.layer1.hidden_size).to(device)
h2_0 = torch.zeros(1, batch_size, self.layer2.hidden_size).to(device)
c2_0 = torch.zeros(1, batch_size, self.layer2.hidden_size).to(device)
return h1_0, c1_0, h2_0, c2_0

def _initializer_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)

def forward(self, x, y_prev, init_hidden=None):
"""
x: input, shape: (batch_size, seq_len, 2)
y_prev: label, shape: (batch_size, seq_len), range from 0 to num_notes-1
y_prev[i] should be the note label for x[i-1], and y[0] is 0.
"""
h1_0, c1_0, h2_0, c2_0 = self._default_init_hidden(x.shape[0]) if init_hidden is None else init_hidden
# Embedding
y_prev_embed = self.note_embedding(y_prev)
X = self.concat_prev(x, y_prev_embed)
# Concat input
X_fc = self.concat_input_fc(X)
X_fc = self.concat_input_activation(X)
# residual connection
X_fc = X_fc + X

X, (h1, c1) = self.layer1(X, (h1_0, c1_0))
X, (h2, c2) = self.layer2(X, (h2_0, c2_0))
predicted_notes = self.notes_logits_output(X)
return predicted_notes, (h1, c1, h2, c2)

def sample(self, x, y_init, temperature=1.0):
"""
x: input, shape: (seq_len, 2)
y_init: initial label, shape: (1), range from 0 to num_notes-1
This function uses a for loop to generate the sequence using LSTMCell, one by one.
"""
assert self.training == False, "This function should be used in eval mode."
assert len(x.shape) == 2, "x should be 2D tensor"
ys = [y_init]
hidden = self._default_init_hidden(1)
for i in range(x.shape[0]):
x_curr = x[i].reshape(1, 1, 2)
y_prev = ys[-1].reshape(1, 1)
scores, hidden = self.forward(x_curr,y_prev, hidden)
scores = scores.squeeze(0)
scores = scores / temperature
scores = torch.nn.functional.softmax(scores, dim=1)
y = torch.multinomial(scores, 1)
ys.append(y)
out = [y.item() for y in ys]
print(out)
return np.array(out)

def loss_function(self, pred, target):
"""
Pred: (batch_size, seq_len, num_notes), logits
Target: (batch_size, seq_len), range from 0 to num_notes-1
"""
criterion = nn.CrossEntropyLoss()
target_one_hot = torch.nn.functional.one_hot(target, self.num_notes).float()
loss = criterion(pred, target_one_hot)
return loss

def clip_gradients_(self, max_value):
torch.nn.utils.clip_grad.clip_grad_value_(self.parameters(), max_value)
14 changes: 14 additions & 0 deletions models/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from torch import nn

class ConcatPrev(nn.Module):
def forward(self, x, y_prev):
"""
x: input, shape: (batch_size, seq_len, 2)
y_prev: label, shape: (batch_size, seq_len, embedding_dim),
y_prev[i] should be the embedding of note label for x[i-1], and y[0] is 0.
"""

# concat x and y_prev_embed to be X
X = torch.cat((x, y_prev), dim=2)
return X
40 changes: 28 additions & 12 deletions predict_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import music21
import numpy as np
from models.lstm_tf import DeepBeatsLSTM
import preprocess.dataset
import torch
from models.lstm import DeepBeats
Expand Down Expand Up @@ -30,58 +31,73 @@ def convert_to_stream(notes, prev_rest, curr_durs):
s.append(a)
return s

def predict_notes_sequence(durs_seq, model, device):
def predict_notes_sequence(durs_seq, model, init_note, device, temperature):
"""
Predict notes sequence given durations sequence
"""
model.to(device)
model.eval()
prob_n = model(torch.from_numpy(durs_seq.astype(np.float32)).to(device)) # (1, seq_length, 128)
prob_n = prob_n.cpu().detach().numpy()
notes_seq = np.argmax(prob_n, -1) # (1, seq_length)
notes_seq = notes_seq.squeeze(0) # (seq_length,)
prev_rest_seq = durs_seq.squeeze(0)[:, 0] # (seq_length,)
curr_durs_seq = durs_seq.squeeze(0)[:, 1] # (seq_length,)
dur_seq_t = torch.from_numpy(durs_seq).to(device)
init_note_t = torch.tensor(init_note, dtype=torch.long).to(device)

notes_seq = model.sample(dur_seq_t, init_note_t, temperature) # TODO: temperature is a magic number
prev_rest_seq = durs_seq[:, 0] # (seq_length,)
curr_durs_seq = durs_seq[:, 1] # (seq_length,)

return notes_seq, prev_rest_seq, curr_durs_seq

if __name__ == '__main__':
parser = argparse.ArgumentParser('Save Predicted Notes Sequence to Midi')
parser.add_argument('--device', type=str, default="cpu")
parser.add_argument('--load_checkpoint', type=str, default=".project_data/snapshots/lstm_all_10.pth")
parser.add_argument('--midi_filename', type=str, default="output.mid")
parser.add_argument('--embed_dim', type=int, default=32)
parser.add_argument('--hidden_dim', type=int, default=256)
parser.add_argument('--n_notes', type=int, default=128)
parser.add_argument('--seq_len', type=int, default=64)
parser.add_argument('--source', type=str, default="interactive")
parser.add_argument('--init_note', type=int, default=60)
parser.add_argument('--temperature', type=float, default=0.5)

main_args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

paths = DataPaths()

# sample one midi file
if main_args.source == 'interactive':
X = create_beat()
X[0][0] = 2.
# convert to float32
X = np.array(X, dtype=np.float32)
elif main_args.source == 'dataset':
dataset = preprocess.dataset.BeatsRhythmsDataset(num_files=1)
X, _ = next(iter(dataset))
it = iter(dataset)
# skip first 10 files
for _ in range(24):
next(it)
X, _, _ = next(it)
X[0][0] = 2.
else:
with open(main_args.source, 'rb') as f:
X = np.load(f, allow_pickle=True)
X[0][0] = 2.
X = np.array(X, dtype=np.float32)



# load model
model = DeepBeats(main_args.n_notes, main_args.embed_dim, main_args.hidden_dim).to(main_args.device)
model = DeepBeatsLSTM(main_args.n_notes, main_args.embed_dim, main_args.hidden_dim).to(device)
if main_args.load_checkpoint:
model.load_state_dict(torch.load(main_args.load_checkpoint))
print(model)

# generate notes seq given durs seq
notes, prev_rest, curr_durs = predict_notes_sequence(
durs_seq = X[np.newaxis, :].copy(), # select the first durs seq for now, batch size = 1
durs_seq = X.copy(), # select the first durs seq for now, batch size = 1
model=model,
device=main_args.device
init_note=main_args.init_note,
device=device,
temperature=main_args.temperature
)

# convert stream to midi
Expand Down
Loading

0 comments on commit c609d65

Please sign in to comment.