From 941051a5c8ef2274fe7bc698b0aac94e281db2e7 Mon Sep 17 00:00:00 2001 From: Sebastian Goodman Date: Fri, 20 Dec 2019 14:47:03 -0800 Subject: [PATCH] Merging run_classifier_with_tfhub into run_classifier because the implementations diverged. Requiring tensorflow 1.15 for einsum grad issue. PiperOrigin-RevId: 286643632 --- README.md | 10 +- classifier_utils.py | 54 ++++-- requirements.txt | 4 +- run_classifier.py | 38 +++-- run_classifier_with_tfhub.py | 314 ----------------------------------- tokenization.py | 20 +++ 6 files changed, 98 insertions(+), 342 deletions(-) delete mode 100644 run_classifier_with_tfhub.py diff --git a/README.md b/README.md index ed3b028f..27edcc93 100644 --- a/README.md +++ b/README.md @@ -125,7 +125,9 @@ For XNLI, COLA, MNLI, and MRPC, use `run_classifier_sp.py`: ``` pip install -r albert/requirements.txt -python -m albert.run_classifier_sp \ +python -m albert.run_classifier \ + --albert_config_file=albert_config.json \ + --init_checkpoint=/path/to/ckpt \ --task_name=MNLI \ ``` @@ -142,12 +144,12 @@ You should see some output like this: sentence_order_loss = ... ``` -You can also fine-tune the model starting from TF-Hub modules using -`run_classifier_with_tfhub.py`: +You can also fine-tune the model starting from TF-Hub modules: ``` pip install -r albert/requirements.txt -python -m albert.run_classifier_with_tfhub \ +python -m albert.run_classifier \ --albert_hub_module_handle=https://tfhub.dev/google/albert_base/1 \ + --task_name=MNLI \ ``` diff --git a/classifier_utils.py b/classifier_utils.py index 8a3f42f4..887704b0 100644 --- a/classifier_utils.py +++ b/classifier_utils.py @@ -25,6 +25,7 @@ import optimization import tokenization import tensorflow.compat.v1 as tf +import tensorflow_hub as hub from tensorflow.contrib import data as contrib_data from tensorflow.contrib import metrics as contrib_metrics from tensorflow.contrib import tpu as contrib_tpu @@ -765,9 +766,28 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): tokens_b.pop() -def create_model(albert_config, is_training, input_ids, input_mask, segment_ids, - labels, num_labels, use_one_hot_embeddings, task_name): - """Creates a classification model.""" +def _create_model_from_hub(hub_module, is_training, input_ids, input_mask, + segment_ids): + """Creates an ALBERT model from TF-Hub.""" + tags = set() + if is_training: + tags.add("train") + albert_module = hub.Module(hub_module, tags=tags, trainable=True) + albert_inputs = dict( + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids) + albert_outputs = albert_module( + inputs=albert_inputs, + signature="tokens", + as_dict=True) + output_layer = albert_outputs["pooled_output"] + return output_layer + + +def _create_model_from_scratch(albert_config, is_training, input_ids, + input_mask, segment_ids, use_one_hot_embeddings): + """Creates an ALBERT model from scratch (as opposed to hub).""" model = modeling.AlbertModel( config=albert_config, is_training=is_training, @@ -775,13 +795,24 @@ def create_model(albert_config, is_training, input_ids, input_mask, segment_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings) - - # In the demo, we are doing a simple classification task on the entire - # segment. - # - # If you want to use the token-level output, use model.get_sequence_output() - # instead. output_layer = model.get_pooled_output() + return output_layer + + +def create_model(albert_config, is_training, input_ids, input_mask, segment_ids, + labels, num_labels, use_one_hot_embeddings, task_name, + hub_module): + """Creates a classification model.""" + if hub_module: + tf.logging.info("creating model from hub_module: %s", hub_module) + output_layer = _create_model_from_hub(hub_module, is_training, input_ids, + input_mask, segment_ids) + else: + tf.logging.info("creating model from albert_config") + output_layer = _create_model_from_scratch(albert_config, is_training, + input_ids, input_mask, + segment_ids, + use_one_hot_embeddings) hidden_size = output_layer.shape[-1].value @@ -818,7 +849,8 @@ def create_model(albert_config, is_training, input_ids, input_mask, segment_ids, def model_fn_builder(albert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, - use_one_hot_embeddings, task_name, optimizer="adamw"): + use_one_hot_embeddings, task_name, hub_module=None, + optimizer="adamw"): """Returns `model_fn` closure for TPUEstimator.""" def model_fn(features, labels, mode, params): # pylint: disable=unused-argument @@ -843,7 +875,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument (total_loss, per_example_loss, probabilities, logits, predictions) = \ create_model(albert_config, is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, - use_one_hot_embeddings, task_name) + use_one_hot_embeddings, task_name, hub_module) tvars = tf.trainable_variables() initialized_variable_names = {} diff --git a/requirements.txt b/requirements.txt index 389ff8e0..0b911f24 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -tensorflow >=1.11.0,<2.0.0 # CPU Version of TensorFlow. -# tensorflow-gpu >=1.11.0,<2.0.0 # GPU version of TensorFlow. +tensorflow==1.15 # CPU Version of TensorFlow (Python2-only) +# tensorflow-gpu==1.15 # GPU version of TensorFlow (Python2-only) sentencepiece diff --git a/run_classifier.py b/run_classifier.py index fd93f629..8cfa2f62 100644 --- a/run_classifier.py +++ b/run_classifier.py @@ -65,6 +65,10 @@ "init_checkpoint", None, "Initial checkpoint (usually from a pre-trained BERT model).") +flags.DEFINE_string( + "albert_hub_module_handle", None, + "If set, the ALBERT hub module to use.") + flags.DEFINE_bool( "do_lower_case", True, "Whether to lower case the input text. Should be True for uncased " @@ -160,13 +164,20 @@ def main(_): raise ValueError( "At least one of `do_train`, `do_eval` or `do_predict' must be True.") - albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file) - - if FLAGS.max_seq_length > albert_config.max_position_embeddings: - raise ValueError( - "Cannot use sequence length %d because the ALBERT model " - "was only trained up to sequence length %d" % - (FLAGS.max_seq_length, albert_config.max_position_embeddings)) + if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle: + raise ValueError("At least one of `--albert_config_file` and " + "`--albert_hub_module_handle` must be set") + + if FLAGS.albert_config_file: + albert_config = modeling.AlbertConfig.from_json_file( + FLAGS.albert_config_file) + if FLAGS.max_seq_length > albert_config.max_position_embeddings: + raise ValueError( + "Cannot use sequence length %d because the ALBERT model " + "was only trained up to sequence length %d" % + (FLAGS.max_seq_length, albert_config.max_position_embeddings)) + else: + albert_config = None # Get the config from TF-Hub. tf.gfile.MakeDirs(FLAGS.output_dir) @@ -181,9 +192,14 @@ def main(_): label_list = processor.get_labels() - tokenizer = tokenization.FullTokenizer( - vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case, - spm_model_file=FLAGS.spm_model_file) + if FLAGS.albert_hub_module_handle: + tokenizer = tokenization.FullTokenizer.from_hub_module( + hub_module=FLAGS.albert_hub_module_handle, + spm_model_file=FLAGS.spm_model_file) + else: + tokenizer = tokenization.FullTokenizer.from_scratch( + vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case, + spm_model_file=FLAGS.spm_model_file) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: @@ -220,6 +236,7 @@ def main(_): use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_tpu, task_name=task_name, + hub_module=FLAGS.albert_hub_module_handle, optimizer=FLAGS.optimizer) # If TPU is not available, this will fall back to normal Estimator on CPU @@ -456,6 +473,5 @@ def _find_valid_cands(curr_step): flags.mark_flag_as_required("data_dir") flags.mark_flag_as_required("task_name") flags.mark_flag_as_required("vocab_file") - flags.mark_flag_as_required("albert_config_file") flags.mark_flag_as_required("output_dir") tf.app.run() diff --git a/run_classifier_with_tfhub.py b/run_classifier_with_tfhub.py deleted file mode 100644 index 86aa4d2c..00000000 --- a/run_classifier_with_tfhub.py +++ /dev/null @@ -1,314 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Google AI Team Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""ALBERT finetuning runner with TF-Hub.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import os -import classifier_utils -import optimization -import tokenization -import tensorflow as tf -import tensorflow_hub as hub -from tensorflow.contrib import cluster_resolver as contrib_cluster_resolver -from tensorflow.contrib import tpu as contrib_tpu - -flags = tf.flags - -FLAGS = flags.FLAGS - -flags.DEFINE_string( - "albert_hub_module_handle", None, - "Handle for the ALBERT TF-Hub module.") - - -def create_model(is_training, input_ids, input_mask, segment_ids, labels, - num_labels, albert_hub_module_handle): - """Creates a classification model.""" - tags = set() - if is_training: - tags.add("train") - albert_module = hub.Module(albert_hub_module_handle, tags=tags, - trainable=True) - albert_inputs = dict( - input_ids=input_ids, - input_mask=input_mask, - segment_ids=segment_ids) - albert_outputs = albert_module( - inputs=albert_inputs, - signature="tokens", - as_dict=True) - - # In the demo, we are doing a simple classification task on the entire - # segment. - # - # If you want to use the token-level output, use - # albert_outputs["sequence_output"] instead. - output_layer = albert_outputs["pooled_output"] - - hidden_size = output_layer.shape[-1].value - - output_weights = tf.get_variable( - "output_weights", [num_labels, hidden_size], - initializer=tf.truncated_normal_initializer(stddev=0.02)) - - output_bias = tf.get_variable( - "output_bias", [num_labels], initializer=tf.zeros_initializer()) - - with tf.variable_scope("loss"): - if is_training: - # I.e., 0.1 dropout - output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) - - logits = tf.matmul(output_layer, output_weights, transpose_b=True) - logits = tf.nn.bias_add(logits, output_bias) - probabilities = tf.nn.softmax(logits, axis=-1) - log_probs = tf.nn.log_softmax(logits, axis=-1) - - one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) - - per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) - loss = tf.reduce_mean(per_example_loss) - - return (loss, per_example_loss, logits, probabilities) - - -def model_fn_builder(num_labels, learning_rate, num_train_steps, - num_warmup_steps, use_tpu, albert_hub_module_handle): - """Returns `model_fn` closure for TPUEstimator.""" - - def model_fn(features, labels, mode, params): # pylint: disable=unused-argument - """The `model_fn` for TPUEstimator.""" - - tf.logging.info("*** Features ***") - for name in sorted(features.keys()): - tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) - - input_ids = features["input_ids"] - input_mask = features["input_mask"] - segment_ids = features["segment_ids"] - label_ids = features["label_ids"] - - is_training = (mode == tf.estimator.ModeKeys.TRAIN) - - (total_loss, per_example_loss, logits, probabilities) = create_model( - is_training, input_ids, input_mask, segment_ids, label_ids, num_labels, - albert_hub_module_handle) - - output_spec = None - if mode == tf.estimator.ModeKeys.TRAIN: - train_op = optimization.create_optimizer( - total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) - - output_spec = contrib_tpu.TPUEstimatorSpec( - mode=mode, loss=total_loss, train_op=train_op) - elif mode == tf.estimator.ModeKeys.EVAL: - - def metric_fn(per_example_loss, label_ids, logits): - predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) - accuracy = tf.metrics.accuracy(label_ids, predictions) - loss = tf.metrics.mean(per_example_loss) - return { - "eval_accuracy": accuracy, - "eval_loss": loss, - } - - eval_metrics = (metric_fn, [per_example_loss, label_ids, logits]) - output_spec = contrib_tpu.TPUEstimatorSpec( - mode=mode, loss=total_loss, eval_metrics=eval_metrics) - elif mode == tf.estimator.ModeKeys.PREDICT: - output_spec = contrib_tpu.TPUEstimatorSpec( - mode=mode, predictions={"probabilities": probabilities}) - else: - raise ValueError( - "Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode)) - - return output_spec - - return model_fn - - -def create_tokenizer_from_hub_module(albert_hub_module_handle): - """Get the vocab file and casing info from the Hub module.""" - with tf.Graph().as_default(): - albert_module = hub.Module(albert_hub_module_handle) - tokenization_info = albert_module(signature="tokenization_info", - as_dict=True) - with tf.Session() as sess: - vocab_file, do_lower_case = sess.run([tokenization_info["vocab_file"], - tokenization_info["do_lower_case"]]) - return tokenization.FullTokenizer( - vocab_file=vocab_file, do_lower_case=do_lower_case, - spm_model_file=FLAGS.spm_model_file) - - -def main(_): - tf.logging.set_verbosity(tf.logging.INFO) - - processors = { - "cola": classifier_utils.ColaProcessor, - "mnli": classifier_utils.MnliProcessor, - "mrpc": classifier_utils.MrpcProcessor, - } - - if not FLAGS.do_train and not FLAGS.do_eval: - raise ValueError("At least one of `do_train` or `do_eval` must be True.") - - tf.gfile.MakeDirs(FLAGS.output_dir) - - task_name = FLAGS.task_name.lower() - - if task_name not in processors: - raise ValueError("Task not found: %s" % (task_name)) - - processor = processors[task_name]() - - label_list = processor.get_labels() - - tokenizer = create_tokenizer_from_hub_module(FLAGS.albert_hub_module_handle) - - tpu_cluster_resolver = None - if FLAGS.use_tpu and FLAGS.tpu_name: - tpu_cluster_resolver = contrib_cluster_resolver.TPUClusterResolver( - FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - is_per_host = contrib_tpu.InputPipelineConfig.PER_HOST_V2 - run_config = contrib_tpu.RunConfig( - cluster=tpu_cluster_resolver, - master=FLAGS.master, - model_dir=FLAGS.output_dir, - save_checkpoints_steps=FLAGS.save_checkpoints_steps, - tpu_config=contrib_tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations_per_loop, - num_shards=FLAGS.num_tpu_cores, - per_host_input_for_training=is_per_host)) - - train_examples = None - num_train_steps = None - num_warmup_steps = None - if FLAGS.do_train: - train_examples = processor.get_train_examples(FLAGS.data_dir) - num_train_steps = int( - len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) - num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) - - model_fn = model_fn_builder( - num_labels=len(label_list), - learning_rate=FLAGS.learning_rate, - num_train_steps=num_train_steps, - num_warmup_steps=num_warmup_steps, - use_tpu=FLAGS.use_tpu, - albert_hub_module_handle=FLAGS.albert_hub_module_handle) - - # If TPU is not available, this will fall back to normal Estimator on CPU - # or GPU. - estimator = contrib_tpu.TPUEstimator( - use_tpu=FLAGS.use_tpu, - model_fn=model_fn, - config=run_config, - train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size, - predict_batch_size=FLAGS.predict_batch_size) - - if FLAGS.do_train: - train_features = classifier_utils.convert_examples_to_features( - train_examples, label_list, FLAGS.max_seq_length, tokenizer, task_name) - tf.logging.info("***** Running training *****") - tf.logging.info(" Num examples = %d", len(train_examples)) - tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) - tf.logging.info(" Num steps = %d", num_train_steps) - train_input_fn = classifier_utils.input_fn_builder( - features=train_features, - seq_length=FLAGS.max_seq_length, - is_training=True, - drop_remainder=True) - estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) - - if FLAGS.do_eval: - eval_examples = processor.get_dev_examples(FLAGS.data_dir) - eval_features = classifier_utils.convert_examples_to_features( - eval_examples, label_list, FLAGS.max_seq_length, tokenizer, task_name) - - tf.logging.info("***** Running evaluation *****") - tf.logging.info(" Num examples = %d", len(eval_examples)) - tf.logging.info(" Batch size = %d", FLAGS.eval_batch_size) - - # This tells the estimator to run through the entire set. - eval_steps = None - # However, if running eval on the TPU, you will need to specify the - # number of steps. - if FLAGS.use_tpu: - # Eval will be slightly WRONG on the TPU because it will truncate - # the last batch. - eval_steps = int(len(eval_examples) / FLAGS.eval_batch_size) - - eval_drop_remainder = True if FLAGS.use_tpu else False - eval_input_fn = classifier_utils.input_fn_builder( - features=eval_features, - seq_length=FLAGS.max_seq_length, - is_training=False, - drop_remainder=eval_drop_remainder) - - result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps) - - output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt") - with tf.gfile.GFile(output_eval_file, "w") as writer: - tf.logging.info("***** Eval results *****") - for key in sorted(result.keys()): - tf.logging.info(" %s = %s", key, str(result[key])) - writer.write("%s = %s\n" % (key, str(result[key]))) - - if FLAGS.do_predict: - predict_examples = processor.get_test_examples(FLAGS.data_dir) - if FLAGS.use_tpu: - # Discard batch remainder if running on TPU - n = len(predict_examples) - predict_examples = predict_examples[:(n - n % FLAGS.predict_batch_size)] - - predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record") - classifier_utils.file_based_convert_examples_to_features( - predict_examples, label_list, FLAGS.max_seq_length, tokenizer, - predict_file) - - tf.logging.info("***** Running prediction*****") - tf.logging.info(" Num examples = %d", len(predict_examples)) - tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size) - - predict_input_fn = classifier_utils.file_based_input_fn_builder( - input_file=predict_file, - seq_length=FLAGS.max_seq_length, - is_training=False, - drop_remainder=FLAGS.use_tpu) - - result = estimator.predict(input_fn=predict_input_fn) - - output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv") - with tf.gfile.GFile(output_predict_file, "w") as writer: - tf.logging.info("***** Predict results *****") - for prediction in result: - probabilities = prediction["probabilities"] - output_line = "\t".join( - str(class_probability) - for class_probability in probabilities) + "\n" - writer.write(output_line) - - -if __name__ == "__main__": - flags.mark_flag_as_required("data_dir") - flags.mark_flag_as_required("task_name") - flags.mark_flag_as_required("albert_hub_module_handle") - flags.mark_flag_as_required("output_dir") - tf.app.run() diff --git a/tokenization.py b/tokenization.py index 4cad9527..30161eae 100644 --- a/tokenization.py +++ b/tokenization.py @@ -26,6 +26,7 @@ import six from six.moves import range import tensorflow.compat.v1 as tf +import tensorflow_hub as hub import sentencepiece as spm SPIECE_UNDERLINE = u"▁".encode("utf-8") @@ -248,6 +249,25 @@ def __init__(self, vocab_file, do_lower_case=True, spm_model_file=None): self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.inv_vocab = {v: k for k, v in self.vocab.items()} + @classmethod + def from_scratch(cls, vocab_file, do_lower_case, spm_model_file): + return FullTokenizer(vocab_file, do_lower_case, spm_model_file) + + @classmethod + def from_hub_module(cls, hub_module, spm_model_file): + """Get the vocab file and casing info from the Hub module.""" + with tf.Graph().as_default(): + albert_module = hub.Module(hub_module) + tokenization_info = albert_module(signature="tokenization_info", + as_dict=True) + with tf.Session() as sess: + vocab_file, do_lower_case = sess.run( + [tokenization_info["vocab_file"], + tokenization_info["do_lower_case"]]) + return FullTokenizer( + vocab_file=vocab_file, do_lower_case=do_lower_case, + spm_model_file=spm_model_file) + def tokenize(self, text): if self.sp_model: split_tokens = encode_pieces(self.sp_model, text, return_unicode=False)