Skip to content

Commit

Permalink
Attention layer: add dropout opt, add default reset_parameters, add o…
Browse files Browse the repository at this point in the history
…ut_dim, rename "SelfAttention" to "NoQueryAttention"
  • Loading branch information
songyouwei committed Jun 26, 2018
1 parent 23f8406 commit 4d9a274
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
38 changes: 28 additions & 10 deletions layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,42 @@


class Attention(nn.Module):
def __init__(self, embed_dim, hidden_dim=None, n_head=1, score_function='scaled_dot_product', dropout=0.1):
def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='scaled_dot_product', dropout=0):
''' Attention Mechanism
:param embed_dim:
:param hidden_dim:
:param out_dim:
:param n_head: num of head (Multi-Head Attention)
:param score_function: scaled_dot_product / mlp (concat) / bi_linear (general dot)
:return (?, q_len, out_dim,)
'''
super(Attention, self).__init__()
if hidden_dim is None:
hidden_dim = embed_dim // n_head
if out_dim is None:
out_dim = embed_dim
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.n_head = n_head
self.score_function = score_function
self.w_kx = nn.Parameter(torch.FloatTensor(n_head, embed_dim, hidden_dim))
self.w_qx = nn.Parameter(torch.FloatTensor(n_head, embed_dim, hidden_dim))
self.proj = nn.Linear(n_head * hidden_dim, embed_dim)
self.proj = nn.Linear(n_head * hidden_dim, out_dim)
self.dropout = nn.Dropout(dropout)
if score_function == 'mlp':
self.weight = nn.Parameter(torch.Tensor(hidden_dim*2, 1))
self.weight = nn.Parameter(torch.Tensor(hidden_dim*2))
elif self.score_function == 'bi_linear':
self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
else:
self.register_parameter('weight', None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.hidden_dim)
self.w_kx.data.uniform_(-stdv, stdv)
self.w_qx.data.uniform_(-stdv, stdv)
if self.weight is not None:
self.weight.data.uniform_(-stdv, stdv)

def forward(self, k, q):
if len(q.shape) == 2: # q_len missing
Expand All @@ -48,7 +60,7 @@ def forward(self, k, q):
# kx: (n_head, ?*k_len, embed_dim) -> (n_head*?, k_len, hidden_dim)
# qx: (n_head, ?*q_len, embed_dim) -> (n_head*?, q_len, hidden_dim)
# score: (n_head*?, q_len, k_len,)
# output: (?, q_len, embed_dim,)
# output: (?, q_len, out_dim,)
kx = k.repeat(self.n_head, 1, 1).view(self.n_head, -1, self.embed_dim) # (n_head, ?*k_len, embed_dim)
qx = q.repeat(self.n_head, 1, 1).view(self.n_head, -1, self.embed_dim) # (n_head, ?*q_len, embed_dim)
kx = torch.bmm(kx, self.w_kx).view(-1, k_len, self.hidden_dim) # (n_head*?, k_len, hidden_dim)
Expand All @@ -61,7 +73,7 @@ def forward(self, k, q):
kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1)
qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1)
kq = torch.cat((kxx, qxx), dim=-1) # (n_head*?, q_len, k_len, hidden_dim*2)
score = F.tanh(torch.matmul(kq, self.weight).squeeze(dim=-1))
score = F.tanh(torch.matmul(kq, self.weight))
elif self.score_function == 'bi_linear':
qw = torch.matmul(qx, self.weight)
kt = kx.permute(0, 2, 1)
Expand All @@ -71,19 +83,25 @@ def forward(self, k, q):
score = F.softmax(score, dim=-1)
output = torch.bmm(score, kx) # (n_head*?, q_len, hidden_dim)
output = torch.cat(torch.split(output, mb_size, dim=0), dim=-1) # (?, q_len, n_head*hidden_dim)
output = self.proj(output) # (?, q_len, embed_dim)
output = self.proj(output) # (?, q_len, out_dim)
output = self.dropout(output)
return output


class SelfAttention(Attention):
class NoQueryAttention(Attention):
'''q is a parameter'''
def __init__(self, embed_dim, hidden_dim=None, n_head=1, score_function='scaled_dot_product', q_len=1, dropout=0.1):
super(SelfAttention, self).__init__(embed_dim, hidden_dim, n_head, score_function, dropout)
def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='scaled_dot_product', q_len=1, dropout=0):
super(NoQueryAttention, self).__init__(embed_dim, hidden_dim, out_dim, n_head, score_function, dropout)
self.q_len = q_len
self.q = nn.Parameter(torch.FloatTensor(q_len, embed_dim))
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.embed_dim)
self.q.data.uniform_(-stdv, stdv)
super(NoQueryAttention, self).reset_parameters()

def forward(self, k, **kwargs):
mb_size = k.shape[0]
q = self.q.expand(mb_size, -1, -1)
return super(SelfAttention, self).forward(k, q)
return super(NoQueryAttention, self).forward(k, q)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def run(self):
parser.add_argument('--optimizer', default='adam', type=str)
parser.add_argument('--initializer', default='xavier_uniform_', type=str)
parser.add_argument('--learning_rate', default=0.001, type=float)
parser.add_argument('--dropout', default=0, type=float)
parser.add_argument('--num_epoch', default=20, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--log_step', default=5, type=int)
Expand Down

0 comments on commit 4d9a274

Please sign in to comment.