Skip to content

Commit

Permalink
Remove pack_inputs/unpack_inputs in bert
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 291082425
  • Loading branch information
saberkun authored and tensorflower-gardener committed Jan 23, 2020
1 parent df15a27 commit c813d85
Showing 1 changed file with 3 additions and 21 deletions.
24 changes: 3 additions & 21 deletions official/nlp/bert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,6 @@ def __init__(self, vocab_size, **kwargs):
'vocab_size': vocab_size,
}

def __call__(self,
lm_output,
sentence_output=None,
lm_label_ids=None,
lm_label_weights=None,
sentence_labels=None,
**kwargs):
inputs = tf_utils.pack_inputs([
lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels
])
return super(BertPretrainLossAndMetricLayer,
self).__call__(inputs, **kwargs)

def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
lm_example_loss, sentence_output, sentence_labels,
next_sentence_loss):
Expand All @@ -110,14 +96,10 @@ def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
self.add_metric(
next_sentence_loss, name='next_sentence_loss', aggregation='mean')

def call(self, inputs):
def call(self, lm_output, sentence_output, lm_label_ids, lm_label_weights,
sentence_labels):
"""Implements call() for the layer."""
unpacked_inputs = tf_utils.unpack_inputs(inputs)
lm_output = unpacked_inputs[0]
sentence_output = unpacked_inputs[1]
lm_label_ids = unpacked_inputs[2]
lm_label_weights = tf.keras.backend.cast(unpacked_inputs[3], tf.float32)
sentence_labels = unpacked_inputs[4]
lm_label_weights = tf.keras.backend.cast(lm_label_weights, tf.float32)

mask_label_loss = losses.weighted_sparse_categorical_crossentropy_loss(
labels=lm_label_ids, predictions=lm_output, weights=lm_label_weights)
Expand Down

0 comments on commit c813d85

Please sign in to comment.