forked from franroldans/tfm-franroldan-wav2pix
-
Notifications
You must be signed in to change notification settings - Fork 37
/
loss_estimator.py
27 lines (22 loc) · 1012 Bytes
/
loss_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch import nn
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
class generator_loss(torch.nn.Module):
def __init__(self):
super(generator_loss, self).__init__()
self.estimator = nn.BCELoss()
def forward(self, fake):
batch_size = fake.size()[0]
self.labels = Variable(torch.FloatTensor(batch_size).cuda().fill_(1))
return self.estimator(fake, self.labels)
class discriminator_loss(torch.nn.Module):
def __init__(self):
super(discriminator_loss, self).__init__()
self.estimator = nn.BCELoss()
def forward(self, real, wrong, fake):
batch_size = real.size()[0]
self.real_labels = Variable(torch.FloatTensor(batch_size).cuda().fill_(1))
self.fake_labels = Variable(torch.FloatTensor(batch_size).cuda().fill_(0))
return self.estimator(real, self.real_labels) + 0.5 * (self.estimator(wrong, self.fake_labels) + self.estimator(fake, self.fake_labels))