-
Notifications
You must be signed in to change notification settings - Fork 409
/
losses.py
97 lines (78 loc) · 3.66 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from .llvae import LosslessLatentEncoder
def total_variation(image):
"""
Compute normalized total variation.
Inputs:
- image: PyTorch Variable of shape (N, C, H, W)
Returns:
- TV: total variation normalized by the number of elements
"""
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
class ComparativeTotalVariation(torch.nn.Module):
"""
Compute the comparative loss in tv between two images. to match their tv
"""
def forward(self, pred, target):
return torch.abs(total_variation(pred) - total_variation(target))
# Gradient penalty
def get_gradient_penalty(critic, real, fake, device):
with torch.autocast(device_type='cuda'):
real = real.float()
fake = fake.float()
alpha = torch.rand(real.size(0), 1, 1, 1).to(device).float()
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
if torch.isnan(interpolates).any():
print('d_interpolates is nan')
d_interpolates = critic(interpolates)
fake = torch.ones(real.size(0), 1, device=device)
if torch.isnan(d_interpolates).any():
print('fake is nan')
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
# see if any are nan
if torch.isnan(gradients).any():
print('gradients is nan')
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty.float()
class PatternLoss(torch.nn.Module):
def __init__(self, pattern_size=4, dtype=torch.float32):
super().__init__()
self.pattern_size = pattern_size
self.llvae_encoder = LosslessLatentEncoder(3, pattern_size, dtype=dtype)
def forward(self, pred, target):
pred_latents = self.llvae_encoder(pred)
target_latents = self.llvae_encoder(target)
matrix_pixels = self.pattern_size * self.pattern_size
color_chans = pred_latents.shape[1] // 3
# pytorch
r_chans, g_chans, b_chans = torch.split(pred_latents, [color_chans, color_chans, color_chans], 1)
r_chans_target, g_chans_target, b_chans_target = torch.split(target_latents, [color_chans, color_chans, color_chans], 1)
def separated_chan_loss(latent_chan):
nonlocal matrix_pixels
chan_mean = torch.mean(latent_chan, dim=[1, 2, 3])
chan_splits = torch.split(latent_chan, [1 for i in range(matrix_pixels)], 1)
chan_loss = None
for chan in chan_splits:
this_mean = torch.mean(chan, dim=[1, 2, 3])
this_chan_loss = torch.abs(this_mean - chan_mean)
if chan_loss is None:
chan_loss = this_chan_loss
else:
chan_loss = chan_loss + this_chan_loss
chan_loss = chan_loss * (1 / matrix_pixels)
return chan_loss
r_chan_loss = torch.abs(separated_chan_loss(r_chans) - separated_chan_loss(r_chans_target))
g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target))
b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target))
return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333