-
Notifications
You must be signed in to change notification settings - Fork 28
/
gumbelmodule.py
71 lines (60 loc) · 2.64 KB
/
gumbelmodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn.functional as F
from torch.autograd import Variable
"""
Gumbel Softmax Sampler
Requires 2D input [batchsize, number of categories]
Does not support sinlge binary category. Use two dimensions with softmax instead.
"""
class GumbleSoftmax(torch.nn.Module):
def __init__(self, hard=False):
super(GumbleSoftmax, self).__init__()
self.hard = hard
self.gpu = False
def cuda(self):
self.gpu = True
def cpu(self):
self.gpu = False
def sample_gumbel(self, shape, eps=1e-10):
"""Sample from Gumbel(0, 1)"""
noise = torch.rand(shape)
noise.add_(eps).log_().neg_()
noise.add_(eps).log_().neg_()
if self.gpu:
return Variable(noise).cuda()
else:
return Variable(noise)
def sample_gumbel_like(self, template_tensor, eps=1e-10):
uniform_samples_tensor = template_tensor.clone().uniform_()
gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps))
return gumble_samples_tensor
def gumbel_softmax_sample(self, logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
dim = logits.size(-1)
gumble_samples_tensor = self.sample_gumbel_like(logits.data)
gumble_trick_log_prob_samples = logits + Variable(gumble_samples_tensor)
soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, dim)
return soft_samples
def gumbel_softmax(self, logits, temperature, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y = self.gumbel_softmax_sample(logits, temperature)
if hard:
_, max_value_indexes = y.data.max(1, keepdim=True)
y_hard = logits.data.clone().zero_().scatter_(1, max_value_indexes, 1)
y = Variable(y_hard - y.data) + y
return y
def forward(self, logits, temp=1, force_hard=False):
samplesize = logits.size()
if self.training and not force_hard:
return self.gumbel_softmax(logits, temperature=1, hard=False)
else:
return self.gumbel_softmax(logits, temperature=1, hard=True)