forked from Ma-Lab-Berkeley/ReduNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
96 lines (84 loc) · 3.12 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
from opt_einsum import contract
class MaximalCodingRateReduction(nn.Module): #TODO: fix this
def __init__(self, eps=0.1, gam=1.):
super(MaximalCodingRateReduction, self).__init__()
self.eps = eps
self.gam = gam
def discrimn_loss(self, Z):
m, d = Z.shape
I = torch.eye(d).to(Z.device)
c = d / (m * self.eps)
return logdet(c * covariance(Z) + I) / 2.
def compress_loss(self, Z, Pi):
loss_comp = 0.
for j in y.unique():
Z_j = Z[(y == int(j))[:, 0]]
m_j = Z_j.shape[0]
c_j = d / (m_j * eps)
logdet_j = logdet(I + c_j * Z_j.T @ Z_j)
loss_comp += logdet_j * m_j / (2 * m)
return loss_comp
def forward(self, Z, y):
Pi = y # TODO: change this to prob distribution
loss_discrimn = self.discrimn_loss(Z)
loss_compress = self.compress_loss(Z, Pi)
return loss_discrimn - self.gam * loss_compress
def compute_mcr2(Z, y, eps):
if len(Z.shape) == 2:
loss_func = compute_loss_vec
elif len(Z.shape) == 3:
loss_func = compute_loss_1d
elif len(Z.shape) == 4:
loss_func = compute_loss_2d
return loss_func(Z, y, eps)
def compute_loss_vec(Z, y, eps):
m, d = Z.shape
I = torch.eye(d).to(Z.device)
c = d / (m * eps)
loss_expd = logdet(c * covariance(Z) + I) / 2.
loss_comp = 0.
for j in y.unique():
Z_j = Z[(y == int(j))[:, 0]]
m_j = Z_j.shape[0]
c_j = d / (m_j * eps)
logdet_j = logdet(I + c_j * Z_j.T @ Z_j)
loss_comp += logdet_j * m_j / (2 * m)
loss_expd, loss_comp = loss_expd.item(), loss_comp.item()
return loss_expd - loss_comp, loss_expd, loss_comp
def compute_loss_1d(V, y, eps):
m, C, T = V.shape
I = torch.eye(C).unsqueeze(-1).to(V.device)
alpha = C / (m * eps)
cov = alpha * covariance(V) + I
loss_expd = logdet(cov.permute(2, 0, 1)).sum() / (2 * T)
loss_comp = 0.
for j in y.unique():
V_j = V[y==int(j)]
m_j = V_j.shape[0]
alpha_j = C / (m_j * eps)
cov_j = alpha_j * covariance(V_j) + I
loss_comp += m_j / m * logdet(cov_j.permute(2, 0, 1)).sum() / (2 * T)
loss_expd, loss_comp = loss_expd.real.item(), loss_comp.real.item()
return loss_expd - loss_comp, loss_expd, loss_comp
def compute_loss_2d(V, y, eps):
m, C, H, W = V.shape
I = torch.eye(C).unsqueeze(-1).unsqueeze(-1).to(V.device)
alpha = C / (m * eps)
cov = alpha * covariance(V) + I
loss_expd = logdet(cov.permute(2, 3, 0, 1)).sum() / (2 * H * W)
loss_comp = 0.
for j in y.unique():
V_j = V[(y==int(j))[:, 0]]
m_j = V_j.shape[0]
alpha_j = C / (m_j * eps)
cov_j = alpha_j * covariance(V_j) + I
loss_comp += m_j / m * logdet(cov_j.permute(2, 3, 0, 1)).sum() / (2 * H * W)
loss_expd, loss_comp = loss_expd.real.item(), loss_comp.real.item()
return loss_expd - loss_comp, loss_expd, loss_comp
def covariance(X):
return contract('ji...,jk...->ik...', X, X.conj())
def logdet(X):
sgn, logdet = torch.linalg.slogdet(X)
return sgn * logdet