Skip to content

Commit

Permalink
add BERT-based models (songyouwei#29)
Browse files Browse the repository at this point in the history
* add AEN model

* UPDATE AEN-BERT

* update README
  • Loading branch information
songyouwei authored Apr 2, 2019
1 parent 47c1e37 commit 943f80b
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 8 deletions.
35 changes: 27 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,63 @@
>
> 基于方面的情感分析,使用PyTorch实现。
![Packagist](https://img.shields.io/packagist/l/doctrine/orm.svg) ![PRsWelcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg)
![LICENSE](https://img.shields.io/packagist/l/doctrine/orm.svg)
![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)
[![GitHub stars](https://img.shields.io/github/stars/songyouwei/ABSA-PyTorch.svg?logo=github)](https://github.com/songyouwei/ABSA-PyTorch/stargazers)

## Requirement

* pytorch >= 0.4.0
* numpy 1.13.3
* tensorboardX 1.2
* python 3.6
* GloVe pre-trained word vectors (See `data_utils.py` for more detail)
* numpy >= 1.13.3
* tensorboardX >= 1.2
* python 3.6 / 3.7
* GloVe pre-trained word vectors (See [data_utils.py](./data_utils.py) for more detail)
* Download pre-trained word vectors [here](https://github.com/stanfordnlp/GloVe#download-pre-trained-word-vectors),
* extract the [glove.twitter.27B.zip](http://nlp.stanford.edu/data/wordvecs/glove.twitter.27B.zip) and [glove.42B.300d.zip](http://nlp.stanford.edu/data/wordvecs/glove.42B.300d.zip) to the root directory
* pytorch-pretrained-bert 0.6.1
* See [pytorch-pretrained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT) for more detail.

## Usage

### Training

```sh
python train.py --model_name ian --dataset twitter --logdir ian_logs
python train.py --model_name bert_spc --dataset restaurant --logdir bert_spc_logs
```

See [train.py](./train.py) for more detail.

#### See the training process (needs to install TensorFlow)

```sh
tensorboard --logdir=./ian_logs
tensorboard --logdir=./bert_spc_logs
```

### Inference

Please refer to [infer_example.py](./infer_example.py).

## Implemented models
### Tips

* BERT-based models are more sensitive to hyperparameters (especially learning rate) on small data sets, see [this issue](https://github.com/songyouwei/ABSA-PyTorch/issues/27).
* Fine-tuning on the specific task is necessary for releasing the true power of BERT.
* Non-RNN models squeezed with [squeeze_embedding.py](./layers/squeeze_embedding.py) can be trained with larger batch size.

## BERT-based models

### AEN / AEN-BERT ([aen.py](./models/aen.py))
Song, Youwei, et al. "Attentional Encoder Network for Targeted Sentiment Classification." arXiv preprint arXiv:1902.09314 (2019). [[pdf]](https://arxiv.org/pdf/1902.09314.pdf)

![aen](assets/aen.png)

### BERT for Sentence Pair Classification ([bert_spc.py](./models/bert_spc.py))
Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018). [[pdf]](https://arxiv.org/pdf/1810.04805.pdf)

![bert_spc](assets/bert_spc.png)


## Non-BERT-based models

### MGAN ([mgan.py](./models/mgan.py))
Fan, Feifan, et al. "Multi-grained Attention Network for Aspect-Level Sentiment Classification." Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing. 2018. [[pdf]](http://aclweb.org/anthology/D18-1380)

Expand Down
Binary file added assets/aen.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,14 @@ def __init__(self, fname, tokenizer):
bert_segments_ids = np.asarray([0] * (np.sum(text_raw_indices != 0) + 2) + [1] * (aspect_len + 1))
bert_segments_ids = pad_and_truncate(bert_segments_ids, tokenizer.max_seq_len)

text_raw_bert_indices = tokenizer.text_to_sequence("[CLS] " + text_left + " " + aspect + " " + text_right + " [SEP]")
aspect_bert_indices = tokenizer.text_to_sequence("[CLS] " + aspect + " [SEP]")

data = {
'text_bert_indices': text_bert_indices,
'bert_segments_ids': bert_segments_ids,
'text_raw_bert_indices': text_raw_bert_indices,
'aspect_bert_indices': aspect_bert_indices,
'text_raw_indices': text_raw_indices,
'text_raw_without_aspect_indices': text_raw_without_aspect_indices,
'text_left_indices': text_left_indices,
Expand Down
129 changes: 129 additions & 0 deletions models/aen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
# file: aen.py
# author: songyouwei <[email protected]>
# Copyright (C) 2018. All Rights Reserved.

from layers.dynamic_rnn import DynamicLSTM
from layers.squeeze_embedding import SqueezeEmbedding
from layers.attention import Attention, NoQueryAttention
from layers.point_wise_feed_forward import PositionwiseFeedForward
import torch
import torch.nn as nn
import torch.nn.functional as F


# CrossEntropyLoss for Label Smoothing Regularization
class CrossEntropyLoss_LSR(nn.Module):
def __init__(self, device, para_LSR=0.2):
super(CrossEntropyLoss_LSR, self).__init__()
self.para_LSR = para_LSR
self.device = device
self.logSoftmax = nn.LogSoftmax(dim=-1)

def _toOneHot_smooth(self, label, batchsize, classes):
prob = self.para_LSR * 1.0 / classes
one_hot_label = torch.zeros(batchsize, classes) + prob
for i in range(batchsize):
index = label[i]
one_hot_label[i, index] += (1.0 - self.para_LSR)
return one_hot_label

def forward(self, pre, label, size_average=True):
b, c = pre.size()
one_hot_label = self._toOneHot_smooth(label, b, c).to(self.device)
loss = torch.sum(-one_hot_label * self.logSoftmax(pre), dim=1)
if size_average:
return torch.mean(loss)
else:
return torch.sum(loss)


class AEN(nn.Module):
def __init__(self, embedding_matrix, opt):
super(AEN, self).__init__()
self.opt = opt
self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
self.squeeze_embedding = SqueezeEmbedding()

self.attn_k = Attention(opt.embed_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.attn_q = Attention(opt.embed_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.ffn_c = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.ffn_t = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)

self.attn_s1 = Attention(opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)

self.dense = nn.Linear(opt.hidden_dim*3, opt.polarities_dim)

def forward(self, inputs):
text_raw_indices, target_indices = inputs[0], inputs[1]
context_len = torch.sum(text_raw_indices != 0, dim=-1)
target_len = torch.sum(target_indices != 0, dim=-1)
context = self.embed(text_raw_indices)
context = self.squeeze_embedding(context, context_len)
target = self.embed(target_indices)
target = self.squeeze_embedding(target, target_len)

hc, _ = self.attn_k(context, context)
hc = self.ffn_c(hc)
ht, _ = self.attn_q(context, target)
ht = self.ffn_t(ht)

s1, _ = self.attn_s1(hc, ht)

context_len = torch.tensor(context_len, dtype=torch.float).to(self.opt.device)
target_len = torch.tensor(target_len, dtype=torch.float).to(self.opt.device)

hc_mean = torch.div(torch.sum(hc, dim=1), context_len.view(context_len.size(0), 1))
ht_mean = torch.div(torch.sum(ht, dim=1), target_len.view(target_len.size(0), 1))
s1_mean = torch.div(torch.sum(s1, dim=1), context_len.view(context_len.size(0), 1))

x = torch.cat((hc_mean, s1_mean, ht_mean), dim=-1)
out = self.dense(x)
return out


class AEN_BERT(nn.Module):
def __init__(self, bert, opt):
super(AEN_BERT, self).__init__()
self.opt = opt
self.bert = bert
self.squeeze_embedding = SqueezeEmbedding()
self.dropout = nn.Dropout(opt.dropout)

self.attn_k = Attention(opt.bert_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.attn_q = Attention(opt.bert_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.ffn_c = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.ffn_t = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)

self.attn_s1 = Attention(opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)

self.dense = nn.Linear(opt.hidden_dim*3, opt.polarities_dim)

def forward(self, inputs):
context, target = inputs[0], inputs[1]
context_len = torch.sum(context != 0, dim=-1)
target_len = torch.sum(target != 0, dim=-1)
context = self.squeeze_embedding(context, context_len)
context, _ = self.bert(context, output_all_encoded_layers=False)
context = self.dropout(context)
target = self.squeeze_embedding(target, target_len)
target, _ = self.bert(target, output_all_encoded_layers=False)
target = self.dropout(target)

hc, _ = self.attn_k(context, context)
hc = self.ffn_c(hc)
ht, _ = self.attn_q(context, target)
ht = self.ffn_t(ht)

s1, _ = self.attn_s1(hc, ht)

context_len = torch.tensor(context_len, dtype=torch.float).to(self.opt.device)
target_len = torch.tensor(target_len, dtype=torch.float).to(self.opt.device)

hc_mean = torch.div(torch.sum(hc, dim=1), context_len.view(context_len.size(0), 1))
ht_mean = torch.div(torch.sum(ht, dim=1), target_len.view(target_len.size(0), 1))
s1_mean = torch.div(torch.sum(s1, dim=1), context_len.view(context_len.size(0), 1))

x = torch.cat((hc_mean, s1_mean, ht_mean), dim=-1)
out = self.dense(x)
return out
5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from data_utils import build_tokenizer, build_embedding_matrix, Tokenizer4Bert, ABSADataset

from models import LSTM, IAN, MemNet, RAM, TD_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA, MGAN
from models.aen import CrossEntropyLoss_LSR, AEN, AEN_BERT
from models.bert_spc import BERT_SPC


Expand Down Expand Up @@ -203,6 +204,8 @@ def run(self, repeats=1):
'aoa': AOA,
'mgan': MGAN,
'bert_spc': BERT_SPC,
'aen': AEN,
'aen_bert': AEN_BERT,
}
dataset_files = {
'twitter': {
Expand Down Expand Up @@ -230,6 +233,8 @@ def run(self, repeats=1):
'aoa': ['text_raw_indices', 'aspect_indices'],
'mgan': ['text_raw_indices', 'aspect_indices', 'text_left_indices'],
'bert_spc': ['text_bert_indices', 'bert_segments_ids'],
'aen': ['text_raw_indices', 'aspect_indices'],
'aen_bert' : ['text_raw_bert_indices', 'aspect_bert_indices'],
}
initializers = {
'xavier_uniform_': torch.nn.init.xavier_uniform_,
Expand Down

0 comments on commit 943f80b

Please sign in to comment.