Skip to content

Commit

Permalink
New and tested transformer code
Browse files Browse the repository at this point in the history
  • Loading branch information
semjon00 committed May 17, 2024
1 parent 05883d1 commit 7d8f27b
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 235 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2024 Semjon Kravtšenko & Kermo Saarse
Copyright (c) 2024 Semjon Kravtšenko

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
120 changes: 41 additions & 79 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from itertools import chain
import torch
from torch import Tensor, nn
import einops
Expand All @@ -7,57 +6,46 @@

from components import PriorityNoise
from cnn import CNN
from transformer import PlanePositionalEmbedding, Transformer
from transformer import MultidimPositionalEmbedding, Transformer

# TODO: Re-introduce the history

# TODO: Swin
# TODO: Better loss function
# TODO: in CNN, different convolutions should be used for different pitches (use groups parameter)

# TODO: Refactor training parameters into a separate class (don't forget kl_loss!)

# TODO: Do a code self-review, there can be some "fun" surprises!


class EchoMorphParameters:
"""Training parameters"""
def __init__(self, **kwargs):
"""By default, contains large model specs"""
one_sec_len = round(24000 / 84 / 64) * 64 # sample_rate / hop_length; approximately
self.target_sample_len = one_sec_len // 2
self.history_len = one_sec_len // 2
self.fragment_len = one_sec_len // 8
self.target_sample_len = one_sec_len // 32
self.history_len = one_sec_len // 32
self.fragment_len = one_sec_len // 32
assert self.target_sample_len == self.history_len, "oh no! - speaker encoding is TODO"

self.spect_width = 128 # x_width
self.length_of_patch = 8

self.embed_dim = 128
self.embed_dim = 64

self.se_convrec = (2, 8, 32, 64)
self.se_convrepeat = 6
self.se_blocks = (10, 0, 0)
self.se_heads = 4
self.se_hidden_dim = 4 * self.embed_dim
self.se_convrec = (2,)
self.se_convrepeat = 4
self.se_blocks = 4
self.se_output_tokens = 256
self.se_kl_loss_k = 0.003

self.ae_convrec = (2, 8, 16, 32)
self.ae_convrec = (2,)
self.ae_convrepeat = 4
self.ae_blocks = (6, 0, 0)
self.ae_heads = 4
self.ae_hidden_dim = 3 * self.embed_dim
self.ae_blocks = 4

self.ad_blocks = (16, 0, 0)
self.ad_heads = 8
self.ad_hidden_dim = 6 * self.embed_dim
self.ad_blocks = 4

self.drop = 0.00
self.rm_k_min = 0.0
self.rm_k_min = 1.0
self.rm_k_max = 1.0
self.rm_fun = 'lin'
self.mid_repeat_interval = (2, 5) # (inclusive, exclusive)
self.se_kl_loss_k = 0.000

for key, value in kwargs.items():
setattr(self, key, value)
Expand All @@ -68,32 +56,30 @@ def __init__(self, pars: EchoMorphParameters):
super().__init__()

