Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
yangjianxin1 committed Dec 9, 2019
1 parent f2f21ab commit c9059b9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
34 changes: 16 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# GPT2 for Chinese chitchat

## 项目描述
- 本项目使用GPT2模型对中文闲聊语料进行训练使用 HuggingFace的[transformers](https://github.com/huggingface/transformers)实现GPT2模型的编写与训练。
- 本项目使用GPT2模型对中文闲聊语料进行训练使用 HuggingFace的[transformers](https://github.com/huggingface/transformers)实现GPT2模型的编写与训练。
- 在闲暇时间用 [GPT2-Chinese](https://github.com/Morizeyao/GPT2-Chinese)模型训练了几个长文本的生成模型,并且精读了一遍作者的源码,获益匪浅,加深了自己对GPT2生成模型的一些理解,于是将GPT2模型用于闲聊对话的生成,非常感谢作者的分享。
- 解码器的逻辑可参考论文[The Curious Case of Neural Text Degeneration](https://arxiv.xilesou.top/pdf/1904.09751.pdf)
- 本项目中沿用了原项目中的部分结构和一些命名方式,同时也对很多代码细节做出了自己实现。
- 解码器的逻辑使用了Temperature、Top-k Sampling和Nucleus Sampling等,可参考论文[The Curious Case of Neural Text Degeneration](https://arxiv.xilesou.top/pdf/1904.09751.pdf)
- 代码中给出了许多详细的中文注释,方便大家更好地理解代码(能力有限,可能有些代码或注释有误,望大家不吝赐教)

## 运行环境
Expand All @@ -14,7 +14,7 @@ python3.6、 transformers==2.1.1、pytorch==1.3.1
- config:存放GPT2模型的参数的配置文件
- data
- train.txt:默认的原始训练集文件,存放闲聊语料
- train_tokenized.txt:对原始训练语料进行tokenze之后的文件
- train_tokenized.txt:对原始训练语料进行tokenize之后的文件
- model:存放模型
- sample:存放人机闲聊生成的历史聊天记录
- vocabulary:存放GPT2模型的字典
Expand Down Expand Up @@ -44,7 +44,7 @@ python train.py --epochs 30 --batch_size 8 --device 0,1 --raw(若要对原始训
python train.py --epochs 30 --batch_size 8 --no_cuda --raw(指定--no_cuda参数,则使用CPU训练,速度要慢得多)
python train.py --epochs 30 --batch_size 8 --no_cuda(当已经对原始语料进行tokenize,可以不用指定--raw,避免重复步骤)
python train.py --epochs 30 --batch_size 8 --no_cuda(若已经对原始语料进行tokenize,可以不用指定--raw,避免重复tokenize,节约时间)
```
人机交互:
``` bash
Expand All @@ -53,7 +53,7 @@ python interact.py --no_cuda --model_path path_to_your_model --max_history_len 3
更多的参数介绍,可直接看train.py和interact.py中的setup_train_args()函数中的参数说明

## 生成样例
Sample1:
#### Sample 1:

|Role | Response |
|---------|--------|
Expand Down Expand Up @@ -81,7 +81,7 @@ Sample1:
| chatbot | 不知道啊,我周末没有作业,我想睡觉,然后还得找人 |


Sample2:
#### Sample 2:

|Role | Response |
|---------|--------|
Expand All @@ -108,7 +108,7 @@ Sample2:
|user | 你个智障 |
| chatbot | 哈哈哈 |

Sample3:
#### Sample 3:

|Role | Response |
|---------|--------|
Expand All @@ -125,7 +125,7 @@ Sample3:
|user | 多喝热水 |
| chatbot | 好的咯|

Sample4:
#### Sample 4:

|Role | Response |
|---------|--------|
Expand All @@ -144,7 +144,7 @@ Sample4:
|user | 那先这样,下次再聊,再见了 |
| chatbot | 好哒|

Sample5:
#### Sample 5:

|Role | Response |
|---------|--------|
Expand All @@ -163,17 +163,17 @@ Sample5:
|中文闲聊语料 | 数据集地址 |语料描述|
|---------|--------|--------|
|常见中文闲聊|[chinese_chatbot_corpus](https://github.com/codemayq/chinese_chatbot_corpus)|包含小黄鸡语料、豆瓣语料、电视剧对白语料、贴吧论坛回帖语料、微博语料、PTT八卦语料、青云语料等|
|50w中文闲聊语料训练模型 | [百度网盘【提取码:jk8d](https://pan.baidu.com/s/1mkP59GyF9CZ8_r1F864GEQ) |由作者[GaoQ1](https://github.com/GaoQ1)提供的比较高质量的闲聊数据集,我整理成了包含50w个多轮对话的语料|
|50w中文闲聊语料训练模型 | [百度网盘【提取码:jk8d](https://pan.baidu.com/s/1mkP59GyF9CZ8_r1F864GEQ) |由作者[GaoQ1](https://github.com/GaoQ1)提供的比较高质量的闲聊数据集,整理出了50w个多轮对话的语料|

## 模型分享
|模型 | 百度网盘 |提取码|
|---------|--------|--------|
|50w中文闲聊语料训练模型 | [百度网盘](https://pan.baidu.com/s/1EZMF0QcxXBeWF8HMoNpyfQ) |gi5i|
所用闲聊训练集大小为67M,包含50w个多轮对话,用两块1080Ti,大概跑了五六天(应该没有记错),训练了40个epoch,最终loss在2.0左右,继续训练的话,loss应该还能继续下降。模型下载链接
|模型 | 百度网盘 |提取码|模型描述|
|---------|--------|--------|--------|
|50w中文闲聊语料训练模型 | [百度网盘](https://pan.baidu.com/s/1EZMF0QcxXBeWF8HMoNpyfQ) |gi5i|闲聊语料为67M,包含50w个多轮对话,用两块1080Ti,大概跑了五六天(应该没有记错),训练了40个epoch,最终loss在2.0左右,继续训练的话,loss应该还能继续下降。|


模型使用方法:把下载好的模型放在model目录下(否则需要通过--model_path参数指定模型的路径),执行如下命令:
``` bash
python interact.py --no_cuda --model_path path_to_your_model --max_history_len 3(由于闲聊对话生成的内容长度不是很长,因此生成部分在CPU上跑速度也挺快的)
python interact.py --no_cuda --model_path path_to_your_model --max_history_len 5(由于闲聊对话生成的内容长度不是很长,因此生成部分在CPU上跑速度也挺快的。根据需求调整max_history_len参数)
```
输入Ctrl+Z结束对话之后,聊天记录将保存到sample目录下的sample.txt文件中

Expand Down Expand Up @@ -266,9 +266,7 @@ chatbot偶尔也会"智商离线",生成的内容"惨不忍睹"
|user | 爱你呦芸芸 |

## Future Work
更多地在解码器上下功夫,如:

使用互信息(mutual information):训练一个额外的网络,给定一个reponse,该网络能够计算出P(Source|response),Source为response的上文。该网络的目的就是对于生成的多个response,选出P(Source|response)最大的response作为最终的回复。
更多地在解码器上下功夫,比如使用互信息(mutual information):训练一个额外的网络,给定一个reponse,该网络能够计算出概率P(Source|response),Source为response的上文。该网络的目的就是对于生成的多个response,选出P(Source|response)最大的response作为最终的回复。


## Reference
Expand Down
2 changes: 1 addition & 1 deletion interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def set_interact_args():
help="重复惩罚参数,若生成的对话重复性较高,可适当提高该参数")
parser.add_argument('--seed', type=int, default=None, help='设置种子用于生成随机数,以使得训练的结果是确定的')
parser.add_argument('--max_len', type=int, default=25, help='每个utterance的最大长度,超过指定长度则进行截断')
parser.add_argument('--max_history_len', type=int, default=3, help="dialogue history的最大长度")
parser.add_argument('--max_history_len', type=int, default=5, help="dialogue history的最大长度")
parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测')
return parser.parse_args()

Expand Down

0 comments on commit c9059b9

Please sign in to comment.