|
| 1 | +# coding:utf-8 |
| 2 | +# author: lichuang |
| 3 | + |
| 4 | +import sys |
| 5 | +import numpy as np |
| 6 | +import tensorflow as tf |
| 7 | +from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq |
| 8 | +import word_token |
| 9 | +import jieba |
| 10 | + |
| 11 | +# 输入序列长度 |
| 12 | +input_seq_len = 5 |
| 13 | +# 输出序列长度 |
| 14 | +output_seq_len = 5 |
| 15 | +# 空值填充0 |
| 16 | +PAD_ID = 0 |
| 17 | +# 输出序列起始标记 |
| 18 | +GO_ID = 1 |
| 19 | +# 结尾标记 |
| 20 | +EOS_ID = 2 |
| 21 | +# LSTM神经元size |
| 22 | +size = 8 |
| 23 | +# 最大输入符号数 |
| 24 | +num_encoder_symbols = 32 |
| 25 | +# 最大输出符号数 |
| 26 | +num_decoder_symbols = 32 |
| 27 | +# 初始学习率 |
| 28 | +init_learning_rate = 1 |
| 29 | + |
| 30 | +wordToken = word_token.WordToken() |
| 31 | + |
| 32 | +# 放在全局的位置,为了动态算出num_encoder_symbols和num_decoder_symbols |
| 33 | +max_token_id = wordToken.load_file_list(['./samples/question', './samples/answer']) |
| 34 | +num_encoder_symbols = max_token_id + 5 |
| 35 | +num_decoder_symbols = max_token_id + 5 |
| 36 | + |
| 37 | + |
| 38 | +def get_id_list_from(sentence): |
| 39 | + sentence_id_list = [] |
| 40 | + seg_list = jieba.cut(sentence) |
| 41 | + for str in seg_list: |
| 42 | + id = wordToken.word2id(str) |
| 43 | + if id: |
| 44 | + sentence_id_list.append(wordToken.word2id(str)) |
| 45 | + return sentence_id_list |
| 46 | + |
| 47 | + |
| 48 | +def get_train_set(): |
| 49 | + global num_encoder_symbols, num_decoder_symbols |
| 50 | + train_set = [] |
| 51 | + with open('./samples/question', 'r') as question_file: |
| 52 | + with open('./samples/answer', 'r') as answer_file: |
| 53 | + while True: |
| 54 | + question = question_file.readline() |
| 55 | + answer = answer_file.readline() |
| 56 | + if question and answer: |
| 57 | + question = question.strip() |
| 58 | + answer = answer.strip() |
| 59 | + |
| 60 | + question_id_list = get_id_list_from(question) |
| 61 | + answer_id_list = get_id_list_from(answer) |
| 62 | + answer_id_list.append(EOS_ID) |
| 63 | + train_set.append([question_id_list, answer_id_list]) |
| 64 | + else: |
| 65 | + break |
| 66 | + return train_set |
| 67 | + |
| 68 | + |
| 69 | +def get_samples(train_set): |
| 70 | + """构造样本数据 |
| 71 | +
|
| 72 | + :return: |
| 73 | + encoder_inputs: [array([0, 0], dtype=int32), array([0, 0], dtype=int32), array([5, 5], dtype=int32), |
| 74 | + array([7, 7], dtype=int32), array([9, 9], dtype=int32)] |
| 75 | + decoder_inputs: [array([1, 1], dtype=int32), array([11, 11], dtype=int32), array([13, 13], dtype=int32), |
| 76 | + array([15, 15], dtype=int32), array([2, 2], dtype=int32)] |
| 77 | + """ |
| 78 | + # train_set = [[[5, 7, 9], [11, 13, 15, EOS_ID]], [[7, 9, 11], [13, 15, 17, EOS_ID]], [[15, 17, 19], [21, 23, 25, EOS_ID]]] |
| 79 | + raw_encoder_input = [] |
| 80 | + raw_decoder_input = [] |
| 81 | + for sample in train_set: |
| 82 | + raw_encoder_input.append([PAD_ID] * (input_seq_len - len(sample[0])) + sample[0]) |
| 83 | + raw_decoder_input.append([GO_ID] + sample[1] + [PAD_ID] * (output_seq_len - len(sample[1]) - 1)) |
| 84 | + |
| 85 | + encoder_inputs = [] |
| 86 | + decoder_inputs = [] |
| 87 | + target_weights = [] |
| 88 | + |
| 89 | + for length_idx in xrange(input_seq_len): |
| 90 | + encoder_inputs.append(np.array([encoder_input[length_idx] for encoder_input in raw_encoder_input], dtype=np.int32)) |
| 91 | + for length_idx in xrange(output_seq_len): |
| 92 | + decoder_inputs.append(np.array([decoder_input[length_idx] for decoder_input in raw_decoder_input], dtype=np.int32)) |
| 93 | + target_weights.append(np.array([ |
| 94 | + 0.0 if length_idx == output_seq_len - 1 or decoder_input[length_idx] == PAD_ID else 1.0 for decoder_input in raw_decoder_input |
| 95 | + ], dtype=np.float32)) |
| 96 | + return encoder_inputs, decoder_inputs, target_weights |
| 97 | + |
| 98 | + |
| 99 | +def seq_to_encoder(input_seq): |
| 100 | + """从输入空格分隔的数字id串,转成预测用的encoder、decoder、target_weight等 |
| 101 | + """ |
| 102 | + input_seq_array = [int(v) for v in input_seq.split()] |
| 103 | + encoder_input = [PAD_ID] * (input_seq_len - len(input_seq_array)) + input_seq_array |
| 104 | + decoder_input = [GO_ID] + [PAD_ID] * (output_seq_len - 1) |
| 105 | + encoder_inputs = [np.array([v], dtype=np.int32) for v in encoder_input] |
| 106 | + decoder_inputs = [np.array([v], dtype=np.int32) for v in decoder_input] |
| 107 | + target_weights = [np.array([1.0], dtype=np.float32)] * output_seq_len |
| 108 | + return encoder_inputs, decoder_inputs, target_weights |
| 109 | + |
| 110 | + |
| 111 | +def get_model(feed_previous=False): |
| 112 | + """构造模型 |
| 113 | + """ |
| 114 | + |
| 115 | + learning_rate = tf.Variable(float(init_learning_rate), trainable=False, dtype=tf.float32) |
| 116 | + learning_rate_decay_op = learning_rate.assign(learning_rate * 0.9) |
| 117 | + |
| 118 | + encoder_inputs = [] |
| 119 | + decoder_inputs = [] |
| 120 | + target_weights = [] |
| 121 | + for i in xrange(input_seq_len): |
| 122 | + encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="encoder{0}".format(i))) |
| 123 | + for i in xrange(output_seq_len + 1): |
| 124 | + decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="decoder{0}".format(i))) |
| 125 | + for i in xrange(output_seq_len): |
| 126 | + target_weights.append(tf.placeholder(tf.float32, shape=[None], name="weight{0}".format(i))) |
| 127 | + |
| 128 | + # decoder_inputs左移一个时序作为targets |
| 129 | + targets = [decoder_inputs[i + 1] for i in xrange(output_seq_len)] |
| 130 | + |
| 131 | + cell = tf.contrib.rnn.BasicLSTMCell(size) |
| 132 | + |
| 133 | + # 这里输出的状态我们不需要 |
| 134 | + outputs, _ = seq2seq.embedding_attention_seq2seq( |
| 135 | + encoder_inputs, |
| 136 | + decoder_inputs[:output_seq_len], |
| 137 | + cell, |
| 138 | + num_encoder_symbols=num_encoder_symbols, |
| 139 | + num_decoder_symbols=num_decoder_symbols, |
| 140 | + embedding_size=size, |
| 141 | + output_projection=None, |
| 142 | + feed_previous=feed_previous, |
| 143 | + dtype=tf.float32) |
| 144 | + |
| 145 | + # 计算加权交叉熵损失 |
| 146 | + loss = seq2seq.sequence_loss(outputs, targets, target_weights) |
| 147 | + # 梯度下降优化器 |
| 148 | + opt = tf.train.GradientDescentOptimizer(learning_rate) |
| 149 | + # 优化目标:让loss最小化 |
| 150 | + update = opt.apply_gradients(opt.compute_gradients(loss)) |
| 151 | + # 模型持久化 |
| 152 | + saver = tf.train.Saver(tf.global_variables()) |
| 153 | + |
| 154 | + return encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver, learning_rate_decay_op, learning_rate |
| 155 | + |
| 156 | + |
| 157 | +def train(): |
| 158 | + """ |
| 159 | + 训练过程 |
| 160 | + """ |
| 161 | + # train_set = [[[5, 7, 9], [11, 13, 15, EOS_ID]], [[7, 9, 11], [13, 15, 17, EOS_ID]], |
| 162 | + # [[15, 17, 19], [21, 23, 25, EOS_ID]]] |
| 163 | + train_set = get_train_set() |
| 164 | + with tf.Session() as sess: |
| 165 | + |
| 166 | + sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples(train_set) |
| 167 | + encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver, learning_rate_decay_op, learning_rate = get_model() |
| 168 | + |
| 169 | + input_feed = {} |
| 170 | + for l in xrange(input_seq_len): |
| 171 | + input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l] |
| 172 | + for l in xrange(output_seq_len): |
| 173 | + input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l] |
| 174 | + input_feed[target_weights[l].name] = sample_target_weights[l] |
| 175 | + input_feed[decoder_inputs[output_seq_len].name] = np.zeros([len(sample_decoder_inputs[0])], dtype=np.int32) |
| 176 | + |
| 177 | + # 全部变量初始化 |
| 178 | + sess.run(tf.global_variables_initializer()) |
| 179 | + |
| 180 | + # 训练很多次迭代,每隔10次打印一次loss,可以看情况直接ctrl+c停止 |
| 181 | + previous_losses = [] |
| 182 | + for step in xrange(20700): |
| 183 | + [loss_ret, _] = sess.run([loss, update], input_feed) |
| 184 | + if step % 10 == 0: |
| 185 | + print 'step=', step, 'loss=', loss_ret, 'learning_rate=', learning_rate.eval() |
| 186 | + |
| 187 | + if len(previous_losses) > 5 and loss_ret > max(previous_losses[-5:]): |
| 188 | + sess.run(learning_rate_decay_op) |
| 189 | + previous_losses.append(loss_ret) |
| 190 | + |
| 191 | + # 模型持久化 |
| 192 | + saver.save(sess, './model/demo') |
| 193 | + |
| 194 | + |
| 195 | +def predict(): |
| 196 | + """ |
| 197 | + 预测过程 |
| 198 | + """ |
| 199 | + with tf.Session() as sess: |
| 200 | + encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver, learning_rate_decay_op, learning_rate = get_model(feed_previous=True) |
| 201 | + saver.restore(sess, './model/demo') |
| 202 | + sys.stdout.write("> ") |
| 203 | + sys.stdout.flush() |
| 204 | + input_seq = sys.stdin.readline() |
| 205 | + while input_seq: |
| 206 | + input_seq = input_seq.strip() |
| 207 | + input_id_list = get_id_list_from(input_seq) |
| 208 | + if (len(input_id_list)): |
| 209 | + sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = seq_to_encoder(' '.join([str(v) for v in input_id_list])) |
| 210 | + |
| 211 | + input_feed = {} |
| 212 | + for l in xrange(input_seq_len): |
| 213 | + input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l] |
| 214 | + for l in xrange(output_seq_len): |
| 215 | + input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l] |
| 216 | + input_feed[target_weights[l].name] = sample_target_weights[l] |
| 217 | + input_feed[decoder_inputs[output_seq_len].name] = np.zeros([2], dtype=np.int32) |
| 218 | + |
| 219 | + # 预测输出 |
| 220 | + outputs_seq = sess.run(outputs, input_feed) |
| 221 | + # 因为输出数据每一个是num_decoder_symbols维的,因此找到数值最大的那个就是预测的id,就是这里的argmax函数的功能 |
| 222 | + outputs_seq = [int(np.argmax(logit[0], axis=0)) for logit in outputs_seq] |
| 223 | + # 如果是结尾符,那么后面的语句就不输出了 |
| 224 | + if EOS_ID in outputs_seq: |
| 225 | + outputs_seq = outputs_seq[:outputs_seq.index(EOS_ID)] |
| 226 | + outputs_seq = [wordToken.id2word(v) for v in outputs_seq] |
| 227 | + print " ".join(outputs_seq) |
| 228 | + else: |
| 229 | + print "WARN:词汇不在服务区" |
| 230 | + |
| 231 | + sys.stdout.write("> ") |
| 232 | + sys.stdout.flush() |
| 233 | + input_seq = sys.stdin.readline() |
| 234 | + |
| 235 | + |
| 236 | +if __name__ == "__main__": |
| 237 | + if sys.argv[1] == 'train': |
| 238 | + train() |
| 239 | + else: |
| 240 | + predict() |
0 commit comments