Skip to content

Commit

Permalink
Fix on details about alignment matrix and attention method in MGAN mo…
Browse files Browse the repository at this point in the history
…del (songyouwei#22)

* Minor fix to support PyTorch 1.0

* Update mgan.py
  • Loading branch information
GeneZC authored and songyouwei committed Jan 16, 2019
1 parent 54b5ca1 commit 7dd9f35
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
6 changes: 3 additions & 3 deletions layers/dynamic_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def forward(self, x, x_len):
:return:
"""
"""sort"""
x_sort_idx = np.argsort(-x_len)
x_unsort_idx = torch.LongTensor(np.argsort(x_sort_idx))
x_sort_idx = torch.argsort(-x_len)
x_unsort_idx = torch.argsort(x_sort_idx).long()
x_len = x_len[x_sort_idx]
x = x[torch.LongTensor(x_sort_idx)]
x = x[x_sort_idx.long()]
"""pack"""
x_emb_p = torch.nn.utils.rnn.pack_padded_sequence(x, x_len, batch_first=self.batch_first)

Expand Down
36 changes: 30 additions & 6 deletions models/mgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

class LocationEncoding(nn.Module):
def __init__(self, opt):
self.opt = opt
super(LocationEncoding, self).__init__()
self.opt = opt

def forward(self, x, pos_inx):
batch_size, seq_len = x.size()[0], x.size()[1]
Expand All @@ -38,20 +38,42 @@ def weight_matrix(self, pos_inx, batch_size, seq_len):
weight = torch.tensor(weight)
return weight

class AlignmentMatrix(nn.Module):
def __init__(self, opt):
super(AlignmentMatrix, self).__init__()
self.opt = opt
self.w_u = nn.Parameter(torch.Tensor(6*opt.hidden_dim, 1))

def forward(self, batch_size, ctx, asp):
ctx_len = ctx.size(1)
asp_len = asp.size(1)
alignment_mat = torch.zeros(batch_size, ctx_len, asp_len).to(self.opt.device)
ctx_chunks = ctx.chunk(ctx_len, dim=1)
asp_chunks = asp.chunk(asp_len, dim=1)
for i, ctx_chunk in enumerate(ctx_chunks):
for j, asp_chunk in enumerate(asp_chunks):
feat = torch.cat([ctx_chunk, asp_chunk, ctx_chunk*asp_chunk], dim=2) # batch_size x 1 x 6*hidden_dim
alignment_mat[:, i, j] = feat.matmul(self.w_u.expand(batch_size, -1, -1)).squeeze(-1).squeeze(-1)
return alignment_mat

class MGAN(nn.Module):
def __init__(self, embedding_matrix, opt):
super(MGAN, self).__init__()
self.opt = opt
self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
self.ctx_lstm = DynamicLSTM(opt.embed_dim, opt.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
self.location = LocationEncoding(opt)
self.asp_lstm = DynamicLSTM(opt.embed_dim, opt.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
self.dense = nn.Linear(8 * opt.hidden_dim, opt.polarities_dim)
self.location = LocationEncoding(opt)
self.w_a2c = nn.Parameter(torch.Tensor(2*opt.hidden_dim, 2*opt.hidden_dim))
self.w_c2a = nn.Parameter(torch.Tensor(2*opt.hidden_dim, 2*opt.hidden_dim))
self.alignment = AlignmentMatrix(opt)
self.dense = nn.Linear(8*opt.hidden_dim, opt.polarities_dim)

def forward(self, inputs):
text_raw_indices = inputs[0] # batch_size x seq_len
aspect_indices = inputs[1]
text_left_indices= inputs[2]
batch_size = text_raw_indices.size(0)
ctx_len = torch.sum(text_raw_indices != 0, dim=1)
asp_len = torch.sum(aspect_indices != 0, dim=1)
left_len = torch.sum(text_left_indices != 0, dim=-1)
Expand All @@ -69,13 +91,15 @@ def forward(self, inputs):
asp_pool = torch.sum(asp_out, dim=1)
asp_pool = torch.div(asp_pool, asp_len.float().unsqueeze(-1)).unsqueeze(-1) # batch_size x 2*hidden_dim x 1

alignment_mat = torch.matmul(ctx_out, asp_out.transpose(1, 2)) # batch_size x (ctx)seq_len x (asp)seq_len
alignment_mat = self.alignment(batch_size, ctx_out, asp_out) # batch_size x (ctx)seq_len x (asp)seq_len
# batch_size x 2*hidden_dim
f_asp2ctx = torch.matmul(ctx_out.transpose(1, 2), F.softmax(alignment_mat.max(2, keepdim=True)[0], dim=1)).squeeze(-1)
f_ctx2asp = torch.matmul(F.softmax(alignment_mat.max(1, keepdim=True)[0], dim=2), asp_out).transpose(1, 2).squeeze(-1)

c_asp2ctx = torch.matmul(ctx_out.transpose(1, 2), F.softmax(torch.matmul(ctx_out, asp_pool), dim=1)).squeeze(-1)
c_ctx2asp = torch.matmul(asp_out.transpose(1, 2), F.softmax(torch.matmul(asp_out, ctx_pool), dim=1)).squeeze(-1)
c_asp2ctx_alpha = F.softmax(ctx_out.matmul(self.w_a2c.expand(batch_size, -1, -1)).matmul(asp_pool), dim=1)
c_asp2ctx = torch.matmul(ctx_out.transpose(1, 2), c_asp2ctx_alpha).squeeze(-1)
c_ctx2asp_alpha = F.softmax(asp_out.matmul(self.w_c2a.expand(batch_size, -1, -1)).matmul(ctx_pool), dim=1)
c_ctx2asp = torch.matmul(asp_out.transpose(1, 2), c_ctx2asp_alpha).squeeze(-1)

feat = torch.cat([c_asp2ctx, f_asp2ctx, f_ctx2asp, c_ctx2asp], dim=1)
out = self.dense(feat) # bathc_size x polarity_dim
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _evaluate_acc_f1(self):
t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)

test_acc = n_test_correct / n_test_total
f1 = metrics.f1_score(t_targets_all, torch.argmax(t_outputs_all, -1), labels=[0, 1, 2], average='macro')
f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2], average='macro')
return test_acc, f1

def run(self, repeats=1):
Expand Down

0 comments on commit 7dd9f35

Please sign in to comment.