Skip to content

Commit

Permalink
Merge pull request #10 from TLX-CTR-Algorithm/liuys
Browse files Browse the repository at this point in the history
Liuys
  • Loading branch information
liuysong authored Oct 26, 2018
2 parents 03b097d + d4fcd1d commit c33fd9d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
1 change: 1 addition & 0 deletions DNN/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def parse_args(check=True):
parser.add_argument('--embed_dim', type=int, default=config.embed_dim)
parser.add_argument('--learning_rate', type=float, default=config.learning_rate)
parser.add_argument('--oridata_dim', type=int, default=config.oridata_dim)
parser.add_argument('--valid_switch', type=int, default=config.valid_switch)
# 路径和文件配置
parser.add_argument('--encod_train_path', type=str, default=config.encod_train_path)
parser.add_argument('--encod_vaild_path', type=str, default=config.encod_vaild_path)
Expand Down
49 changes: 35 additions & 14 deletions DNN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,42 @@ def train_model(batch_size=FLAGS.batch_size):

logging.info('----------------------valid-----------------------')
#使用验证数据,验证模型性能
valid_batches = utils.genbatch(valid_inputs, valid_labels, batch_size=FLAGS.batch_size)
for step in range(len(valid_inputs) // batch_size):
batch_valid_inputs,batch_valid_lables = next(valid_batches)
valid_continous_inputs = batch_valid_inputs[:, 0:FLAGS.encod_cat_index_begin]
valid_categorial_inputs = batch_valid_inputs[:,FLAGS.encod_cat_index_begin:FLAGS.encod_cat_index_end]
feed_dict = { dnnmodel.categorial_inputs:valid_categorial_inputs, dnnmodel.continous_inputs:valid_continous_inputs, dnnmodel.label:batch_valid_lables, dnnmodel.keep_prob:FLAGS.keep_prob }
#with tf.Session() as sess:
if FLAGS.valid_switch == 0:
valid_continous_inputs = valid_inputs[:, 0:FLAGS.encod_cat_index_begin]
valid_categorial_inputs = valid_inputs[:, FLAGS.encod_cat_index_begin:FLAGS.encod_cat_index_end]
feed_dict = {dnnmodel.categorial_inputs: valid_categorial_inputs,
dnnmodel.continous_inputs: valid_continous_inputs, dnnmodel.label: valid_labels,
dnnmodel.keep_prob: FLAGS.keep_prob}
valid_global_step, logits, loss, accuracy, auc, end_points, labels = sess.run([dnnmodel.global_step, dnnmodel.logits, dnnmodel.loss, dnnmodel.accuracy, dnnmodel.auc, dnnmodel.end_points, dnnmodel.label], feed_dict=feed_dict)
#np.savetxt('./log/logits.log', end_points['logits'])
#np.savetxt('./log/pre.log', end_points['prediction'])
#np.savetxt('./log/labels.log', labels)
if step % FLAGS.logfrequency == 0:
#每间隔指定的频率打印日志并存储checkpoint文件
logging.info('valid: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, loss, auc, accuracy))
#saver.save(sess, os.path.join(FLAGS.model_ouput_dir, "model.ckpt"), global_step=global_step)
logging.info(
'valid0: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, loss, auc,
accuracy))
else:
valid_batches = utils.genbatch(valid_inputs, valid_labels, batch_size=FLAGS.batch_size)
loss_list = []
auc_list = []
accuracy_list = []
for step in range(len(valid_inputs) // batch_size):
batch_valid_inputs,batch_valid_lables = next(valid_batches)
valid_continous_inputs = batch_valid_inputs[:, 0:FLAGS.encod_cat_index_begin]
valid_categorial_inputs = batch_valid_inputs[:,FLAGS.encod_cat_index_begin:FLAGS.encod_cat_index_end]
feed_dict = { dnnmodel.categorial_inputs:valid_categorial_inputs, dnnmodel.continous_inputs:valid_continous_inputs, dnnmodel.label:batch_valid_lables, dnnmodel.keep_prob:FLAGS.keep_prob }
valid_global_step, logits, loss, accuracy, auc, end_points, labels = sess.run([dnnmodel.global_step, dnnmodel.logits, dnnmodel.loss, dnnmodel.accuracy, dnnmodel.auc, dnnmodel.end_points, dnnmodel.label], feed_dict=feed_dict)

loss_list.append(loss)
auc_list.append(auc)
accuracy_list.append(accuracy)
#np.savetxt('./log/logits.log', end_points['logits'])
#np.savetxt('./log/pre.log', end_points['prediction'])
#np.savetxt('./log/labels.log', labels)
#if step % FLAGS.logfrequency == 0:
#每间隔指定的频率打印日志并存储checkpoint文件
# logging.info('valid: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, loss, auc, accuracy))
valid_loss = np.mean(loss_list)
valid_auc = np.mean(auc_list,0)
valid_accuracy = np.mean(accuracy_list)
logging.info( 'valid1: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, valid_loss, valid_auc, valid_accuracy))

epoch = (global_step * batch_size) // count_data
logging.info('has completed epoch:{}'.format(epoch))
if epoch >= FLAGS.Max_epoch or global_step >= FLAGS.Max_step:
Expand Down
1 change: 1 addition & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
dnn_log_path = os.path.join(dnn_log_dir, dnn_log_file)
encod_cat_index_begin = 4
encod_cat_index_end = 30
valid_switch = 1
# 训练参数
batch_size = 1000
keep_prob = 0.8
Expand Down

0 comments on commit c33fd9d

Please sign in to comment.