self.cnn = CNN(pars.se_convrec, pars.se_convrepeat)
self.transformer = Transformer(embed_dim=pars.embed_dim, mlp_hidden_dim=pars.se_hidden_dim, heads=pars.se_heads,
drop=pars.drop, blocks_num=pars.se_blocks, cross_n=0)
self.entok = nn.Linear(self.cnn.out_channels(), pars.embed_dim)
reduction = self.cnn.res_reduction_factor()
self.pos_embed = PlanePositionalEmbedding(
pars.history_len // reduction, pars.spect_width // reduction, pars.embed_dim
self.transformer = Transformer(
input_dim=self.cnn.out_channels, output_dim=pars.embed_dim,
input_size=(pars.history_len // reduction, pars.spect_width // reduction),
num_blocks=pars.se_blocks, embed_dim=pars.embed_dim, cross_n=0,
rearrange_back=False
)

self.out_tokens = pars.se_output_tokens
self.mean_linear = nn.Linear(pars.embed_dim, pars.embed_dim)
self.log_var_linear = nn.Linear(pars.embed_dim, pars.embed_dim)

def forward_shared(self, x: Tensor, mid_rep) -> (Tensor, Tensor):
def forward_shared(self, x: Tensor) -> (Tensor, Tensor):
if len(x.shape) < 4:
x = x.unsqueeze(0)
x = self.cnn(x)
x = self.entok(x)
x = self.pos_embed(x)
x = einops.rearrange(x, '... l w d -> ... (l w) d')
x = self.transformer(x, [], mid_rep)
x = self.transformer(x, [])

ret = x[..., :self.out_tokens, :]
assert ret.size(-2) == self.out_tokens
return ret

def forward_train(self, x, mid_rep):
ret_tok = self.forward_shared(x, mid_rep)
def forward_train(self, x):
ret_tok = self.forward_shared(x)
means = self.mean_linear(ret_tok)
log_vars = self.log_var_linear(ret_tok)
kl_loss = torch.mean(0.5 * torch.sum(torch.exp(log_vars) + means ** 2 - log_vars - 1, dim=-1))
Expand All @@ -103,57 +89,46 @@ def forward_train(self, x, mid_rep):
z = std * epsilon + means
return z, kl_loss

def forward_use(self, x, mid_rep):
return self.mean_linear(self.forward_shared(x, mid_rep))
def forward_use(self, x):
return self.mean_linear(self.forward_shared(x))


class AudioEncoder(nn.Module):
def __init__(self, pars: EchoMorphParameters):
super().__init__()

self.cnn = CNN(pars.ae_convrec, pars.ae_convrepeat)
self.transformer = Transformer(embed_dim=pars.embed_dim, mlp_hidden_dim=pars.ae_hidden_dim, heads=pars.ae_heads,
drop=pars.drop, blocks_num=pars.ae_blocks, cross_n=0)
self.entok = nn.Linear(self.cnn.out_channels(), pars.embed_dim)
reduction = self.cnn.res_reduction_factor()
self.pos_embed = PlanePositionalEmbedding(
pars.fragment_len // reduction, pars.spect_width // reduction, pars.embed_dim
self.transformer = Transformer(
input_dim=self.cnn.out_channels, output_dim=pars.embed_dim,
input_size=(pars.fragment_len // reduction, pars.spect_width // reduction),
num_blocks=pars.ae_blocks, embed_dim=pars.embed_dim, cross_n=0,
rearrange_back=False
)

def forward(self, x: Tensor, mid_rep) -> Tensor:
def forward(self, x: Tensor) -> Tensor:
x = self.cnn(x)
x = self.entok(x)
x = self.pos_embed(x)
x = einops.rearrange(x, '... l w d -> ... (l w) d')
x = self.transformer.forward(x, [], mid_rep)
x = self.transformer(x, [])
return x


class AudioDecoder(Transformer):
def __init__(self, pars: EchoMorphParameters):
super().__init__(embed_dim=pars.embed_dim, mlp_hidden_dim=pars.ad_hidden_dim, heads=pars.ad_heads,
drop=pars.drop, blocks_num=pars.ad_blocks, cross_n=2)
super().__init__(input_dim=pars.embed_dim, output_dim=2 * pars.length_of_patch,
input_size=(pars.fragment_len // pars.length_of_patch, pars.spect_width),
num_blocks=pars.ad_blocks, embed_dim=pars.embed_dim, cross_n=2)

self.spect_width = pars.spect_width
self.fragment_len = pars.fragment_len
self.length_of_patch = pars.length_of_patch
self.embed_dim = pars.embed_dim

self.detok = nn.Linear(pars.embed_dim, 2 * pars.length_of_patch)
self.pos_embed = PlanePositionalEmbedding(
self.fragment_len // self.length_of_patch, self.spect_width, self.embed_dim
)

def forward(self, im: Tensor, sc: Tensor, mid_rep) -> Tensor:
def forward(self, im: Tensor, sc: Tensor) -> Tensor:
dims = [self.fragment_len // self.length_of_patch, self.spect_width, self.embed_dim]
if len(im.size()) > 2:
dims = [im.size(0)] + dims
feed = self.pos_embed(torch.zeros(dims, dtype=im.dtype, device=im.device))
feed = einops.rearrange(feed, '... l w d -> ... (l w) d')
x = super().forward(feed, [im, sc], mid_rep)

x = einops.rearrange(x, '... (l w) d -> ... l w d', w=self.spect_width)
x = self.detok(x)
x = super().forward(feed, [im, sc])
x = einops.rearrange(x, ' ... l w (c ld) -> ... (l ld) w c', ld=self.length_of_patch)
return x

Expand All @@ -169,28 +144,15 @@ def __init__(self, pars: EchoMorphParameters):
)
self.audio_decoder = AudioDecoder(pars)

def forward(self, target_sample, source_fragment, middle_repeats):
def forward(self, target_sample, source_fragment):
"""Used for training, use inference.py for inference"""
speaker_characteristic, se_loss = self.speaker_encoder.forward_train(target_sample, middle_repeats)
intermediate = self.audio_encoder(source_fragment, middle_repeats)
speaker_characteristic, se_loss = self.speaker_encoder.forward_train(target_sample)
intermediate = self.audio_encoder(source_fragment)
intermediate = self.bottleneck(intermediate)
output = self.audio_decoder(intermediate, speaker_characteristic, middle_repeats)
output = self.audio_decoder(intermediate, speaker_characteristic)
extra_loss = self.pars.se_kl_loss_k * se_loss
return output, extra_loss

def get_multiplication_parameters(self):
return chain.from_iterable([
m.blocks_mid.parameters() for m in [self.speaker_encoder.transformer,
self.audio_encoder.transformer,
self.audio_decoder]
])

def get_base_parameters(self):
mult_params = set(self.get_multiplication_parameters())
all_params = set(self.parameters())
base_params = list(all_params - mult_params)
return base_params


def load_model(directory, device, dtype, verbose=False):
fp = directory / 'model.bin'
Expand Down
17 changes: 5 additions & 12 deletions training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def print(*args, **kwargs):
def print_cuda_stats():
if str(device) == "cpu":
print(f'Cuda memory | CPU is used')
return
try:
vals = torch.cuda.mem_get_info()
print(f'Cuda memory | free:{str(vals[0])} total:{str(vals[1])}')
Expand Down Expand Up @@ -201,7 +202,7 @@ def load_progress():
for d in [1, 4]:
torchinfo_summary(model, ((args.batch_size, model.pars.history_len, model.pars.spect_width, 2),
(args.batch_size, model.pars.fragment_len, model.pars.spect_width, 2),),
middle_repeats=1, depth=d)
depth=d)
print(pars.__dict__)

try:
Expand Down Expand Up @@ -314,12 +315,10 @@ def eval_model(model, eval_datasets):
total_items = 0
with torch.inference_mode():
model.eval()
r = random.Random(42)
model.bottleneck.deterministic(42)
for target_sample, dataloader in eval_datasets:
for history, fragments in iter(dataloader):
rep = r.randint(*model.pars.mid_repeat_interval)
pred, extra_loss = model(history, fragments, rep)
pred, extra_loss = model(history, fragments)
loss: Tensor = loss_function(pred.float(), fragments.float()).to(dtype=precision) + extra_loss
if loss.isnan():
raise LossNaNException()
Expand Down Expand Up @@ -361,8 +360,7 @@ def train_on_bite(model: EchoMorph, optimizer: torch.optim.Optimizer, train_spec
model.train()
for history, fragments in iter(dataloader):
optimizer.zero_grad()
mid_rep = random.randint(*model.pars.mid_repeat_interval)
pred, extra_loss = model(history, fragments, mid_rep)
pred, extra_loss = model(history, fragments)
loss: Tensor = loss_function(pred.float(), fragments.float()).to(dtype=precision) + extra_loss
if loss.isnan():
raise LossNaNException()
Expand All @@ -384,12 +382,7 @@ def training():
lr, = training_params
eval_datasets = create_eval_datasets(model.pars)
last_save = time.time()
optimizer = torch.optim.Adam([
{'params': model.get_base_parameters(),
'lr': lr},
{'params': model.get_multiplication_parameters(),
'lr': (lr / (sum(model.pars.mid_repeat_interval) - 1))}
])
optimizer = torch.optim.Adam(model.parameters(), lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=30, min_lr=1e-8,
threshold=0.001, threshold_mode='rel')
print_cuda_stats()
Expand Down
Loading

0 comments on commit 7d8f27b

Please sign in to comment.