Skip to content

Commit

Permalink
More flag validation and documentation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 293712482
  • Loading branch information
0x0539 authored and albert-copybara committed Feb 7, 2020
1 parent 0bc6e42 commit fad4bd9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
14 changes: 12 additions & 2 deletions run_squad_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@

flags.DEFINE_string(
"predict_feature_file", None,
"predict feature file.")
"Location of predict features. If it doesn't exist, it will be written. "
"If it does exist, it will be read.")

flags.DEFINE_string(
"predict_feature_left_file", None,
"predict data kept but not pass to tpu.")
"Location of predict features not passed to TPU. If it doesn't exist, it "
"will be written. If it does exist, it will be read.")

flags.DEFINE_string(
"init_checkpoint", None,
Expand Down Expand Up @@ -188,6 +190,14 @@ def validate_flags_or_throw(albert_config):
if not FLAGS.predict_file:
raise ValueError(
"If `do_predict` is True, then `predict_file` must be specified.")
if not FLAGS.predict_feature_file:
raise ValueError(
"If `do_predict` is True, then `predict_feature_file` must be "
"specified.")
if not FLAGS.predict_feature_left_file:
raise ValueError(
"If `do_predict` is True, then `predict_feature_left_file` must be "
"specified.")

if FLAGS.max_seq_length > albert_config.max_position_embeddings:
raise ValueError(
Expand Down
14 changes: 12 additions & 2 deletions run_squad_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@

flags.DEFINE_string(
"predict_feature_file", None,
"predict feature file.")
"Location of predict features. If it doesn't exist, it will be written. "
"If it does exist, it will be read.")

flags.DEFINE_string(
"predict_feature_left_file", None,
"predict data kept but not pass to tpu.")
"Location of predict features not passed to TPU. If it doesn't exist, it "
"will be written. If it does exist, it will be read.")

flags.DEFINE_string(
"init_checkpoint", None,
Expand Down Expand Up @@ -196,6 +198,14 @@ def validate_flags_or_throw(albert_config):
if not FLAGS.predict_file:
raise ValueError(
"If `do_predict` is True, then `predict_file` must be specified.")
if not FLAGS.predict_feature_file:
raise ValueError(
"If `do_predict` is True, then `predict_feature_file` must be "
"specified.")
if not FLAGS.predict_feature_left_file:
raise ValueError(
"If `do_predict` is True, then `predict_feature_left_file` must be "
"specified.")

if FLAGS.max_seq_length > albert_config.max_position_embeddings:
raise ValueError(
Expand Down

0 comments on commit fad4bd9

Please sign in to comment.