@@ -121,13 +121,17 @@ def generate_trainig_data(self):
121
121
init_seq ()
122
122
xy_data = []
123
123
y_data = []
124
- for i in range (10 , 30 ,10 ):
124
+ for i in range (30 , 40 ,10 ):
125
125
# 问句、答句都是16字,所以取32个
126
126
start = i * self .max_seq_len * 2
127
127
middle = i * self .max_seq_len * 2 + self .max_seq_len
128
128
end = (i + 1 )* self .max_seq_len * 2
129
129
sequence_xy = seq [start :end ]
130
130
sequence_y = seq [middle :end ]
131
+ print "right answer"
132
+ for w in sequence_y :
133
+ (match_word , max_cos ) = vector2word (w )
134
+ print match_word
131
135
sequence_y = [np .ones (self .word_vec_dim )] + sequence_y
132
136
xy_data .append (sequence_xy )
133
137
y_data .append (sequence_y )
@@ -179,7 +183,7 @@ def model(self, feed_previous=False):
179
183
def train (self ):
180
184
trainXY , trainY = self .generate_trainig_data ()
181
185
model = self .model (feed_previous = False )
182
- model .fit (trainXY , trainY , n_epoch = 100 , snapshot_epoch = False )
186
+ model .fit (trainXY , trainY , n_epoch = 1000 , snapshot_epoch = False )
183
187
model .save ('./model/model' )
184
188
return model
185
189
@@ -189,6 +193,16 @@ def load(self):
189
193
return model
190
194
191
195
if __name__ == '__main__' :
196
+ phrase = sys .argv [1 ]
192
197
my_seq2seq = MySeq2Seq (word_vec_dim = word_vec_dim , max_seq_len = max_seq_len )
193
- my_seq2seq .train ()
194
- #model = my_seq2seq.load()
198
+ if phrase == 'train' :
199
+ my_seq2seq .train ()
200
+ else :
201
+ model = my_seq2seq .load ()
202
+ trainXY , trainY = my_seq2seq .generate_trainig_data ()
203
+ predict = model .predict (trainXY )
204
+ for sample in predict :
205
+ print "predict answer"
206
+ for w in sample [1 :]:
207
+ (match_word , max_cos ) = vector2word (w )
208
+ print match_word , max_cos
0 commit comments