-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathquantise.py
157 lines (141 loc) · 7.33 KB
/
quantise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import einsum
from einops import rearrange
class VectorQuantiser(nn.Module):
"""
Improved version over vector quantiser, with the dynamic initialisation
for these unoptimised "dead" points.
num_embed: number of codebook entry
embed_dim: dimensionality of codebook entry
beta: weight for the commitment loss
distance: distance for looking up the closest code
anchor: anchor sampled methods
first_batch: if true, the offline version of our model
contras_loss: if true, use the contras_loss to further improve the performance
"""
def __init__(self, num_embed, embed_dim, beta, distance='cos',
anchor='probrandom', first_batch=False, contras_loss=False):
super().__init__()
self.num_embed = num_embed
self.embed_dim = embed_dim
self.beta = beta
self.distance = distance
self.anchor = anchor
self.first_batch = first_batch
self.contras_loss = contras_loss
self.decay = 0.99
self.init = False
self.pool = FeaturePool(self.num_embed, self.embed_dim)
self.embedding = nn.Embedding(self.num_embed, self.embed_dim)
self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed)
self.register_buffer("embed_prob", torch.zeros(self.num_embed))
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
assert rescale_logits==False, "Only for interface compatible with Gumbel"
assert return_logits==False, "Only for interface compatible with Gumbel"
# reshape z -> (batch, height, width, channel) and flatten
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = z.view(-1, self.embed_dim)
# clculate the distance
if self.distance == 'l2':
# l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \
torch.sum(self.embedding.weight ** 2, dim=1) + \
2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n'))
elif self.distance == 'cos':
# cosine distances from z to embeddings e_j
normed_z_flattened = F.normalize(z_flattened, dim=1).detach()
normed_codebook = F.normalize(self.embedding.weight, dim=1)
d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n'))
# encoding
sort_distance, indices = d.sort(dim=1)
# look up the closest point for the indices
encoding_indices = indices[:,-1]
encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device)
encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
# quantise and unflatten
z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
# count
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
min_encodings = encodings
# online clustered reinitialisation for unoptimized points
if self.training:
# calculate the average usage of code entries
self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
# running average updates
if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init):
# closest sampling
if self.anchor == 'closest':
sort_distance, indices = d.sort(dim=0)
random_feat = z_flattened.detach()[indices[-1,:]]
# feature pool based random sampling
elif self.anchor == 'random':
random_feat = self.pool.query(z_flattened.detach())
# probabilitical based random sampling
elif self.anchor == 'probrandom':
norm_distance = F.softmax(d.t(), dim=1)
prob = torch.multinomial(norm_distance, num_samples=1).view(-1)
random_feat = z_flattened.detach()[prob]
# decay parameter based on the average usage
decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim)
self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
if self.first_batch:
self.init = True
# contrastive loss
if self.contras_loss:
sort_distance, indices = d.sort(dim=0)
dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True)
dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:]
dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07
contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device))
loss += contra_loss
return z_q, loss, (perplexity, min_encodings, encoding_indices)
class FeaturePool():
"""
This class implements a feature buffer that stores previously encoded features
This buffer enables us to initialize the codebook using a history of generated features
rather than the ones produced by the latest encoders
"""
def __init__(self, pool_size, dim=64):
"""
Initialize the FeaturePool class
Parameters:
pool_size(int) -- the size of featue buffer
"""
self.pool_size = pool_size
if self.pool_size > 0:
self.nums_features = 0
self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size
def query(self, features):
"""
return features from the pool
"""
self.features = self.features.to(features.device)
if self.nums_features < self.pool_size:
if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
self.features = features[random_feat_id]
self.nums_features = self.pool_size
else:
# if the mini-batch is not large nuough, just store it for the next update
num = self.nums_features + features.size(0)
self.features[self.nums_features:num] = features
self.nums_features = num
else:
if features.size(0) > int(self.pool_size):
random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
self.features = features[random_feat_id]
else:
random_id = torch.randperm(self.pool_size)
self.features[random_id[:features.size(0)]] = features
return self.features