Skip to content

Commit c844ece

Browse files
author
李闯
committed
add new demo.py
1 parent fb01b11 commit c844ece

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

chatbotv4/demo.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# coding:utf-8
2+
import sys
3+
import numpy as np
4+
import tensorflow as tf
5+
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq
6+
7+
# 输入序列长度
8+
input_seq_len = 5
9+
# 输出序列长度
10+
output_seq_len = 5
11+
# 空值填充0
12+
PAD_ID = 0
13+
# 输出序列起始标记
14+
GO_ID = 1
15+
# 结尾标记
16+
EOS_ID = 2
17+
# LSTM神经元size
18+
size = 8
19+
# 最大输入符号数
20+
num_encoder_symbols = 10
21+
# 最大输出符号数
22+
num_decoder_symbols = 16
23+
# 学习率
24+
learning_rate = 0.1
25+
26+
27+
def get_samples():
28+
"""构造样本数据
29+
30+
:return:
31+
encoder_inputs: [array([0, 0], dtype=int32), array([0, 0], dtype=int32), array([1, 3], dtype=int32),
32+
array([3, 5], dtype=int32), array([5, 7], dtype=int32)]
33+
decoder_inputs: [array([1, 1], dtype=int32), array([7, 9], dtype=int32), array([ 9, 11], dtype=int32),
34+
array([11, 13], dtype=int32), array([0, 0], dtype=int32)]
35+
"""
36+
train_set = [[[5, 7, 9], [11, 13, 15, EOS_ID]], [[5, 7, 9], [11, 13, 15, EOS_ID]]]
37+
encoder_input_0 = [PAD_ID] * (input_seq_len - len(train_set[0][0])) + train_set[0][0]
38+
encoder_input_1 = [PAD_ID] * (input_seq_len - len(train_set[1][0])) + train_set[1][0]
39+
decoder_input_0 = [GO_ID] + train_set[0][1] + [PAD_ID] * (output_seq_len - len(train_set[0][1]) - 1)
40+
decoder_input_1 = [GO_ID] + train_set[1][1] + [PAD_ID] * (output_seq_len - len(train_set[1][1]) - 1)
41+
42+
encoder_inputs = []
43+
decoder_inputs = []
44+
target_weights = []
45+
for length_idx in xrange(input_seq_len):
46+
encoder_inputs.append(np.array([encoder_input_0[length_idx], encoder_input_1[length_idx]], dtype=np.int32))
47+
for length_idx in xrange(output_seq_len):
48+
decoder_inputs.append(np.array([decoder_input_0[length_idx], decoder_input_1[length_idx]], dtype=np.int32))
49+
target_weights.append(np.array([
50+
0.0 if length_idx == output_seq_len - 1 or decoder_input_0[length_idx] == PAD_ID else 1.0,
51+
0.0 if length_idx == output_seq_len - 1 or decoder_input_1[length_idx] == PAD_ID else 1.0,
52+
], dtype=np.float32))
53+
return encoder_inputs, decoder_inputs, target_weights
54+
55+
56+
def get_model(feed_previous=False):
57+
"""构造模型
58+
"""
59+
encoder_inputs = []
60+
decoder_inputs = []
61+
target_weights = []
62+
for i in xrange(input_seq_len):
63+
encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="encoder{0}".format(i)))
64+
for i in xrange(output_seq_len + 1):
65+
decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="decoder{0}".format(i)))
66+
for i in xrange(output_seq_len):
67+
target_weights.append(tf.placeholder(tf.float32, shape=[None], name="weight{0}".format(i)))
68+
69+
# decoder_inputs左移一个时序作为targets
70+
targets = [decoder_inputs[i + 1] for i in xrange(output_seq_len)]
71+
72+
cell = tf.contrib.rnn.BasicLSTMCell(size)
73+
74+
# 这里输出的状态我们不需要
75+
outputs, _ = seq2seq.embedding_attention_seq2seq(
76+
encoder_inputs,
77+
decoder_inputs[:output_seq_len],
78+
cell,
79+
num_encoder_symbols=num_encoder_symbols,
80+
num_decoder_symbols=num_decoder_symbols,
81+
embedding_size=size,
82+
output_projection=None,
83+
feed_previous=feed_previous,
84+
dtype=tf.float32)
85+
86+
# 计算加权交叉熵损失
87+
loss = seq2seq.sequence_loss(outputs, targets, target_weights)
88+
# 梯度下降优化器
89+
opt = tf.train.GradientDescentOptimizer(learning_rate)
90+
# 优化目标:让loss最小化
91+
update = opt.apply_gradients(opt.compute_gradients(loss))
92+
# 模型持久化
93+
saver = tf.train.Saver(tf.global_variables())
94+
return encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver
95+
96+
97+
def train():
98+
"""
99+
训练过程
100+
"""
101+
with tf.Session() as sess:
102+
sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples()
103+
encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver = get_model()
104+
105+
input_feed = {}
106+
for l in xrange(input_seq_len):
107+
input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l]
108+
for l in xrange(output_seq_len):
109+
input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l]
110+
input_feed[target_weights[l].name] = sample_target_weights[l]
111+
input_feed[decoder_inputs[output_seq_len].name] = np.zeros([2], dtype=np.int32)
112+
113+
# 全部变量初始化
114+
sess.run(tf.global_variables_initializer())
115+
116+
# 训练200次迭代,每隔10次打印一次loss
117+
for step in xrange(200):
118+
[loss_ret, _] = sess.run([loss, update], input_feed)
119+
if step % 10 == 0:
120+
print 'step=', step, 'loss=', loss_ret
121+
122+
# 模型持久化
123+
saver.save(sess, './model/demo')
124+
125+
126+
def predict():
127+
"""
128+
预测过程
129+
"""
130+
with tf.Session() as sess:
131+
sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples()
132+
encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver = get_model(feed_previous=True)
133+
# 从文件恢复模型
134+
saver.restore(sess, './model/demo')
135+
136+
input_feed = {}
137+
for l in xrange(input_seq_len):
138+
input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l]
139+
for l in xrange(output_seq_len):
140+
input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l]
141+
input_feed[target_weights[l].name] = sample_target_weights[l]
142+
input_feed[decoder_inputs[output_seq_len].name] = np.zeros([2], dtype=np.int32)
143+
144+
# 预测输出
145+
outputs = sess.run(outputs, input_feed)
146+
# 一共试验样本有2个,所以分别遍历
147+
for sample_index in xrange(2):
148+
# 因为输出数据每一个是num_decoder_symbols维的,因此找到数值最大的那个就是预测的id,就是这里的argmax函数的功能
149+
outputs_seq = [int(np.argmax(logit[sample_index], axis=0)) for logit in outputs]
150+
# 如果是结尾符,那么后面的语句就不输出了
151+
if EOS_ID in outputs_seq:
152+
outputs_seq = outputs_seq[:outputs_seq.index(EOS_ID)]
153+
outputs_seq = [str(v) for v in outputs_seq]
154+
print " ".join(outputs_seq)
155+
156+
157+
if __name__ == "__main__":
158+
if sys.argv[1] == 'train':
159+
train()
160+
else:
161+
predict()

0 commit comments

Comments
 (0)