Skip to content

Commit

Permalink
Move dataset folder
Browse files Browse the repository at this point in the history
  • Loading branch information
jacbz committed Jun 25, 2021
1 parent d637e98 commit f729450
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 13 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.idea
node_modules
dist
datasets/hooktheory
datasets/processed
model/dataset/hooktheory
model/dataset/processed
*.npy
*.pth
__pycache__
Expand Down
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,13 @@ def forward(self, x):
chord_prediction = self.chord_linear(hx)
output = [chord_prediction]

# force stop when reaching max length
# stop when reaching max length
max_chord_progression_length = 325
for i in range(max_chord_progression_length):
hx, cx = self.cell(hx, (hx, cx))
chord_prediction = self.chord_linear(hx)
output.append(chord_prediction)
# break when all have predicted an 8

output = torch.stack(output, dim=1)
preds = output.argmax(dim=2)
Expand Down
23 changes: 14 additions & 9 deletions model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,31 @@
from model import Model

device = "cuda" if torch.cuda.is_available() else "cpu"
file = "model.pth"

parser = argparse.ArgumentParser(description='Generate some chords.')
parser.add_argument('input', type=str)
args = parser.parse_args()

if __name__ == '__main__':
def predict_chords(input):
print("Loading model...", end=" ")
model = Model()
model.load_state_dict(torch.load("model.pth"))
model.load_state_dict(torch.load(file))
print(f"Loaded {file}.")
model.to(device)

model.eval()

input = args.input
embedding, length = make_embedding(input)

input = pack_padded_sequence(embedding[None], torch.tensor([length]), batch_first=True, enforce_sorted=False)
pred, _ = model(input)

chords = pred.argmax(dim=2)[0].tolist()
chords.append(8)
chords = chords[:chords.index(8)-1] # cut off 8
chords = chords[:chords.index(8) - 1] # cut off 8

print(f"Chord progression: {' '.join(map(str, chords))}")

print(f"Chord progression: {' '.join(map(str, chords))}")

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Generate some chords.')
parser.add_argument('input', type=str)
args = parser.parse_args()
predict_chords(args.input)
2 changes: 1 addition & 1 deletion model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __getitem__(self, index):


if __name__ == '__main__':
dataset_folder = "../datasets/processed"
dataset_folder = "dataset/processed"
dataset_files = os.listdir(dataset_folder)
embeddings_file = "embeddings"
embedding_lengths_file = "embedding_lengths.json"
Expand Down

0 comments on commit f729450

Please sign in to comment.