Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
brightmart committed Nov 24, 2018
1 parent 728506d commit 27ab348
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions a00_Bert/train_bert_multi-label.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main(_):
loss_total, counter = 0.0, 0
for start, end in zip(range(0, number_of_training_data, batch_size),range(batch_size, number_of_training_data, batch_size)):
iteration = iteration + 1
input_ids_,input_mask_,segment_ids_=get_input_mask_segment_ids(trainX[start:end],cls_id)
input_mask_, segment_ids_, input_ids_=get_input_mask_segment_ids(trainX[start:end],cls_id) # input_ids_,input_mask_,segment_ids_
feed_dict = {input_ids: input_ids_, input_mask: input_mask_, segment_ids:segment_ids_,
label_ids:trainY[start:end],is_training:True}
curr_loss,_ = sess.run([loss,train_op], feed_dict)
Expand Down Expand Up @@ -152,7 +152,7 @@ def do_eval(sess,input_ids,input_mask,segment_ids,label_ids,is_training,loss,pro
f1_score_micro_sklearn_total=0.0
# batch_size=1 # TODO
for start, end in zip(range(0, number_examples, batch_size), range(batch_size, number_examples, batch_size)):
input_ids_,input_mask_, segment_ids_ = get_input_mask_segment_ids(vaildX[start:end],cls_id)
input_mask_, segment_ids_, input_ids_ = get_input_mask_segment_ids(vaildX[start:end],cls_id)
feed_dict = {input_ids: input_ids_,input_mask:input_mask_,segment_ids:segment_ids_,
label_ids:vaildY[start:end],is_training:False}
curr_eval_loss, prob = sess.run([loss, probabilities],feed_dict)
Expand Down

0 comments on commit 27ab348

Please sign in to comment.