Skip to content

Commit

Permalink
Fixed weight init for fused weight matrices in fused MHA by adding co…
Browse files Browse the repository at this point in the history
…rrect gain factor.
  • Loading branch information
szmigacz committed Jul 9, 2020
1 parent 1ff54b8 commit a0d99fd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 7 additions & 1 deletion apex/contrib/multihead_attn/encdec_multihead_attn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
from torch import nn
from torch.nn import Parameter
Expand Down Expand Up @@ -76,7 +78,11 @@ def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_ad

def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_kv)
# in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
# therefore xavier_uniform gain should be set to sqrt(1.5).
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.)
Expand Down
8 changes: 7 additions & 1 deletion apex/contrib/multihead_attn/self_multihead_attn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import torch
from torch import nn
from torch.nn import Parameter
Expand Down Expand Up @@ -98,7 +100,11 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
nn.init.xavier_uniform_(self.in_proj_weight)
# in_proj_weight has shape [3 * hidden, hidden] but it should be
# initialized like a [hidden, hidden] matrix.
# sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2)
# therefore xavier_uniform gain should be set to sqrt(2).
nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
if self.separate_qkv_params:
Expand Down

0 comments on commit a0d99fd

Please sign in to comment.