Skip to content

Commit

Permalink
add end-of-stack normalizations in case normalize_before has been set (
Browse files Browse the repository at this point in the history
  • Loading branch information
ngimel authored and myleott committed Aug 16, 2018
1 parent f7f2dd0 commit fedc55e
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def __init__(self, args, dictionary, embed_tokens, left_pad=True):
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)

def forward(self, src_tokens, src_lengths):
# embed tokens and positions
Expand All @@ -220,6 +224,9 @@ def forward(self, src_tokens, src_lengths):
for layer in self.layers:
x = layer(x, encoder_padding_mask)

if self.normalize:
x = self.layer_norm(x)

return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
Expand All @@ -245,6 +252,11 @@ def upgrade_state_dict(self, state_dict):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
if state_dict.get('encoder.version', torch.Tensor([1]))[0] < 2:
#earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict['encoder.version'] = torch.Tensor([1])
return state_dict


Expand Down Expand Up @@ -285,6 +297,10 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_p
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)

def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# embed positions
Expand Down Expand Up @@ -317,6 +333,9 @@ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
incremental_state,
)

if self.normalize:
x = self.layer_norm(x)

# T x B x C -> B x T x C
x = x.transpose(0, 1)

Expand Down Expand Up @@ -354,6 +373,12 @@ def upgrade_state_dict(self, state_dict):
if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
del state_dict[k]
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
#earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict['decoder.version'] = torch.Tensor([1])


return state_dict

Expand Down

0 comments on commit fedc55e

Please sign in to comment.