Skip to content

Commit

Permalink
updating run_xlnet_classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Jun 24, 2019
1 parent f6081f2 commit 24ed0b9
Show file tree
Hide file tree
Showing 5 changed files with 587 additions and 125 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,7 @@ tensorflow_code

# Models
models
proc_data
proc_data

# examples
examples/runs
117 changes: 42 additions & 75 deletions examples/run_xlnet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,91 +54,58 @@ def main():
parser = argparse.ArgumentParser()

## Required parameters
parser.add_argument("--data_dir",
default=None,
type=str,
required=True,
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str,
help="XLNet pre-trained model: currently only xlnet-large-cased.")
parser.add_argument("--task_name",
default=None,
type=str,
required=True,
parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")

## Other parameters
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--do_train",
action='store_true',
# training
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval",
action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case",
action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--train_batch_size",
default=32,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=8,
type=int,
help="Total batch size for eval.")
parser.add_argument("--learning_rate",
default=5e-5,
type=float,
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs",
default=3.0,
type=float,
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion",
default=0.1,
type=float,
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--overwrite_output_dir',
action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16',
action='store_true',
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0,
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
# evaluation
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
# Model
parser.add_argument("--xlnet_model", default="xlnet-large-cased", type=str,
help="XLNet pre-trained model: currently only xlnet-large-cased.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
# task specific
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
# Misc
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
Expand Down Expand Up @@ -306,7 +273,7 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch

# define a new function to compute loss values for both output_modes
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)

if output_mode == "classification":
loss_fct = CrossEntropyLoss()
Expand Down Expand Up @@ -420,7 +387,7 @@ def main():
label_ids = label_ids.to(device)

with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)

# create eval loss and other metric required by the task
if output_mode == "classification":
Expand Down Expand Up @@ -501,7 +468,7 @@ def main():
label_ids = label_ids.to(device)

with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)
logits, _ = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=None)

loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
Expand Down
Loading

0 comments on commit 24ed0b9

Please sign in to comment.