Skip to content
This repository has been archived by the owner on Dec 26, 2024. It is now read-only.

Commit

Permalink
Merge internal changes (facebookresearch#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
myleott authored May 24, 2018
1 parent 29153e2 commit ec0031d
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 57 deletions.
5 changes: 3 additions & 2 deletions fairseq/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def add(self, ref, pred):
raise TypeError('pred must be a torch.IntTensor(got {})'
.format(type(pred)))

assert self.unk > 0, 'unknown token index must be >0'
# don't match unknown words
rref = ref.clone()
rref.apply_(lambda x: x if x != self.unk else -x)
assert not rref.lt(0).any()
rref[rref.eq(self.unk)] = -999

rref = rref.contiguous().view(-1)
pred = pred.contiguous().view(-1)
Expand Down
25 changes: 21 additions & 4 deletions fairseq/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def add_symbol(self, word, n=1):
self.count.append(n)
return idx

def update(self, new_dict):
"""Updates counts from new dictionary."""
for word in new_dict.symbols:
idx2 = new_dict.indices[word]
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + new_dict.count[idx2]
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(new_dict.count[idx2])

def finalize(self):
"""Sort symbols by frequency in descending order, ignoring special ones."""
self.count, self.symbols = zip(
Expand All @@ -102,7 +115,7 @@ def unk(self):
return self.unk_index

@classmethod
def load(cls, f):
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
Expand All @@ -114,8 +127,12 @@ def load(cls, f):

if isinstance(f, str):
try:
with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd)
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
return cls.load(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
return cls.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except Exception:
Expand All @@ -141,6 +158,6 @@ def save(self, f, threshold=3, nwords=-1):
cnt = 0
for i, t in enumerate(zip(self.symbols, self.count)):
if i >= self.nspecial and t[1] >= threshold \
and (nwords < 0 or cnt < nwords):
and (nwords <= 0 or cnt < nwords):
print('{} {}'.format(t[0], t[1]), file=f)
cnt += 1
167 changes: 122 additions & 45 deletions fairseq/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ def add_args(parser):
help='encoder embedding dimension')
parser.add_argument('--encoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-hidden-size', type=int, metavar='N',
help='encoder hidden size')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='number of encoder layers')
parser.add_argument('--encoder-bidirectional', action='store_true',
help='make all layers of encoder bidirectional')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
help='decoder hidden size')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='number of decoder layers')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
Expand All @@ -60,68 +66,102 @@ def build_model(cls, args, src_dict, dst_dict):
args.encoder_embed_path = None
if not hasattr(args, 'decoder_embed_path'):
args.decoder_embed_path = None

encoder_embed_dict = None
if not hasattr(args, 'encoder_hidden_size'):
args.encoder_hidden_size = args.encoder_embed_dim
if not hasattr(args, 'decoder_hidden_size'):
args.decoder_hidden_size = args.decoder_embed_dim
if not hasattr(args, 'encoder_bidirectional'):
args.encoder_bidirectional = False

def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(embed_path)
utils.print_embed_overlap(embed_dict, dictionary)
return utils.load_embedding(embed_dict, dictionary, embed_tokens)

pretrained_encoder_embed = None
if args.encoder_embed_path:
encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path)
utils.print_embed_overlap(encoder_embed_dict, src_dict)

decoder_embed_dict = None
pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, src_dict, args.encoder_embed_dim)
pretrained_decoder_embed = None
if args.decoder_embed_path:
decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
utils.print_embed_overlap(decoder_embed_dict, dst_dict)
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, dst_dict, args.decoder_embed_dim)

encoder = LSTMEncoder(
src_dict,
dictionary=src_dict,
embed_dim=args.encoder_embed_dim,
embed_dict=encoder_embed_dict,
hidden_size=args.encoder_hidden_size,
num_layers=args.encoder_layers,
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
bidirectional=args.encoder_bidirectional,
pretrained_embed=pretrained_encoder_embed,
)
try:
attention = bool(eval(args.decoder_attention))
except TypeError:
attention = bool(args.decoder_attention)
decoder = LSTMDecoder(
dst_dict,
encoder_embed_dim=args.encoder_embed_dim,
dictionary=dst_dict,
embed_dim=args.decoder_embed_dim,
embed_dict=decoder_embed_dict,
hidden_size=args.decoder_hidden_size,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers,
attention=bool(eval(args.decoder_attention)),
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
attention=attention,
encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
)
return cls(encoder, decoder)


class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(self, dictionary, embed_dim=512, embed_dict=None,
num_layers=1, dropout_in=0.1, dropout_out=0.1):
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad_source=LanguagePairDataset.LEFT_PAD_SOURCE,
pretrained_embed=None,
padding_value=0.,
):
super().__init__(dictionary)
self.num_layers = num_layers
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.bidirectional = bidirectional
self.hidden_size = hidden_size

num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
else:
self.embed_tokens = pretrained_embed

self.lstm = LSTM(
input_size=embed_dim,
hidden_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=self.dropout_out,
bidirectional=False,
bidirectional=bidirectional,
)
self.left_pad_source = left_pad_source
self.padding_value = padding_value

self.output_units = hidden_size
if bidirectional:
self.output_units *= 2

def forward(self, src_tokens, src_lengths):
if LanguagePairDataset.LEFT_PAD_SOURCE:
if self.left_pad_source:
# convert left-padding to right-padding
src_tokens = utils.convert_padding_direction(
src_tokens,
src_lengths,
self.padding_idx,
left_to_right=True,
)
Expand All @@ -131,7 +171,6 @@ def forward(self, src_tokens, src_lengths):
# embed tokens
x = self.embed_tokens(src_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
embed_dim = x.size(2)

# B x T x C -> T x B x C
x = x.transpose(0, 1)
Expand All @@ -140,17 +179,35 @@ def forward(self, src_tokens, src_lengths):
packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())

