Skip to content

Commit 84cce6f

Browse files
author
李闯
committed
add chatbotv5
1 parent c844ece commit 84cce6f

File tree

4 files changed

+2295
-0
lines changed

4 files changed

+2295
-0
lines changed

chatbotv5/demo.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# coding:utf-8
2+
# author: lichuang
3+
4+
import sys
5+
import numpy as np
6+
import tensorflow as tf
7+
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq
8+
import word_token
9+
import jieba
10+
11+
# 输入序列长度
12+
input_seq_len = 5
13+
# 输出序列长度
14+
output_seq_len = 5
15+
# 空值填充0
16+
PAD_ID = 0
17+
# 输出序列起始标记
18+
GO_ID = 1
19+
# 结尾标记
20+
EOS_ID = 2
21+
# LSTM神经元size
22+
size = 8
23+
# 最大输入符号数
24+
num_encoder_symbols = 32
25+
# 最大输出符号数
26+
num_decoder_symbols = 32
27+
# 初始学习率
28+
init_learning_rate = 1
29+
30+
wordToken = word_token.WordToken()
31+
32+
# 放在全局的位置,为了动态算出num_encoder_symbols和num_decoder_symbols
33+
max_token_id = wordToken.load_file_list(['./samples/question', './samples/answer'])
34+
num_encoder_symbols = max_token_id + 5
35+
num_decoder_symbols = max_token_id + 5
36+
37+
38+
def get_id_list_from(sentence):
39+
sentence_id_list = []
40+
seg_list = jieba.cut(sentence)
41+
for str in seg_list:
42+
id = wordToken.word2id(str)
43+
if id:
44+
sentence_id_list.append(wordToken.word2id(str))
45+
return sentence_id_list
46+
47+
48+
def get_train_set():
49+
global num_encoder_symbols, num_decoder_symbols
50+
train_set = []
51+
with open('./samples/question', 'r') as question_file:
52+
with open('./samples/answer', 'r') as answer_file:
53+
while True:
54+
question = question_file.readline()
55+
answer = answer_file.readline()
56+
if question and answer:
57+
question = question.strip()
58+
answer = answer.strip()
59+
60+
question_id_list = get_id_list_from(question)
61+
answer_id_list = get_id_list_from(answer)
62+
answer_id_list.append(EOS_ID)
63+
train_set.append([question_id_list, answer_id_list])
64+
else:
65+
break
66+
return train_set
67+
68+
69+
def get_samples(train_set):
70+
"""构造样本数据
71+
72+
:return:
73+
encoder_inputs: [array([0, 0], dtype=int32), array([0, 0], dtype=int32), array([5, 5], dtype=int32),
74+
array([7, 7], dtype=int32), array([9, 9], dtype=int32)]
75+
decoder_inputs: [array([1, 1], dtype=int32), array([11, 11], dtype=int32), array([13, 13], dtype=int32),
76+
array([15, 15], dtype=int32), array([2, 2], dtype=int32)]
77+
"""
78+
# train_set = [[[5, 7, 9], [11, 13, 15, EOS_ID]], [[7, 9, 11], [13, 15, 17, EOS_ID]], [[15, 17, 19], [21, 23, 25, EOS_ID]]]
79+
raw_encoder_input = []
80+
raw_decoder_input = []
81+
for sample in train_set:
82+
raw_encoder_input.append([PAD_ID] * (input_seq_len - len(sample[0])) + sample[0])
83+
raw_decoder_input.append([GO_ID] + sample[1] + [PAD_ID] * (output_seq_len - len(sample[1]) - 1))
84+
85+
encoder_inputs = []
86+
decoder_inputs = []
87+
target_weights = []
88+
89+
for length_idx in xrange(input_seq_len):
90+
encoder_inputs.append(np.array([encoder_input[length_idx] for encoder_input in raw_encoder_input], dtype=np.int32))
91+
for length_idx in xrange(output_seq_len):
92+
decoder_inputs.append(np.array([decoder_input[length_idx] for decoder_input in raw_decoder_input], dtype=np.int32))
93+
target_weights.append(np.array([
94+
0.0 if length_idx == output_seq_len - 1 or decoder_input[length_idx] == PAD_ID else 1.0 for decoder_input in raw_decoder_input
95+
], dtype=np.float32))
96+
return encoder_inputs, decoder_inputs, target_weights
97+
98+
99+
def seq_to_encoder(input_seq):
100+
"""从输入空格分隔的数字id串,转成预测用的encoder、decoder、target_weight等
101+
"""
102+
input_seq_array = [int(v) for v in input_seq.split()]
103+
encoder_input = [PAD_ID] * (input_seq_len - len(input_seq_array)) + input_seq_array
104+
decoder_input = [GO_ID] + [PAD_ID] * (output_seq_len - 1)
105+
encoder_inputs = [np.array([v], dtype=np.int32) for v in encoder_input]
106+
decoder_inputs = [np.array([v], dtype=np.int32) for v in decoder_input]
107+
target_weights = [np.array([1.0], dtype=np.float32)] * output_seq_len
108+
return encoder_inputs, decoder_inputs, target_weights
109+
110+
111+
def get_model(feed_previous=False):
112+
"""构造模型
113+
"""
114+
115+
learning_rate = tf.Variable(float(init_learning_rate), trainable=False, dtype=tf.float32)
116+
learning_rate_decay_op = learning_rate.assign(learning_rate * 0.9)
117+
118+
encoder_inputs = []
119+
decoder_inputs = []
120+
target_weights = []
121+
for i in xrange(input_seq_len):
122+
encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="encoder{0}".format(i)))
123+
for i in xrange(output_seq_len + 1):
124+
decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="decoder{0}".format(i)))
125+
for i in xrange(output_seq_len):
126+
target_weights.append(tf.placeholder(tf.float32, shape=[None], name="weight{0}".format(i)))
127+
128+
# decoder_inputs左移一个时序作为targets
129+
targets = [decoder_inputs[i + 1] for i in xrange(output_seq_len)]
130+
131+
cell = tf.contrib.rnn.BasicLSTMCell(size)
132+
133+
# 这里输出的状态我们不需要
134+
outputs, _ = seq2seq.embedding_attention_seq2seq(
135+
encoder_inputs,
136+
decoder_inputs[:output_seq_len],
137+
cell,
138+
num_encoder_symbols=num_encoder_symbols,
139+
num_decoder_symbols=num_decoder_symbols,
140+
embedding_size=size,
141+
output_projection=None,
142+
feed_previous=feed_previous,
143+
dtype=tf.float32)
144+
145+
# 计算加权交叉熵损失
146+
loss = seq2seq.sequence_loss(outputs, targets, target_weights)
147+
# 梯度下降优化器
148+
opt = tf.train.GradientDescentOptimizer(learning_rate)
149+
# 优化目标:让loss最小化
150+
update = opt.apply_gradients(opt.compute_gradients(loss))
151+
# 模型持久化
152+
saver = tf.train.Saver(tf.global_variables())
153+
154+
return encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver, learning_rate_decay_op, learning_rate
155+
156+
157+
def train():
158+
"""
159+
训练过程
160+
"""
161+
# train_set = [[[5, 7, 9], [11, 13, 15, EOS_ID]], [[7, 9, 11], [13, 15, 17, EOS_ID]],
162+
# [[15, 17, 19], [21, 23, 25, EOS_ID]]]
163+
train_set = get_train_set()
164+
with tf.Session() as sess:
165+
166+
sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples(train_set)
167+
encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver, learning_rate_decay_op, learning_rate = get_model()
168+
169+
input_feed = {}
170+
for l in xrange(input_seq_len):
171+
input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l]
172+
for l in xrange(output_seq_len):
173+
input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l]
174+
input_feed[target_weights[l].name] = sample_target_weights[l]
175+
input_feed[decoder_inputs[output_seq_len].name] = np.zeros([len(sample_decoder_inputs[0])], dtype=np.int32)
176+
177+
# 全部变量初始化
178+
sess.run(tf.global_variables_initializer())
179+
180+
# 训练很多次迭代,每隔10次打印一次loss,可以看情况直接ctrl+c停止
181+
previous_losses = []
182+
for step in xrange(20700):
183+
[loss_ret, _] = sess.run([loss, update], input_feed)
184+
if step % 10 == 0:
185+
print 'step=', step, 'loss=', loss_ret, 'learning_rate=', learning_rate.eval()
186+
187+
if len(previous_losses) > 5 and loss_ret > max(previous_losses[-5:]):
188+
sess.run(learning_rate_decay_op)
189+
previous_losses.append(loss_ret)
190+
191+
# 模型持久化
192+
saver.save(sess, './model/demo')
193+
194+
195+
def predict():
196+
"""
197+
预测过程
198+
"""
199+
with tf.Session() as sess:
200+
encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver, learning_rate_decay_op, learning_rate = get_model(feed_previous=True)
201+
saver.restore(sess, './model/demo')
202+
sys.stdout.write("> ")
203+
sys.stdout.flush()
204+
input_seq = sys.stdin.readline()
205+
while input_seq:
206+
input_seq = input_seq.strip()
207+
input_id_list = get_id_list_from(input_seq)
208+
if (len(input_id_list)):
209+
sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = seq_to_encoder(' '.join([str(v) for v in input_id_list]))
210+
211+
input_feed = {}
212+
for l in xrange(input_seq_len):
213+
input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l]
214+
for l in xrange(output_seq_len):
215+
input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l]
216+
input_feed[target_weights[l].name] = sample_target_weights[l]
217+
input_feed[decoder_inputs[output_seq_len].name] = np.zeros([2], dtype=np.int32)
218+
219+
# 预测输出
220+
outputs_seq = sess.run(outputs, input_feed)
221+
# 因为输出数据每一个是num_decoder_symbols维的,因此找到数值最大的那个就是预测的id,就是这里的argmax函数的功能
222+
outputs_seq = [int(np.argmax(logit[0], axis=0)) for logit in outputs_seq]
223+
# 如果是结尾符,那么后面的语句就不输出了
224+
if EOS_ID in outputs_seq:
225+
outputs_seq = outputs_seq[:outputs_seq.index(EOS_ID)]
226+
outputs_seq = [wordToken.id2word(v) for v in outputs_seq]
227+
print " ".join(outputs_seq)
228+
else:
229+
print "WARN:词汇不在服务区"
230+
231+
sys.stdout.write("> ")
232+
sys.stdout.flush()
233+
input_seq = sys.stdin.readline()
234+
235+
236+
if __name__ == "__main__":
237+
if sys.argv[1] == 'train':
238+
train()
239+
else:
240+
predict()

0 commit comments

Comments
 (0)