forked from Leavingseason/xDeepFM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
304 lines (283 loc) · 12.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""define train, infer, eval, test process"""
import numpy as np
import os, time, collections
import tensorflow as tf
from IO.iterator import FfmIterator #, DinIterator, CCCFNetIterator
#from IO.din_cache import DinCache
from IO.ffm_cache import FfmCache
#from IO.cccfnet_cache import CCCFNetCache
#from src.deep_fm import DeepfmModel
#from src.deep_wide import DeepWideModel
#from src.fm import FmModel
#from src.dnn import DnnModel
#from src.opnn import OpnnModel
#from src.ipnn import IpnnModel
#from src.lr import LrModel
#from src.din import DinModel
#from src.cccfnet import CCCFModel
#from src.deepcross import DeepCrossModel
from src.exDeepFM import ExtremeDeepFMModel
from src.CIN import CINModel
#from src.cross import CrossModel
import utils.util as util
import utils.metric as metric
# from utils.log import Log
# log = Log(hparams)
class TrainModel(collections.namedtuple("TrainModel", ("graph", "model", "iterator", "filenames"))):
"""define train class, include graph, model, iterator"""
pass
def create_train_model(model_creator, hparams, scope=None):
graph = tf.Graph()
with graph.as_default():
# feed train file name, valid file name, or test file name
filenames = tf.placeholder(tf.string, shape=[None])
#src_dataset = tf.contrib.data.TFRecordDataset(filenames)
src_dataset = tf.data.TFRecordDataset(filenames)
if hparams.data_format == 'ffm':
batch_input = FfmIterator(src_dataset)
elif hparams.data_format == 'din':
batch_input = DinIterator(src_dataset)
elif hparams.data_format == 'cccfnet':
batch_input = CCCFNetIterator(src_dataset)
else:
raise ValueError("not support {0} format data".format(hparams.data_format))
# build model
model = model_creator(
hparams,
iterator=batch_input,
scope=scope)
return TrainModel(
graph=graph,
model=model,
iterator=batch_input,
filenames=filenames)
# run evaluation and get evaluted loss
def run_eval(load_model, load_sess, filename, sample_num_file, hparams, flag):
# load sample num
with open(sample_num_file, 'r') as f:
sample_num = int(f.readlines()[0].strip())
load_sess.run(load_model.iterator.initializer, feed_dict={load_model.filenames: [filename]})
preds = []
labels = []
while True:
try:
_, _, step_pred, step_labels = load_model.model.eval(load_sess)
preds.extend(np.reshape(step_pred, -1))
labels.extend(np.reshape(step_labels, -1))
except tf.errors.OutOfRangeError:
break
preds = preds[:sample_num]
labels = labels[:sample_num]
hparams.logger.info("data num:{0:d}".format(len(labels)))
res = metric.cal_metric(labels, preds, hparams, flag)
return res
# run infer
def run_infer(load_model, load_sess, filename, hparams, sample_num_file):
# load sample num
with open(sample_num_file, 'r') as f:
sample_num = int(f.readlines()[0].strip())
if not os.path.exists(util.RES_DIR):
os.mkdir(util.RES_DIR)
load_sess.run(load_model.iterator.initializer, feed_dict={load_model.filenames: [filename]})
preds = []
while True:
try:
step_pred = load_model.model.infer(load_sess)
preds.extend(np.reshape(step_pred, -1))
except tf.errors.OutOfRangeError:
break
preds = preds[:sample_num]
hparams.res_name = util.convert_res_name(hparams.infer_file)
# print('result name:', hparams.res_name)
with open(hparams.res_name, 'w') as out:
out.write('\n'.join(map(str, preds)))
# cache data
def cache_data(hparams, filename, flag):
if hparams.data_format == 'ffm':
cache_obj = FfmCache()
elif hparams.data_format == 'din':
cache_obj = DinCache()
elif hparams.data_format == 'cccfnet':
cache_obj = CCCFNetCache()
else:
raise ValueError(
"data format must be ffm, din, cccfnet, this format not defined {0}".format(hparams.data_format))
if not os.path.exists(util.CACHE_DIR):
os.mkdir(util.CACHE_DIR)
if flag == 'train':
hparams.train_file_cache = util.convert_cached_name(hparams.train_file, hparams.batch_size)
cached_name = hparams.train_file_cache
sample_num_path = util.TRAIN_NUM
impression_id_path = util.TRAIN_IMPRESSION_ID
elif flag == 'eval':
hparams.eval_file_cache = util.convert_cached_name(hparams.eval_file, hparams.batch_size)
cached_name = hparams.eval_file_cache
sample_num_path = util.EVAL_NUM
impression_id_path = util.EVAL_IMPRESSION_ID
elif flag == 'test':
hparams.test_file_cache = util.convert_cached_name(hparams.test_file, hparams.batch_size)
cached_name = hparams.test_file_cache
sample_num_path = util.TEST_NUM
impression_id_path = util.TEST_IMPRESSION_ID
elif flag == 'infer':
hparams.infer_file_cache = util.convert_cached_name(hparams.infer_file, hparams.batch_size)
cached_name = hparams.infer_file_cache
sample_num_path = util.INFER_NUM
impression_id_path = util.INFER_IMPRESSION_ID
else:
raise ValueError("flag must be train, eval, test, infer")
print('cache filename:', filename)
if not os.path.isfile(cached_name):
print('has not cached file, begin cached...')
start_time = time.time()
sample_num, impression_id_list = cache_obj.write_tfrecord(filename, cached_name, hparams)
util.print_time("caced file used time", start_time)
print("data sample num:{0}".format(sample_num))
with open(sample_num_path, 'w') as f:
f.write(str(sample_num) + '\n')
with open(impression_id_path, 'w') as f:
for impression_id in impression_id_list:
f.write(str(impression_id) + '\n')
def train(hparams, scope=None, target_session=""):
params = hparams.values()
for key, val in params.items():
hparams.logger.info(str(key) + ':' + str(val))
print('load and cache data...')
if hparams.train_file is not None:
cache_data(hparams, hparams.train_file, flag='train')
if hparams.eval_file is not None:
cache_data(hparams, hparams.eval_file, flag='eval')
if hparams.test_file is not None:
cache_data(hparams, hparams.test_file, flag='test')
if hparams.infer_file is not None:
cache_data(hparams, hparams.infer_file, flag='infer')
if hparams.model_type == 'deepFM':
model_creator = DeepfmModel
print("run deepfm model!")
elif hparams.model_type == 'deepWide':
model_creator = DeepWideModel
print("run deepWide model!")
elif hparams.model_type == 'dnn':
print("run dnn model!")
model_creator = DnnModel
elif hparams.model_type == 'ipnn':
print("run ipnn model!")
model_creator = IpnnModel
elif hparams.model_type == 'opnn':
print("run opnn model!")
model_creator = OpnnModel
elif hparams.model_type == 'din':
print("run din model!")
model_creator = DinModel
elif hparams.model_type == 'fm':
print("run fm model!")
model_creator = FmModel
elif hparams.model_type == 'lr':
print("run lr model!")
model_creator = LrModel
elif hparams.model_type == 'din':
print("run din model!")
model_creator = DinModel
elif hparams.model_type == 'cccfnet':
print("run cccfnet model!")
model_creator = CCCFModel
elif hparams.model_type == 'deepcross':
print("run deepcross model!")
model_creator = DeepCrossModel
elif hparams.model_type == 'exDeepFM':
print("run extreme deepFM model!")
model_creator = ExtremeDeepFMModel
elif hparams.model_type == 'cross':
print("run extreme cross model!")
model_creator = CrossModel
elif hparams.model_type == 'CIN':
print("run extreme cin model!")
model_creator = CINModel
else:
raise ValueError("model type should be cccfnet, deepFM, deepWide, dnn, fm, lr, ipnn, opnn, din")
# define train,eval,infer graph
# define train session, eval session, infer session
train_model = create_train_model(model_creator, hparams, scope)
gpuconfig = tf.ConfigProto()
gpuconfig.gpu_options.allow_growth = True
tf.set_random_seed(1234)
train_sess = tf.Session(target=target_session, graph=train_model.graph, config=gpuconfig)
train_sess.run(train_model.model.init_op)
# load model from checkpoint
if not hparams.load_model_name is None:
checkpoint_path = hparams.load_model_name
try:
train_model.model.saver.restore(train_sess, checkpoint_path)
print('load model', checkpoint_path)
except:
raise IOError("Failed to find any matching files for {0}".format(checkpoint_path))
print('total_loss = data_loss+regularization_loss, data_loss = {rmse or logloss ..}')
writer = tf.summary.FileWriter(util.SUMMARIES_DIR, train_sess.graph)
last_eval = 0
for epoch in range(hparams.epochs):
step = 0
train_sess.run(train_model.iterator.initializer, feed_dict={train_model.filenames: [hparams.train_file_cache]})
epoch_loss = 0
train_start = time.time()
train_load_time = 0
while True:
try:
t1 = time.time()
step_result = train_model.model.train(train_sess)
t3 = time.time()
train_load_time += t3 - t1
(_, step_loss, step_data_loss, summary) = step_result
writer.add_summary(summary, step)
epoch_loss += step_loss
step += 1
if step % hparams.show_step == 0:
print('step {0:d} , total_loss: {1:.4f}, data_loss: {2:.4f}' \
.format(step, step_loss, step_data_loss))
except tf.errors.OutOfRangeError:
print('finish one epoch!')
break
train_end = time.time()
train_time = train_end - train_start
if epoch % hparams.save_epoch == 0:
checkpoint_path = train_model.model.saver.save(
sess=train_sess,
save_path=util.MODEL_DIR + 'epoch_' + str(epoch))
# print(checkpoint_path)
train_res = dict()
train_res["loss"] = epoch_loss / step
eval_start = time.time()
# train_res = run_eval(train_model, train_sess, hparams.train_file_cache, util.TRAIN_NUM, hparams, flag='train')
eval_res = run_eval(train_model, train_sess, hparams.eval_file_cache, util.EVAL_NUM, hparams, flag='eval')
train_info = ', '.join(
[str(item[0]) + ':' + str(item[1])
for item in sorted(train_res.items(), key=lambda x: x[0])])
eval_info = ', '.join(
[str(item[0]) + ':' + str(item[1])
for item in sorted(eval_res.items(), key=lambda x: x[0])])
if hparams.test_file is not None:
test_res = run_eval(train_model, train_sess, hparams.test_file_cache, util.TEST_NUM, hparams, flag='test')
test_info = ', '.join(
[str(item[0]) + ':' + str(item[1])
for item in sorted(test_res.items(), key=lambda x: x[0])])
eval_end = time.time()
eval_time = eval_end - eval_start
if hparams.test_file is not None:
print('at epoch {0:d}'.format(
epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info + ' test info: ' + test_info)
hparams.logger.info('at epoch {0:d}'.format(
epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info + ' test info: ' + test_info)
else:
print('at epoch {0:d}'.format(epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info)
hparams.logger.info('at epoch {0:d}'.format(epoch) + ' train info: ' + train_info + ' eval info: ' + eval_info)
print('at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}'.format(epoch, train_time, eval_time))
hparams.logger.info('at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}' \
.format(epoch, train_time, eval_time))
hparams.logger.info('\n')
if eval_res["auc"] - last_eval < - 0.003:
break
if eval_res["auc"] > last_eval:
last_eval = eval_res["auc"]
writer.close()
# after train,run infer
if hparams.infer_file is not None:
run_infer(train_model, train_sess, hparams.infer_file_cache, hparams, util.INFER_NUM)