# apply LSTM
h0 = Variable(x.data.new(self.num_layers, bsz, embed_dim).zero_())
c0 = Variable(x.data.new(self.num_layers, bsz, embed_dim).zero_())
if self.bidirectional:
state_size = 2 * self.num_layers, bsz, self.hidden_size
else:
state_size = self.num_layers, bsz, self.hidden_size
h0 = Variable(x.data.new(*state_size).zero_())
c0 = Variable(x.data.new(*state_size).zero_())
packed_outs, (final_hiddens, final_cells) = self.lstm(
packed_x,
(h0, c0),
)

# unpack outputs and apply dropout
x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=0.)
x, _ = nn.utils.rnn.pad_packed_sequence(
packed_outs, padding_value=self.padding_value)
x = F.dropout(x, p=self.dropout_out, training=self.training)
assert list(x.size()) == [seqlen, bsz, embed_dim]
assert list(x.size()) == [seqlen, bsz, self.output_units]

if self.bidirectional:
bi_final_hiddens, bi_final_cells = [], []
for i in range(self.num_layers):
bi_final_hiddens.append(
torch.cat(
(final_hiddens[2 * i], final_hiddens[2 * i + 1]),
dim=0).view(bsz, self.output_units))
bi_final_cells.append(
torch.cat(
(final_cells[2 * i], final_cells[2 * i + 1]),
dim=0).view(bsz, self.output_units))
return x, bi_final_hiddens, bi_final_cells

return x, final_hiddens, final_cells

Expand All @@ -166,7 +223,7 @@ def __init__(self, input_embed_dim, output_embed_dim):
self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False)
self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False)

def forward(self, input, source_hids):
def forward(self, input, source_hids, src_lengths=None):
# input: bsz x input_embed_dim
# source_hids: srclen x bsz x output_embed_dim

Expand All @@ -186,27 +243,39 @@ def forward(self, input, source_hids):

class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
def __init__(self, dictionary, encoder_embed_dim=512,
embed_dim=512, embed_dict=None,
out_embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1, attention=True):
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512,
pretrained_embed=None,
):
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.hidden_size = hidden_size

num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
else:
self.embed_tokens = pretrained_embed

self.encoder_output_units = encoder_output_units
assert encoder_output_units == hidden_size, \
'{} {}'.format(encoder_output_units, hidden_size)
# TODO another Linear layer if not equal

self.layers = nn.ModuleList([
LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim)
LSTMCell(
input_size=encoder_output_units + embed_dim if layer == 0 else hidden_size,
hidden_size=hidden_size,
)
for layer in range(num_layers)
])
self.attention = AttentionLayer(encoder_embed_dim, embed_dim) if attention else None
if embed_dim != out_embed_dim:
self.additional_fc = Linear(embed_dim, out_embed_dim)
self.attention = AttentionLayer(encoder_output_units, hidden_size) if attention else None
if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)

def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
Expand All @@ -215,13 +284,12 @@ def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
bsz, seqlen = prev_output_tokens.size()

# get outputs from encoder
encoder_outs, _, _ = encoder_out
encoder_outs, _, _ = encoder_out[:3]
srclen = encoder_outs.size(0)

# embed tokens
x = self.embed_tokens(prev_output_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
embed_dim = x.size(2)

# B x T x C -> T x B x C
x = x.transpose(0, 1)
Expand All @@ -231,11 +299,11 @@ def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if cached_state is not None:
prev_hiddens, prev_cells, input_feed = cached_state
else:
_, encoder_hiddens, encoder_cells = encoder_out
_, encoder_hiddens, encoder_cells = encoder_out[:3]
num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)]
input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
input_feed = Variable(x.data.new(bsz, self.encoder_output_units).zero_())

attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
outs = []
Expand Down Expand Up @@ -272,7 +340,7 @@ def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed))

# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)

# T x B x C -> B x T x C
x = x.transpose(1, 0)
Expand Down Expand Up @@ -342,10 +410,13 @@ def Linear(in_features, out_features, bias=True, dropout=0):
@register_model_architecture('lstm', 'lstm')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', 512)
args.encoder_layers = getattr(args, 'encoder_layers', 1)
args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 512)
args.decoder_layers = getattr(args, 'decoder_layers', 1)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
args.decoder_attention = getattr(args, 'decoder_attention', True)
Expand All @@ -357,10 +428,13 @@ def base_architecture(args):
def lstm_wiseman_iwslt_de_en(args):
base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_hidden_size = 256
args.encoder_layers = 1
args.encoder_bidirectional = False
args.encoder_dropout_in = 0
args.encoder_dropout_out = 0
args.decoder_embed_dim = 256
args.decoder_hidden_size = 256
args.decoder_layers = 1
args.decoder_out_embed_dim = 256
args.decoder_attention = True
Expand All @@ -371,9 +445,12 @@ def lstm_wiseman_iwslt_de_en(args):
def lstm_luong_wmt_en_de(args):
base_architecture(args)
args.encoder_embed_dim = 1000
args.encoder_hidden_size = 1000
args.encoder_layers = 4
args.encoder_dropout_out = 0
args.encoder_bidirectional = False
args.decoder_embed_dim = 1000
args.decoder_hidden_size = 1000
args.decoder_layers = 4
args.decoder_out_embed_dim = 1000
args.decoder_attention = True
Expand Down
Loading

0 comments on commit ec0031d

Please sign in to comment.