Skip to content

Commit

Permalink
Allow run_classifier to export fine-tuned model as a SavedModel.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 315433171
  • Loading branch information
albert-copybara committed Jun 9, 2020
1 parent c21d8a3 commit 7d8e66b
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,36 @@
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")

flags.DEFINE_string(
"export_dir", None,
"The directory where the exported SavedModel will be stored.")


def serving_input_receiver_fn():
"""Creates an input function for serving."""
seq_len = FLAGS.max_seq_length
serialized_example = tf.placeholder(
dtype=tf.string, shape=[None], name="serialized_example")
features = {
"input_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64),
"input_mask": tf.FixedLenFeature([seq_len], dtype=tf.int64),
"segment_ids": tf.FixedLenFeature([seq_len], dtype=tf.int64),
}
feature_map = tf.parse_example(serialized_example, features=features)
feature_map["is_real_example"] = tf.constant(1, dtype=tf.int32)
feature_map["label_ids"] = tf.constant(0, dtype=tf.int32)

# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in feature_map.keys():
t = feature_map[name]
if t.dtype == tf.int64:
t = tf.to_int32(t)
feature_map[name] = t

return tf.estimator.export.ServingInputReceiver(
features=feature_map, receiver_tensors=serialized_example)


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
Expand All @@ -157,9 +187,11 @@ def main(_):
"wnli": classifier_utils.WnliProcessor,
}

if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_predict or
FLAGS.export_dir):
raise ValueError(
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")
"At least one of `do_train`, `do_eval`, `do_predict' or `export_dir` "
"must be True.")

if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle:
raise ValueError("At least one of `--albert_config_file` and "
Expand Down Expand Up @@ -241,7 +273,8 @@ def main(_):
config=run_config,
train_batch_size=FLAGS.train_batch_size,
eval_batch_size=FLAGS.eval_batch_size,
predict_batch_size=FLAGS.predict_batch_size)
predict_batch_size=FLAGS.predict_batch_size,
export_to_tpu=False) # http://yaqs/4707241341091840

if FLAGS.do_train:
cached_dir = FLAGS.cached_dir
Expand Down Expand Up @@ -479,6 +512,16 @@ def _find_valid_cands(curr_step):
num_written_lines += 1
assert num_written_lines == num_actual_predict_examples

if FLAGS.export_dir:
tf.gfile.MakeDirs(FLAGS.export_dir)
checkpoint_path = os.path.join(FLAGS.output_dir, "model.ckpt-best")
tf.logging.info("Starting to export model.")
subfolder = estimator.export_saved_model(
export_dir_base=FLAGS.export_dir,
serving_input_receiver_fn=serving_input_receiver_fn,
checkpoint_path=checkpoint_path)
tf.logging.info("Model exported to %s.", subfolder)


if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
Expand Down

0 comments on commit 7d8e66b

Please sign in to comment.