Skip to content

Commit

Permalink
Merging run_classifier_with_tfhub into run_classifier because the imp…
Browse files Browse the repository at this point in the history
…lementations diverged.

Requiring tensorflow 1.15 for einsum grad issue.

PiperOrigin-RevId: 286643632
  • Loading branch information
0x0539 authored and albert-copybara committed Dec 20, 2019
1 parent eda4dc2 commit 941051a
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 342 deletions.
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ For XNLI, COLA, MNLI, and MRPC, use `run_classifier_sp.py`:

```
pip install -r albert/requirements.txt
python -m albert.run_classifier_sp \
python -m albert.run_classifier \
--albert_config_file=albert_config.json \
--init_checkpoint=/path/to/ckpt \
--task_name=MNLI \
<additional flags>
```
Expand All @@ -142,12 +144,12 @@ You should see some output like this:
sentence_order_loss = ...
```

You can also fine-tune the model starting from TF-Hub modules using
`run_classifier_with_tfhub.py`:
You can also fine-tune the model starting from TF-Hub modules:

```
pip install -r albert/requirements.txt
python -m albert.run_classifier_with_tfhub \
python -m albert.run_classifier \
--albert_hub_module_handle=https://tfhub.dev/google/albert_base/1 \
--task_name=MNLI \
<additional flags>
```
54 changes: 43 additions & 11 deletions classifier_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import optimization
import tokenization
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
from tensorflow.contrib import data as contrib_data
from tensorflow.contrib import metrics as contrib_metrics
from tensorflow.contrib import tpu as contrib_tpu
Expand Down Expand Up @@ -765,23 +766,53 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
tokens_b.pop()


def create_model(albert_config, is_training, input_ids, input_mask, segment_ids,
labels, num_labels, use_one_hot_embeddings, task_name):
"""Creates a classification model."""
def _create_model_from_hub(hub_module, is_training, input_ids, input_mask,
segment_ids):
"""Creates an ALBERT model from TF-Hub."""
tags = set()
if is_training:
tags.add("train")
albert_module = hub.Module(hub_module, tags=tags, trainable=True)
albert_inputs = dict(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids)
albert_outputs = albert_module(
inputs=albert_inputs,
signature="tokens",
as_dict=True)
output_layer = albert_outputs["pooled_output"]
return output_layer


def _create_model_from_scratch(albert_config, is_training, input_ids,
input_mask, segment_ids, use_one_hot_embeddings):
"""Creates an ALBERT model from scratch (as opposed to hub)."""
model = modeling.AlbertModel(
config=albert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)

# In the demo, we are doing a simple classification task on the entire
# segment.
#
# If you want to use the token-level output, use model.get_sequence_output()
# instead.
output_layer = model.get_pooled_output()
return output_layer


def create_model(albert_config, is_training, input_ids, input_mask, segment_ids,
labels, num_labels, use_one_hot_embeddings, task_name,
hub_module):
"""Creates a classification model."""
if hub_module:
tf.logging.info("creating model from hub_module: %s", hub_module)
output_layer = _create_model_from_hub(hub_module, is_training, input_ids,
input_mask, segment_ids)
else:
tf.logging.info("creating model from albert_config")
output_layer = _create_model_from_scratch(albert_config, is_training,
input_ids, input_mask,
segment_ids,
use_one_hot_embeddings)

hidden_size = output_layer.shape[-1].value

Expand Down Expand Up @@ -818,7 +849,8 @@ def create_model(albert_config, is_training, input_ids, input_mask, segment_ids,

def model_fn_builder(albert_config, num_labels, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings, task_name, optimizer="adamw"):
use_one_hot_embeddings, task_name, hub_module=None,
optimizer="adamw"):
"""Returns `model_fn` closure for TPUEstimator."""

def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
Expand All @@ -843,7 +875,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
(total_loss, per_example_loss, probabilities, logits, predictions) = \
create_model(albert_config, is_training, input_ids, input_mask,
segment_ids, label_ids, num_labels,
use_one_hot_embeddings, task_name)
use_one_hot_embeddings, task_name, hub_module)

tvars = tf.trainable_variables()
initialized_variable_names = {}
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tensorflow >=1.11.0,<2.0.0 # CPU Version of TensorFlow.
# tensorflow-gpu >=1.11.0,<2.0.0 # GPU version of TensorFlow.
tensorflow==1.15 # CPU Version of TensorFlow (Python2-only)
# tensorflow-gpu==1.15 # GPU version of TensorFlow (Python2-only)
sentencepiece
38 changes: 27 additions & 11 deletions run_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")

flags.DEFINE_string(
"albert_hub_module_handle", None,
"If set, the ALBERT hub module to use.")

flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
Expand Down Expand Up @@ -160,13 +164,20 @@ def main(_):
raise ValueError(
"At least one of `do_train`, `do_eval` or `do_predict' must be True.")

albert_config = modeling.AlbertConfig.from_json_file(FLAGS.albert_config_file)

if FLAGS.max_seq_length > albert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the ALBERT model "
"was only trained up to sequence length %d" %
(FLAGS.max_seq_length, albert_config.max_position_embeddings))
if not FLAGS.albert_config_file and not FLAGS.albert_hub_module_handle:
raise ValueError("At least one of `--albert_config_file` and "
"`--albert_hub_module_handle` must be set")

if FLAGS.albert_config_file:
albert_config = modeling.AlbertConfig.from_json_file(
FLAGS.albert_config_file)
if FLAGS.max_seq_length > albert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the ALBERT model "
"was only trained up to sequence length %d" %
(FLAGS.max_seq_length, albert_config.max_position_embeddings))
else:
albert_config = None # Get the config from TF-Hub.

tf.gfile.MakeDirs(FLAGS.output_dir)

Expand All @@ -181,9 +192,14 @@ def main(_):

label_list = processor.get_labels()

tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case,
spm_model_file=FLAGS.spm_model_file)
if FLAGS.albert_hub_module_handle:
tokenizer = tokenization.FullTokenizer.from_hub_module(
hub_module=FLAGS.albert_hub_module_handle,
spm_model_file=FLAGS.spm_model_file)
else:
tokenizer = tokenization.FullTokenizer.from_scratch(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case,
spm_model_file=FLAGS.spm_model_file)

tpu_cluster_resolver = None
if FLAGS.use_tpu and FLAGS.tpu_name:
Expand Down Expand Up @@ -220,6 +236,7 @@ def main(_):
use_tpu=FLAGS.use_tpu,
use_one_hot_embeddings=FLAGS.use_tpu,
task_name=task_name,
hub_module=FLAGS.albert_hub_module_handle,
optimizer=FLAGS.optimizer)

# If TPU is not available, this will fall back to normal Estimator on CPU
Expand Down Expand Up @@ -456,6 +473,5 @@ def _find_valid_cands(curr_step):
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("task_name")
flags.mark_flag_as_required("vocab_file")
flags.mark_flag_as_required("albert_config_file")
flags.mark_flag_as_required("output_dir")
tf.app.run()
Loading

0 comments on commit 941051a

Please sign in to comment.