Skip to content

Commit b24b8a5

Browse files
author
lichuang
committed
arg
1 parent db94252 commit b24b8a5

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

chatbotv2/my_seq2seq.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,17 @@ def generate_trainig_data(self):
121121
init_seq()
122122
xy_data = []
123123
y_data = []
124-
for i in range(10,30,10):
124+
for i in range(30,40,10):
125125
# 问句、答句都是16字,所以取32个
126126
start = i*self.max_seq_len*2
127127
middle = i*self.max_seq_len*2 + self.max_seq_len
128128
end = (i+1)*self.max_seq_len*2
129129
sequence_xy = seq[start:end]
130130
sequence_y = seq[middle:end]
131+
print "right answer"
132+
for w in sequence_y:
133+
(match_word, max_cos) = vector2word(w)
134+
print match_word
131135
sequence_y = [np.ones(self.word_vec_dim)] + sequence_y
132136
xy_data.append(sequence_xy)
133137
y_data.append(sequence_y)
@@ -179,7 +183,7 @@ def model(self, feed_previous=False):
179183
def train(self):
180184
trainXY, trainY = self.generate_trainig_data()
181185
model = self.model(feed_previous=False)
182-
model.fit(trainXY, trainY, n_epoch=100, snapshot_epoch=False)
186+
model.fit(trainXY, trainY, n_epoch=1000, snapshot_epoch=False)
183187
model.save('./model/model')
184188
return model
185189

@@ -189,6 +193,16 @@ def load(self):
189193
return model
190194

191195
if __name__ == '__main__':
196+
phrase = sys.argv[1]
192197
my_seq2seq = MySeq2Seq(word_vec_dim=word_vec_dim, max_seq_len=max_seq_len)
193-
my_seq2seq.train()
194-
#model = my_seq2seq.load()
198+
if phrase == 'train':
199+
my_seq2seq.train()
200+
else:
201+
model = my_seq2seq.load()
202+
trainXY, trainY = my_seq2seq.generate_trainig_data()
203+
predict = model.predict(trainXY)
204+
for sample in predict:
205+
print "predict answer"
206+
for w in sample[1:]:
207+
(match_word, max_cos) = vector2word(w)
208+
print match_word, max_cos

0 commit comments

Comments
 (0)