-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
27 lines (22 loc) · 1.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import torch
from torch.distributions.beta import Beta
def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6):
if np.random.rand() > p:
return x
batch_size = x.size(0)
# original: batch x channel x frequency x time -> mean(frequency, time)
# freq_style: batch x channel x frequency x time -> mean(channel, time)
# changed from dim=[2,3] to dim=[1,3] - from channel-wise statistics to frequency-wise statistics
f_mu = x.mean(dim=[1, 3], keepdim=True)
f_var = x.var(dim=[1, 3], keepdim=True)
f_sig = (f_var + eps).sqrt() # compute instance standard deviation
f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients
x_normed = (x - f_mu) / f_sig # normalize input
perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling
lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device) # sample instance-wise convex weights
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation
x = x_normed * sig_mix + mu_mix # denormalize input using the mixed frequency statistics
return x