-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
2,503 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__pycache__ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import torch.nn as nn | ||
|
||
from utils import constant | ||
|
||
|
||
def init_rnn_wt(rnn): | ||
for names in rnn._all_weights: | ||
for name in names: | ||
if name.startswith('weight_'): | ||
wt = getattr(rnn, name) | ||
nn.init.xavier_uniform_(wt) | ||
# wt.data.uniform_(-constant.rand_unif_init_mag, constant.rand_unif_init_mag) | ||
elif name.startswith('bias_'): | ||
# set forget bias to 1 | ||
bias = getattr(rnn, name) | ||
n = bias.size(0) | ||
start, end = n // 4, n // 2 | ||
bias.data.fill_(0.) | ||
bias.data[start:end].fill_(1.) | ||
|
||
def init_linear_wt(linear): | ||
# linear.weight.data.normal_(std=constant.trunc_norm_init_std) | ||
nn.init.xavier_uniform_(linear.weight) | ||
if linear.bias is not None: | ||
n = linear.bias.size(0) | ||
start, end = n // 4, n // 2 | ||
linear.bias.data.fill_(0.) | ||
linear.bias.data[start:end].fill_(1.) | ||
# linear.bias.data.nomral_(std=constant.trunc_norm_init_std) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
import math | ||
import time | ||
import pprint | ||
import random | ||
|
||
from tqdm import tqdm | ||
import dill as pickle | ||
import numpy as np | ||
from numpy import random | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence | ||
|
||
from utils import constant | ||
|
||
|
||
def gumbel_softmax(logits, dim, tau=1.0): | ||
""" | ||
Sample z ~ log p(z) + G(0, 1) | ||
""" | ||
eps=1e-20 | ||
noise = torch.rand(logits.size()) | ||
noise = -torch.log(-torch.log(noise + eps) + eps) # gumble noise | ||
if constant.USE_CUDA: | ||
noise = noise.float().cuda() | ||
return F.softmax((logits + noise) / tau, dim=dim) | ||
|
||
def reparameterization(mu, logvar, z_dim): | ||
""" | ||
Reparameterization trick: z = mu + std*eps; eps ~ N(0, I) | ||
""" | ||
eps = torch.randn(z_dim) | ||
eps = eps.cuda() if constant.USE_CUDA else eps | ||
return mu + torch.exp(logvar/2) * eps | ||
|
||
def split_z(z, B, M, K): | ||
return z.view(B, M, K) | ||
|
||
def merge_z(z, B, M, K): | ||
return z.view(B, M * K) | ||
|
||
def cat_mi(p, q): | ||
pass | ||
|
||
def cat_kl(logp, logq, dim=1): | ||
""" | ||
\sum q * log(q/p) | ||
""" | ||
if logq.dim() > 2: | ||
logq = logq.squeeze() | ||
|
||
q = torch.exp(logq) | ||
kl = torch.sum(q * (logq - logp), dim=dim) | ||
return torch.mean(kl) | ||
|
||
def norm_kl(recog_mu, recog_logvar, prior_mu=None, prior_logvar=None): | ||
# find the KL divergence between two Gaussian distributions (defaults to standard normal for prior) | ||
if prior_mu is None: | ||
prior_mu = torch.zeros(1) | ||
prior_logvar = torch.ones(1) | ||
if constant.USE_CUDA: | ||
prior_mu = prior_mu.cuda() | ||
prior_logvar = prior_logvar.cuda() | ||
loss = 1.0 + (recog_logvar - prior_logvar) | ||
loss -= torch.div(torch.pow(prior_mu - recog_mu, 2), torch.exp(prior_logvar)) | ||
loss -= torch.div(torch.exp(recog_logvar), torch.exp(prior_logvar)) | ||
kl_loss = -0.5 * torch.mean(loss, dim=1) | ||
return torch.mean(kl_loss) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import math | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | ||
|
||
from models.commons.attention import Attention | ||
from models.commons.initializer import init_rnn_wt, init_linear_wt | ||
from utils import constant | ||
|
||
|
||
class RNNDecoder(nn.Module): | ||
def __init__(self, V, D, H, L=1, embedding=None): | ||
super(RNNDecoder, self).__init__() | ||
self.V = V | ||
self.H = H | ||
self.L = L | ||
self.D = D | ||
if constant.attn != 'none': | ||
self.attention = Attention(H, constant.attn) | ||
# self.dropout = nn.Dropout(constant.dropout) | ||
|
||
self.cuda = constant.USE_CUDA | ||
self.embeddings_cpu = constant.embeddings_cpu | ||
|
||
if embedding is not None: | ||
self.embedding = embedding | ||
else: | ||
self.embedding = nn.Embedding(V, D) | ||
self.embedding.weight.requires_grad = True | ||
|
||
if constant.lstm: | ||
self.rnn = nn.LSTM(D, H, L, batch_first=True, bidirectional=False) | ||
else: | ||
self.rnn = nn.GRU(D, H, L, batch_first=True, bidirectional=False) | ||
|
||
self.out = nn.Linear(H, V) | ||
if constant.weight_tie: | ||
self.out = nn.Linear(H, V) | ||
self.out.weight = self.embedding.weight # Assuming H == D. They share the weight, and updated together | ||
|
||
def forward(self, x_t, last_h, src_hs=None, use_attn=False): | ||
# Note: we run this in a for loop (mulitple batches over single token at a time) | ||
# batch_size = x_t.size(0) | ||
x = self.embedding(x_t) | ||
if self.cuda and self.embeddings_cpu: | ||
x = x.cuda() | ||
# x = self.dropout(x) | ||
# x = x.view(1, batch_size, self.H) # S=1 x B x N | ||
outputs, dec_h_t = self.rnn(x.unsqueeze(1), last_h) # [B, 1, H] & [1, B, H] | ||
|
||
if use_attn: | ||
h, _ = self.attention(src_hs, src_hs, outputs) | ||
# output = self.out(self.linear(h)) | ||
output = self.out(h) | ||
else: | ||
# output = self.out(self.linear(outputs)) | ||
output = self.out(outputs) | ||
|
||
return output.squeeze(), dec_h_t | ||
|
||
def predict_one(self, x_t, last_h, src_hs=None, use_attn=False): | ||
with torch.no_grad(): | ||
x = self.embedding(x_t) | ||
outputs, dec_h_t = self.rnn(x.unsqueeze(1), last_h) # [B, 1, H] & [1, B, H] | ||
if use_attn: | ||
h, _ = self.attention(src_hs, src_hs, outputs) | ||
output = self.out(h) | ||
else: | ||
output = self.out(outputs) | ||
return output.squeeze(), dec_h_t |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.utils.rnn import pack_sequence, pack_padded_sequence, pad_packed_sequence | ||
|
||
from models.commons.initializer import init_rnn_wt | ||
from utils import constant | ||
|
||
|
||
class RNNEncoder(nn.Module): | ||
def __init__(self, V, D, H, L=1, embedding=None): | ||
super(RNNEncoder, self).__init__() | ||
self.V = V | ||
self.H = H | ||
self.L = L | ||
self.D = D | ||
self.bi = True if constant.bi == 'bi' else False | ||
self.use_lstm = constant.lstm | ||
# self.dropout = nn.Dropout(constant.dropout) | ||
|
||
self.cuda = constant.USE_CUDA | ||
|
||
if embedding is not None: | ||
self.embedding = embedding | ||
else: | ||
self.embedding = nn.Embedding(V, D) | ||
self.embedding.weight.requires_grad = True | ||
|
||
self.embedding_dropout = nn.Dropout(constant.dropout) | ||
|
||
if constant.lstm: | ||
self.rnn = nn.LSTM(D, H, L, batch_first=True, bidirectional=self.bi) | ||
else: | ||
self.rnn = nn.GRU(D, H, L, batch_first=True, bidirectional=self.bi) | ||
|
||
def soft_embed(self, x): | ||
# x: (T, B, V), (B, V) or (V) | ||
return (x.unsqueeze(len(x.shape)) * self.embedding.weight).sum(dim=len(x.shape)-1) | ||
|
||
def forward(self, seqs, lens, soft_encode=False, logits=None): | ||
# Note: we run this all at once (over multiple batches of multiple sequences) | ||
# x, lens = pad_packed_sequence(pack_sequence(seqs)) | ||
if not soft_encode: | ||
x = self.embedding(seqs) | ||
x = self.embedding_dropout(x) | ||
else: | ||
x = self.soft_embed(logits).transpose(0, 1).contiguous() | ||
x = pack_padded_sequence(x, lens, batch_first=True) | ||
outputs, hidden = self.rnn(x) | ||
outputs, _ = pad_packed_sequence(outputs, batch_first=True) | ||
|
||
if self.use_lstm: | ||
h, c = hidden | ||
|
||
if self.bi: | ||
# [2, B, H] => [B, 2H] | ||
if self.use_lstm: | ||
h = h.transpose(0, 1).contiguous().view(-1, 2*self.H) | ||
c = c.transpose(0, 1).contiguous().view(-1, 2*self.H) | ||
# h = torch.cat((h[0], h[1]), 1) | ||
# c = torch.cat((c[0], c[1]), 1) | ||
return outputs, h.squeeze(), c.squeeze() | ||
else: | ||
h = torch.cat((hidden[0], hidden[1]), 1) | ||
return outputs, h.squeeze() | ||
else: | ||
return outputs, hidden.squeeze() | ||
|
||
def predict_one(self, seq): | ||
with torch.no_grad(): | ||
x = self.embedding(seq) | ||
outputs, hidden = self.rnn(x) | ||
if self.bi: | ||
# [2, B, H] => [B, 2H] | ||
hidden = torch.cat((hidden[0], hidden[1]), 1) | ||
return outputs, hidden | ||
else: | ||
return outputs, hidden |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .utils import * | ||
from .dataset import * | ||
from .sentiment_dataset import * | ||
from .lang import * | ||
from .bleu import * | ||
from .beam_omt import * | ||
from .rouge import * | ||
from .masked_cross_entropy import * | ||
from .embedding_metrics import * |
Oops, something went wrong.