-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss.py
38 lines (28 loc) · 1.19 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
from torch import nn
class RBF(nn.Module):
def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
super().__init__()
self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
self.bandwidth_multipliers = self.bandwidth_multipliers.cuda()
self.bandwidth = bandwidth
def get_bandwidth(self, L2_distances):
if self.bandwidth is None:
n_samples = L2_distances.shape[0]
return L2_distances.data.sum() / (n_samples ** 2 - n_samples)
return self.bandwidth
def forward(self, X):
L2_distances = torch.cdist(X, X) ** 2
L2_distances = L2_distances.cuda()
return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)
class MMDLoss(nn.Module):
def __init__(self, kernel=RBF()):
super().__init__()
self.kernel = kernel
def forward(self, X, Y):
K = self.kernel(torch.vstack([X, Y]))
X_size = X.shape[0]
XX = K[:X_size, :X_size].mean()
XY = K[:X_size, X_size:].mean()
YY = K[X_size:, X_size:].mean()
return XX - 2 * XY + YY