Skip to content

Commit

Permalink
Update dataset API usage in distill lstm (PaddlePaddle#85)
Browse files Browse the repository at this point in the history
* update dataset usage in distill lstm

* update usage of chnsenticorp

* use map fn in data augmentation

* fix paddlenlp readme typo

* convert chnsenticorp to uppercase
  • Loading branch information
LiuChiachi authored Mar 12, 2021
1 parent a093555 commit 50cd4cf
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 174 deletions.
6 changes: 3 additions & 3 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ from paddlenlp.datasets import load_dataset
train_ds, dev_ds, test_ds = load_dataset("chnsenticorp", splits=["train", "dev", "test"])
```

### Chinese Text Emebdding Loading
### Chinese Text Embedding Loading

```python

Expand All @@ -60,7 +60,7 @@ wordemb.cosine_sim("艺术", "火车")
>>> 0.14792643
```

### Rich Chinsese Pre-trained Models
### Rich Chinese Pre-trained Models


```python
Expand Down Expand Up @@ -129,7 +129,7 @@ Please refer to our official AI Studio account for more interactive tutorials: [
* [Waybill Information Extraction with BiGRU-CRF Model](https://aistudio.baidu.com/aistudio/projectdetail/1317771) shows how to make use of Bi-GRU plus CRF to finish information extraction task.

* [Waybill Information Extraction with ERNIE](https://aistudio.baidu.com/aistudio/projectdetail/1329361) shows how to use ERNIE, the Chinese pre-trained model improve information extraction performance.

* [Use TCN Model to predict COVID-19 confirmed cases](https://aistudio.baidu.com/aistudio/projectdetail/1290873)


Expand Down
28 changes: 17 additions & 11 deletions examples/model_compression/distill_lstm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ python -u ./run_glue.py \
--num_train_epochs 3 \
--logging_steps 10 \
--save_steps 10 \
--output_dir ../model_compression/distill_lstm/pretrained_modelss/$TASK_NAME/ \
--output_dir ../model_compression/distill_lstm/pretrained_models/$TASK_NAME/ \
--n_gpu 1 \

```
Expand All @@ -81,7 +81,7 @@ python -u ./run_glue.py \

```shell
CUDA_VISIBLE_DEVICES=0 python small.py \
--task_name senta \
--task_name chnsenticorp \
--max_epoch 20 \
--vocab_size 1256608 \
--batch_size 64 \
Expand All @@ -90,7 +90,8 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
--lr 3e-4 \
--dropout_prob 0.2 \
--vocab_path senta_word_dict.txt \
--output_dir small_models/senta/
--save_steps 10000 \
--output_dir small_models/chnsenticorp/

```

Expand All @@ -103,6 +104,7 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
--lr 1.0 \
--dropout_prob 0.4 \
--output_dir small_models/SST-2 \
--save_steps 10000 \
--embedding_name w2v.google_news.target.word-word.dim300.en

```
Expand All @@ -116,6 +118,7 @@ CUDA_VISIBLE_DEVICES=0 python small.py \
--lr 2.0 \
--dropout_prob 0.4 \
--output_dir small_models/QQP \
--save_steps 10000 \
--embedding_name w2v.google_news.target.word-word.dim300.en

```
Expand All @@ -125,16 +128,17 @@ CUDA_VISIBLE_DEVICES=0 python small.py \

```shell
CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
--task_name senta \
--task_name chnsenticorp \
--vocab_size 1256608 \
--max_epoch 6 \
--lr 1.0 \
--dropout_prob 0.1 \
--batch_size 64 \
--model_name bert-wwm-ext-chinese \
--teacher_path pretrained_models/senta/best_bert_wwm_ext_model_880/model_state.pdparams \
--teacher_path pretrained_models/chnsenticorp/best_bert_wwm_ext_model_880/model_state.pdparams \
--vocab_path senta_word_dict.txt \
--output_dir distilled_models/senta
--output_dir distilled_models/chnsenticorp \
--save_steps 10000 \

```

Expand All @@ -148,9 +152,10 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
--dropout_prob 0.2 \
--batch_size 128 \
--model_name bert-base-uncased \
--embedding_name w2v.google_news.target.word-word.dim300.en \
--output_dir distilled_models/SST-2 \
--teacher_path pretrained_models/SST-2/best_model_610/model_state.pdparams
--teacher_path pretrained_models/SST-2/best_model_610/model_state.pdparams \
--save_steps 10000 \
--embedding_name w2v.google_news.target.word-word.dim300.en \

```

Expand All @@ -163,18 +168,19 @@ CUDA_VISIBLE_DEVICES=0 python bert_distill.py \
--dropout_prob 0.2 \
--batch_size 256 \
--model_name bert-base-uncased \
--embedding_name w2v.google_news.target.word-word.dim300.en \
--n_iter 10 \
--output_dir distilled_models/QQP \
--teacher_path pretrained_models/QQP/best_model_17000/model_state.pdparams
--teacher_path pretrained_models/QQP/best_model_17000/model_state.pdparams \
--save_steps 10000 \
--embedding_name w2v.google_news.target.word-word.dim300.en \

```

各参数的具体说明请参阅 `args.py` ,注意在训练不同任务时,需要调整对应的超参数。


## 蒸馏实验结果
本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、senta(中文情感分类)任务上分别有3.3%、1.9%、1.4%的提升。
本蒸馏实验基于GLUE的SST-2、QQP、中文情感分类ChnSentiCorp数据集。实验效果均使用每个数据集的验证集(dev)进行评价,评价指标是准确率(acc),其中QQP中包含f1值。利用基于BERT的教师模型去蒸馏基于Bi-LSTM的学生模型,对比Bi-LSTM小模型单独训练,在SST-2、QQP、ChnSentiCorp(中文情感分类)任务上分别有3.3%、1.9%、1.4%的提升。

| Model | SST-2(dev acc) | QQP(dev acc/f1) | ChnSentiCorp(dev acc) | ChnSentiCorp(dev acc) |
| ----------------- | ----------------- | -------------------------- | --------------------- | --------------------- |
Expand Down
11 changes: 5 additions & 6 deletions examples/model_compression/distill_lstm/bert_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@

from paddlenlp.transformers import BertForSequenceClassification
from paddlenlp.metrics import AccuracyAndF1
from paddlenlp.datasets import GlueSST2, GlueQQP, ChnSentiCorp

from args import parse_args
from small import BiLSTM
from data import create_distill_loader

TASK_CLASSES = {
"sst-2": (GlueSST2, Accuracy),
"qqp": (GlueQQP, AccuracyAndF1),
"senta": (ChnSentiCorp, Accuracy),
METRIC_CLASSES = {
"sst-2": Accuracy,
"qqp": AccuracyAndF1,
"chnsenticorp": Accuracy
}


Expand Down Expand Up @@ -98,7 +97,7 @@ def do_train(agrs):
mse_loss = nn.MSELoss()
klloss = nn.KLDivLoss()

metric_class = TASK_CLASSES[args.task_name][1]
metric_class = METRIC_CLASSES[args.task_name]
metric = metric_class()

teacher = TeacherModel(
Expand Down
Loading

0 comments on commit 50cd4cf

Please sign in to comment.