Skip to content

Commit

Permalink
Merge pull request NVIDIA#910 from szmigacz/smigacz/mha_xavier_init_g…
Browse files Browse the repository at this point in the history
…ain_fix

Fixed weight init for fused weight matrices in fused MHA by adding correct gain factor
  • Loading branch information
thorjohnsen authored Jul 16, 2020
2 parents 4027bcb + a0d99fd commit 3104fd5
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 7 additions & 1 deletion apex/contrib/multihead_attn/encdec_multihead_attn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
from torch import nn
from torch.nn import Parameter
Expand Down Expand Up @@ -76,7 +78,11 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_ad

def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_kv)
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.)
Expand Down
8 changes: 7 additions & 1 deletion apex/contrib/multihead_attn/self_multihead_attn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
from torch import nn
from torch.nn import Parameter
Expand Down Expand Up @@ -98,7 +100,11 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
nn.init.xavier_uniform_(self.in_proj_weight)
# in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
if self.separate_qkv_params:
Expand Down

0 comments on commit 3104fd5

Please sign in to comment.