Skip to content

Commit

Permalink
support AMP for In-batch Negatives (PaddlePaddle#1374)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianxin authored Dec 8, 2021
1 parent 9f79af1 commit c3a467b
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions examples/semantic_indexing/train_batch_neg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@
parser.add_argument("--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument("--output_emb_size", default=None, type=int, help="output_embedding_size")
parser.add_argument("--output_emb_size", default=None, type=int, help="output_embedding_size.")
parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--epochs", default=10, type=int, help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.0, type=float, help="Linear warmup proption over the training process.")
parser.add_argument("--init_from_ckpt", type=str, default=None, help="The path of checkpoint to be loaded.")
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization")
parser.add_argument("--seed", type=int, default=1000, help="random seed for initialization.")
parser.add_argument('--device', choices=['cpu', 'gpu'], default="gpu", help="Select which device to train model, defaults to gpu.")
parser.add_argument('--save_steps', type=int, default=10000, help="Inteval steps to save checkpoint")
parser.add_argument("--train_set_file", type=str, required=True, help="The full path of train_set_file")
parser.add_argument("--margin", default=0.3, type=float, help="Margin beteween pos_sample and neg_samples")
parser.add_argument('--save_steps', type=int, default=10000, help="Inteval steps to save checkpoint.")
parser.add_argument("--train_set_file", type=str, required=True, help="The full path of train_set_file.")
parser.add_argument("--margin", default=0.3, type=float, help="Margin beteween pos_sample and neg_samples.")
parser.add_argument("--scale", default=30, type=int, help="Scale for pair-wise margin_rank_loss")
parser.add_argument("--use_amp", action="store_true", help="Whether to use AMP.")
parser.add_argument("--amp_loss_scale", default=32768, type=float, help="The value of scale_loss for fp16. This is only used for AMP training.")


args = parser.parse_args()
Expand Down Expand Up @@ -133,17 +135,31 @@ def do_train():
weight_decay=args.weight_decay,
apply_decay_param_fun=lambda x: x in decay_params)

if args.use_amp:
scaler = paddle.amp.GradScaler(init_loss_scaling=args.amp_loss_scale)

global_step = 0
tic_train = time.time()
for epoch in range(1, args.epochs + 1):
for step, batch in enumerate(train_data_loader, start=1):
query_input_ids, query_token_type_ids, title_input_ids, title_token_type_ids = batch

loss = model(
query_input_ids=query_input_ids,
title_input_ids=title_input_ids,
query_token_type_ids=query_token_type_ids,
title_token_type_ids=title_token_type_ids)
with paddle.amp.auto_cast(
args.use_amp,
custom_white_list=["layer_norm", "softmax", "gelu"]):
loss = model(
query_input_ids=query_input_ids,
title_input_ids=title_input_ids,
query_token_type_ids=query_token_type_ids,
title_token_type_ids=title_token_type_ids)

if args.use_amp:
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
else:
loss.backward()
optimizer.step()

global_step += 1
if global_step % 10 == 0 and rank == 0:
Expand All @@ -152,10 +168,10 @@ def do_train():
% (global_step, epoch, step, loss,
10 / (time.time() - tic_train)))
tic_train = time.time()
loss.backward()
optimizer.step()

lr_scheduler.step()
optimizer.clear_grad()

if global_step % args.save_steps == 0 and rank == 0:
save_dir = os.path.join(args.save_dir, "model_%d" % global_step)
if not os.path.exists(save_dir):
Expand Down

0 comments on commit c3a467b

Please sign in to comment.