Skip to content

Commit

Permalink
fix the attention mask description error and a cola evaluation calcul…
Browse files Browse the repository at this point in the history
…ation error.

PiperOrigin-RevId: 307494726
  • Loading branch information
Danny-Google authored and albert-copybara committed Apr 20, 2020
1 parent a41cf11 commit c21d8a3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
14 changes: 9 additions & 5 deletions classifier_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,17 +916,21 @@ def metric_fn(per_example_loss, label_ids, logits, is_real_example):
"MSE": mse, "eval_loss": loss,}
elif task_name == "cola":
def metric_fn(per_example_loss, label_ids, logits, is_real_example):
"""Compute Matthew's correlations for STS-B."""
"""Compute Matthew's correlations for COLA."""
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
# https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
tp, tp_op = tf.metrics.true_positives(
predictions, label_ids, weights=is_real_example)
labels=label_ids, predictions=predictions,
weights=is_real_example)
tn, tn_op = tf.metrics.true_negatives(
predictions, label_ids, weights=is_real_example)
labels=label_ids, predictions=predictions,
weights=is_real_example)
fp, fp_op = tf.metrics.false_positives(
predictions, label_ids, weights=is_real_example)
labels=label_ids, predictions=predictions,
weights=is_real_example)
fn, fn_op = tf.metrics.false_negatives(
predictions, label_ids, weights=is_real_example)
labels=label_ids, predictions=predictions,
weights=is_real_example)

# Compute Matthew's correlation
mcc = tf.div_no_nan(
Expand Down
22 changes: 11 additions & 11 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,10 +854,10 @@ def attention_layer(from_tensor,
from_tensor: float Tensor of shape [batch_size, from_seq_length,
from_width].
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
attention_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length, to_seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions in
the mask that are 0, and will be unchanged for positions that are 1.
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
The values should be 1 or 0. The attention scores will effectively
be set to -infinity for any positions in the mask that are 0, and
will be unchanged for positions that are 1.
num_attention_heads: int. Number of attention heads.
query_act: (optional) Activation function for the query transform.
key_act: (optional) Activation function for the key transform.
Expand Down Expand Up @@ -949,10 +949,10 @@ def attention_ffn_block(layer_input,
layer_input: float Tensor of shape [batch_size, from_seq_length,
from_width].
hidden_size: (optional) int, size of hidden layer.
attention_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length, to_seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions in
the mask that are 0, and will be unchanged for positions that are 1.
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
The values should be 1 or 0. The attention scores will effectively be set
to -infinity for any positions in the mask that are 0, and will be
unchanged for positions that are 1.
num_attention_heads: int. Number of attention heads.
attention_head_size: int. Size of attention head.
attention_probs_dropout_prob: float. dropout probability for attention_layer
Expand Down Expand Up @@ -1042,9 +1042,9 @@ def transformer_model(input_tensor,
Args:
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length],
with 1 for positions that can be attended to and 0 in positions that
should not be.
hidden_size: int. Hidden size of the Transformer.
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
num_hidden_groups: int. Number of group for the hidden layers, parameters
Expand Down

0 comments on commit c21d8a3

Please sign in to comment.