Skip to content

Commit

Permalink
Add MIDI random shift augmentation at training time
Browse files Browse the repository at this point in the history
  • Loading branch information
yqzhishen committed Sep 13, 2023
1 parent 45c8cc7 commit 2528b6b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
1 change: 1 addition & 0 deletions configs/some.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pe_ckpt: pretrained/rmvpe/model.pt
midi_min: 0
midi_max: 128
midi_prob_deviation: 0.5
midi_shift_range: [-12, 12]

# neural networks
sort_by_len: true
Expand Down
7 changes: 4 additions & 3 deletions training/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ class BaseDataset(Dataset):
the index function.
"""

def __init__(self, config: dict, data_dir, prefix):
def __init__(self, config: dict, data_dir, prefix, allow_aug=False):
super().__init__()
self.config = config
self.prefix = prefix
self.data_dir = data_dir if isinstance(data_dir, pathlib.Path) else pathlib.Path(data_dir)
self.sizes = np.load(self.data_dir / f'{self.prefix}.lengths')
self.indexed_ds = IndexedDataset(self.data_dir, self.prefix)
self.allow_aug = allow_aug

@property
def _sizes(self):
Expand Down Expand Up @@ -133,11 +134,11 @@ def setup(self, stage):
self.build_losses_and_metrics()
self.train_dataset = self.dataset_cls(
config=self.config, data_dir=self.config['binary_data_dir'],
prefix=self.config['train_set_name']
prefix=self.config['train_set_name'], allow_aug=True
)
self.valid_dataset = self.dataset_cls(
config=self.config, data_dir=self.config['binary_data_dir'],
prefix=self.config['valid_set_name']
prefix=self.config['valid_set_name'], allow_aug=False
)

def get_need_freeze_state_dict_key(self, model_state_dict) -> list:
Expand Down
24 changes: 14 additions & 10 deletions training/me_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import random

import torch
import torch.nn.functional as F
from torch import nn
Expand All @@ -15,15 +17,21 @@ def __init__(self, *args, **kwargs):
self.deviation = self.config['midi_prob_deviation']
self.interval = (self.midi_max - self.midi_min) / (self.num_bins - 1) # align with centers of bins
self.sigma = self.deviation / self.interval
self.midi_shift_min, self.midi_shift_max = self.config['midi_shift_range']

def midi_to_bin(self, midi):
return (midi - self.midi_min) / self.interval

def collater(self, samples):
batch = super().collater(samples)
midi_shifts = [
random.random() * (self.midi_shift_max - self.midi_shift_min) + self.midi_shift_min
if self.allow_aug else 0
for _ in range(len(samples))
]
batch['units'] = collate_nd([s['units'] for s in samples]) # [B, T_s, C]
batch['pitch'] = collate_nd([s['pitch'] for s in samples]) # [B, T_s]
batch['note_midi'] = collate_nd([s['note_midi'] for s in samples]) # [B, T_n]
batch['pitch'] = collate_nd([s['pitch'] + d for s, d in zip(samples, midi_shifts)]) # [B, T_s]
batch['note_midi'] = collate_nd([s['note_midi'] + d for s, d in zip(samples, midi_shifts)]) # [B, T_n]
batch['note_rest'] = collate_nd([s['note_rest'] for s in samples]) # [B, T_n]
batch['note_dur'] = collate_nd([s['note_dur'] for s in samples]) # [B, T_n]

Expand All @@ -38,15 +46,11 @@ def collater(self, samples):
unit2note_ = unit2note[..., None].repeat([1, 1, self.num_bins])
probs = torch.gather(probs, 1, unit2note_)
batch['probs'] = probs # [B, T_s, N]
batch['mask'] = unit2note > 0

bound = torch.diff(
bounds = torch.diff(
unit2note, dim=1, prepend=unit2note.new_zeros((batch['size'], 1))
)
bounds=bound>0


batch['bounds'] = torch.zeros_like(bound,dtype=torch.float).masked_fill(bounds,1) # [B, T_s]
) > 0
batch['bounds'] = bounds.float() # [B, T_s]

return batch

Expand Down Expand Up @@ -81,7 +85,7 @@ def run_model(self, sample, infer=False):
"""
spec = sample['units'] # [B, T_ph]
target = (sample['probs'],sample['bounds']) # [B, T_s, M]
mask=sample['mask']
mask = sample['unit2note'] > 0

f0 = sample['pitch']
output=self.model(x=spec,f0=f0,mask=mask)
Expand Down

0 comments on commit 2528b6b

Please sign in to comment.