Skip to content

Commit

Permalink
Fix on memnet, ram and tnet (songyouwei#16)
Browse files Browse the repository at this point in the history
* Minor fix on TNET-LF to make it in accordance with original implementation

* Fix on positioned weight of ram and memnet

* Upload AoA model

* AOA typo fix
  • Loading branch information
GeneZC authored and songyouwei committed Dec 8, 2018
1 parent 8a2de4e commit 62d372c
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 20 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ tensorboard --logdir=./ian_logs

## Implemented models

### AOA ([aoa.py](./models/aoa.py)) [[pdf]](https://arxiv.org/pdf/1804.06536.pdf)
Huang, Binxuan, et al. "Aspect Level Sentiment Classification with Attention-over-Attention Neural Networks." arXiv preprint arXiv:1804.06536 (2018).

![aoa](assets/aoa.png)

### TNet ([tnet_lf.py](./models/tnet_lf.py)) [[pdf]](https://arxiv.org/pdf/1805.01086)
Li, Xin, et al. "Transformation Networks for Target-Oriented Sentiment Classification." arXiv preprint arXiv:1805.01086 (2018).

Expand Down
Binary file added assets/aoa.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from models.cabasc import Cabasc
from models.atae_lstm import ATAE_LSTM
from models.tnet_lf import TNet_LF
from models.aoa import AOA
38 changes: 38 additions & 0 deletions models/aoa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
# file: aoa.py
# author: gene_zc <[email protected]>
# Copyright (C) 2018. All Rights Reserved.

from layers.dynamic_rnn import DynamicLSTM
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class AOA(nn.Module):
def __init__(self, embedding_matrix, opt):
super(AOA, self).__init__()
self.opt = opt
self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
self.ctx_lstm = DynamicLSTM(opt.embed_dim, opt.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
self.asp_lstm = DynamicLSTM(opt.embed_dim, opt.hidden_dim, num_layers=1, batch_first=True, bidirectional=True)
self.dense = nn.Linear(2 * opt.hidden_dim, opt.polarities_dim)

def forward(self, inputs):
text_raw_indices = inputs[0] # batch_size x seq_len
aspect_indices = inputs[1] # batch_size x seq_len
ctx_len = torch.sum(text_raw_indices != 0, dim=1)
asp_len = torch.sum(aspect_indices != 0, dim=1)
ctx = self.embed(text_raw_indices) # batch_size x seq_len x embed_dim
asp = self.embed(aspect_indices) # batch_size x seq_len x embed_dim
ctx_out, (_, _) = self.ctx_lstm(ctx, ctx_len) # batch_size x (ctx) seq_len x 2*hidden_dim
asp_out, (_, _) = self.asp_lstm(asp, asp_len) # batch_size x (asp) seq_len x 2*hidden_dim
interaction_mat = torch.matmul(ctx_out, torch.transpose(asp_out, 1, 2)) # batch_size x (ctx) seq_len x (asp) seq_len
alpha = F.softmax(interaction_mat, dim=1) # col-wise, batch_size x (ctx) seq_len x (asp) seq_len
beta = F.softmax(interaction_mat, dim=2) # row-wise, batch_size x (ctx) seq_len x (asp) seq_len
beta_avg = beta.mean(dim=1, keepdim=True) # batch_size x 1 x (asp) seq_len
gamma = torch.matmul(alpha, beta_avg.transpose(1, 2)) # batch_size x (ctx) seq_len x 1
weighted_sum = torch.matmul(torch.transpose(ctx_out, 1, 2), gamma).squeeze(-1) # batch_size x 2*hidden_dim
out = self.dense(weighted_sum) # batch_size x polarity_dim

return out
13 changes: 3 additions & 10 deletions models/memnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@

class MemNet(nn.Module):

def locationed_memory(self, memory, memory_len, left_len, aspect_len):
def locationed_memory(self, memory, memory_len):
# here we just simply calculate the location vector in Model2's manner
'''
Updated to calculate location as the absolute diference between context word and aspect
'''
for i in range(memory.size(0)):
for idx in range(memory_len[i]):
aspect_start = left_len[i] - aspect_len[i]
if idx < aspect_start: l = aspect_start.item() - idx # l = absolute distance to the aspect
else: l = idx +1 - aspect_start.item()
memory[i][idx] *= (1-float(l)/int(memory_len[i]))

memory[i][idx] *= (1-float(idx)/int(memory_len[i]))
return memory

def __init__(self, embedding_matrix, opt):
Expand All @@ -44,7 +37,7 @@ def forward(self, inputs):

memory = self.embed(text_raw_without_aspect_indices)
memory = self.squeeze_embedding(memory, memory_len)
# memory = self.locationed_memory(memory, memory_len, left_len, aspect_len)
# memory = self.locationed_memory(memory, memory_len)
aspect = self.embed(aspect_indices)
aspect = torch.sum(aspect, dim=1)
aspect = torch.div(aspect, nonzeros_aspect.view(nonzeros_aspect.size(0), 1))
Expand Down
13 changes: 9 additions & 4 deletions models/ram.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@


class RAM(nn.Module):
def locationed_memory(self, memory, memory_len):
def locationed_memory(self, memory, memory_len, left_len):
# here we just simply calculate the location vector in Model2's manner
for i in range(memory.size(0)):
for idx in range(memory_len[i]):
memory[i][idx] *= (1-float(idx)/int(memory_len[i]))
aspect_start = left_len[i]
if idx < aspect_start: l = aspect_start.item() - idx # l = absolute distance to the aspect
else: l = idx +1 - aspect_start.item()
memory[i][idx] *= (1-float(l)/int(memory_len[i]))

return memory

def __init__(self, embedding_matrix, opt):
Expand All @@ -28,14 +32,15 @@ def __init__(self, embedding_matrix, opt):
self.dense = nn.Linear(opt.hidden_dim*2, opt.polarities_dim)

def forward(self, inputs):
text_raw_indices, aspect_indices = inputs[0], inputs[1]
text_raw_indices, aspect_indices, text_left_indices = inputs[0], inputs[1], inputs[2]
left_len = torch.sum(text_left_indices != 0, dim=-1)
memory_len = torch.sum(text_raw_indices != 0, dim=-1)
aspect_len = torch.sum(aspect_indices != 0, dim=-1)
nonzeros_aspect = torch.tensor(aspect_len, dtype=torch.float).to(self.opt.device)

memory = self.embed(text_raw_indices)
memory, (_, _) = self.bi_lstm_context(memory, memory_len)
# memory = self.locationed_memory(memory, memory_len)
# memory = self.locationed_memory(memory, memory_len, left_len)
aspect = self.embed(aspect_indices)
aspect, (_, _) = self.bi_lstm_aspect(aspect, aspect_len)
aspect = torch.sum(aspect, dim=1)
Expand Down
21 changes: 18 additions & 3 deletions models/tnet_lf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def forward(self, x, pos_inx):
x = weight.unsqueeze(2) * x
return x

'''
def weight_matrix(self, pos_inx, batch_size, seq_len):
pos_inx = pos_inx.cpu().numpy()
weight = [[] for i in range(batch_size)]
Expand All @@ -39,6 +40,20 @@ def weight_matrix(self, pos_inx, batch_size, seq_len):
weight[i].append(1 - relative_pos / sentence_len)
weight = torch.tensor(weight)
return weight
'''

def weight_matrix(self, pos_inx, batch_size, seq_len):
pos_inx = pos_inx.cpu().numpy()
weight = [[] for i in range(batch_size)]
for i in range(batch_size):
for j in range(pos_inx[i][1]):
relative_pos = pos_inx[i][1] - j
weight[i].append(1 - relative_pos / 40)
for j in range(pos_inx[i][1], seq_len):
relative_pos = j - pos_inx[i][0]
weight[i].append(1 - relative_pos / 40)
weight = torch.tensor(weight)
return weight

class TNet_LF(nn.Module):
def __init__(self, embedding_matrix, opt):
Expand Down Expand Up @@ -71,10 +86,10 @@ def forward(self, inputs):
a = F.softmax(a, 1) # (aspect_len,context_len)
aspect_mid = torch.bmm(e, a)
aspect_mid = torch.cat((aspect_mid, v), dim=1).transpose(1, 2)
aspect_mid = self.fc1(aspect_mid).transpose(1, 2)
aspect_mid = F.relu(self.fc1(aspect_mid).transpose(1, 2))
v = aspect_mid + v
z = F.relu(self.convs3(
self.position(v.transpose(1, 2), aspect_in_text).transpose(1, 2))) # [(N,Co,L), ...]*len(Ks)
v = self.position(v.transpose(1, 2), aspect_in_text).transpose(1, 2)
z = F.relu(self.convs3(v)) # [(N,Co,L), ...]*len(Ks)
z = F.max_pool1d(z, z.size(2)).squeeze(2)
out = self.fc(z)
return out
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import argparse
import math

from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF
from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA


class Instructor:
Expand Down Expand Up @@ -169,16 +169,18 @@ def run(self, repeats=1):
'ram': RAM,
'cabasc': Cabasc,
'tnet_lf': TNet_LF,
'aoa': AOA,
}
input_colses = {
'lstm': ['text_raw_indices'],
'td_lstm': ['text_left_with_aspect_indices', 'text_right_with_aspect_indices'],
'atae_lstm': ['text_raw_indices', 'aspect_indices'],
'ian': ['text_raw_indices', 'aspect_indices'],
'memnet': ['text_raw_without_aspect_indices', 'aspect_indices', 'text_left_with_aspect_indices'],
'ram': ['text_raw_indices', 'aspect_indices'],
'memnet': ['text_raw_without_aspect_indices', 'aspect_indices'],
'ram': ['text_raw_indices', 'aspect_indices', 'text_left_indices'],
'cabasc': ['text_raw_indices', 'aspect_indices', 'text_left_with_aspect_indices', 'text_right_with_aspect_indices'],
'tnet_lf': ['text_raw_indices', 'aspect_indices', 'aspect_in_text'],
'aoa': ['text_raw_indices', 'aspect_indices']
}
initializers = {
'xavier_uniform_': torch.nn.init.xavier_uniform_,
Expand Down

0 comments on commit 62d372c

Please sign in to comment.