Skip to content

Commit

Permalink
modify
Browse files Browse the repository at this point in the history
  • Loading branch information
liuysong committed Oct 28, 2018
1 parent 6394e8f commit 8d7e4b7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 10 deletions.
2 changes: 2 additions & 0 deletions DNN/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def parse_args(check=True):
parser.add_argument('--dnn_log_path', type=str, default=config.dnn_log_path)
parser.add_argument('--encod_cat_index_begin', type=int, default=config.encod_cat_index_begin)
parser.add_argument('--encod_cat_index_end', type=int, default=config.encod_cat_index_end)
#日志显示等级
parser.add_argument('--debug_level', type=str, default=config.debug_level)

FLAGS, unparsed = parser.parse_known_args()
return FLAGS, unparsed
Expand Down
36 changes: 27 additions & 9 deletions DNN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import numpy as np
import re
import math
from DNN import flags

slim = tf.contrib.slim
Expand All @@ -16,7 +17,10 @@
os.mkdir(FLAGS.dnn_log_dir)
#设置日志打印格式
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
if FLAGS.debug_level == 'DEBUG':
logger.setLevel(logging.DEBUG)
elif FLAGS.debug_level == 'INFO':
logger.setLevel(logging.INFO)
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter('%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s'))
fl = logging.FileHandler(FLAGS.dnn_log_path)
Expand Down Expand Up @@ -81,6 +85,9 @@ def train_model(batch_size=FLAGS.batch_size):
while 1 == 1:
# 使用训练数据进行模型训练
batches = utils.genbatch(inputs, lables, batch_size=FLAGS.batch_size)
train_loss_list=[]
train_auc_list = []
train_accuracy_list = []
for step in range(len(inputs) // batch_size):
batch_inputs,batch_lables = next(batches)
continous_inputs = batch_inputs[:, 0:FLAGS.encod_cat_index_begin]
Expand All @@ -89,12 +96,15 @@ def train_model(batch_size=FLAGS.batch_size):
#with tf.Session() as sess:
global_step, _, logits, loss, accuracy, summaries, auc, end_points, labels = sess.run([dnnmodel.global_step, dnnmodel.train_step, dnnmodel.logits, dnnmodel.loss, dnnmodel.accuracy, dnnmodel.train_summary_op, dnnmodel.auc, dnnmodel.end_points, dnnmodel.label], feed_dict=feed_dict)
train_summary_writer.add_summary(summaries, step)
train_loss_list.append(loss)
train_auc_list.append(auc[0])
train_accuracy_list.append(accuracy)
#np.savetxt('./log/tlogits.log', end_points['logits'])
#np.savetxt('./log/tpre.log', end_points['prediction'])
#np.savetxt('./log/tlabels.log', labels)
if global_step % FLAGS.logfrequency == 0:
#每间隔指定的频率打印日志并存储checkpoint文件
logging.info('train: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, loss, auc, accuracy))
logging.debug('train: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, loss, auc, accuracy))
try:
saver.save(sess, os.path.join(FLAGS.model_ouput_dir, "model.ckpt"), global_step=global_step)
except:
Expand All @@ -103,8 +113,11 @@ def train_model(batch_size=FLAGS.batch_size):
#if global_step >= FLAGS.Max_step or global_step > epoch * batch_size:
if global_step >= FLAGS.Max_step:
break
train_loss = np.mean(train_loss_list)
train_auc = np.mean(train_auc_list, 0)
train_accuracy = np.mean(train_accuracy_list)

logging.info('----------------------valid-----------------------')
logging.debug('----------------------valid-----------------------')
#使用验证数据,验证模型性能
if FLAGS.valid_switch == 0:
valid_continous_inputs = valid_inputs[:, 0:FLAGS.encod_cat_index_begin]
Expand All @@ -113,8 +126,8 @@ def train_model(batch_size=FLAGS.batch_size):
dnnmodel.continous_inputs: valid_continous_inputs, dnnmodel.label: valid_labels,
dnnmodel.keep_prob: FLAGS.keep_prob}
valid_global_step, logits, loss, accuracy, auc, end_points, labels = sess.run([dnnmodel.global_step, dnnmodel.logits, dnnmodel.loss, dnnmodel.accuracy, dnnmodel.auc, dnnmodel.end_points, dnnmodel.label], feed_dict=feed_dict)
logging.info(
'valid0: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, loss, auc,
logging.debug(
'valid: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, loss, auc,
accuracy))
else:
valid_batches = utils.genbatch(valid_inputs, valid_labels, batch_size=FLAGS.batch_size)
Expand All @@ -129,7 +142,7 @@ def train_model(batch_size=FLAGS.batch_size):
valid_global_step, logits, loss, accuracy, auc, end_points, labels = sess.run([dnnmodel.global_step, dnnmodel.logits, dnnmodel.loss, dnnmodel.accuracy, dnnmodel.auc, dnnmodel.end_points, dnnmodel.label], feed_dict=feed_dict)

loss_list.append(loss)
auc_list.append(auc)
auc_list.append(auc[0])
accuracy_list.append(accuracy)
#np.savetxt('./log/logits.log', end_points['logits'])
#np.savetxt('./log/pre.log', end_points['prediction'])
Expand All @@ -140,10 +153,15 @@ def train_model(batch_size=FLAGS.batch_size):
valid_loss = np.mean(loss_list)
valid_auc = np.mean(auc_list,0)
valid_accuracy = np.mean(accuracy_list)
logging.info( 'valid1: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, valid_loss, valid_auc, valid_accuracy))
logging.debug( 'valid: step [{0}] loss [{1}] auc [{2}] accuracy [{3}]'.format(global_step, valid_loss, valid_auc, valid_accuracy))

epoch = (global_step * batch_size) // count_data
logging.info('has completed epoch:{}'.format(epoch))
#epoch = (global_step * batch_size) // count_data
epoch = math.ceil((global_step * batch_size) / count_data)
logging.debug('has completed epoch:{}'.format(epoch))

logging.info('epoch [{0}] train_loss [{1}] valid_loss [{2}] train_auc [{3}] valid_auc [{4}] train_accuracy [{5}] valid_accuracy [{6}]'.format(
epoch, train_loss, valid_loss, train_auc, valid_auc, train_accuracy, valid_accuracy
))
if epoch >= FLAGS.Max_epoch or global_step >= FLAGS.Max_step:
break

Expand Down
3 changes: 2 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
print(test_path)

# 深度网络相关配置
debug_level='INFO'
# 路径和文件配置
encod_train_path = os.path.join(BASE_DIR, "output/model_data/train.txt")
encod_vaild_path = os.path.join(BASE_DIR, "output/model_data/valid.txt")
Expand All @@ -38,7 +39,7 @@
keep_prob = 0.8
logfrequency = 10
Max_step = 2000000000
Max_epoch = 6
Max_epoch = 50
embed_dim = 128
learning_rate = 0.01
decay_rate = 0.96
Expand Down

0 comments on commit 8d7e4b7

Please sign in to comment.