Skip to content

Commit

Permalink
Fix inverted conditional in TF common test! (huggingface#22540)
Browse files Browse the repository at this point in the history
* Fix inverted conditional in TF common test!

* Make the same change in the PT tests file

* Make sure hidden states for GPT2 have the same output shape in PT/TF

* Minor fix to PT implementation of token classification loss

* Skip loss equivalence test for TFHubert because it keeps overflowing to inf

* Compute LM loss for TF the (weird) way it's computed in PT

* Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert

* Fix - don't try to access the hidden states property when output is a tuple
  • Loading branch information
Rocketknight1 authored Apr 4, 2023
1 parent 48fbd8f commit edb704b
Show file tree
Hide file tree
Showing 7 changed files with 245 additions and 17 deletions.
11 changes: 1 addition & 10 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,16 +1228,7 @@ def forward(
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)
active_labels = torch.where(
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[2:]
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,12 @@ def call(
)
hidden_states = transformer_outputs[0]
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
if return_dict and output_hidden_states:
# We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the
# input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged)
all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)
else:
all_hidden_states = None
lm_logits = self.transformer.wte(hidden_states, mode="linear")
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
mc_logits = tf.squeeze(mc_logits, axis=-1)
Expand All @@ -1062,7 +1068,7 @@ def call(
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
hidden_states=all_hidden_states,
attentions=transformer_outputs.attentions,
)

Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/xglm/modeling_tf_xglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,9 +953,11 @@ def call(
loss = None
if labels is not None:
# shift labels to the left and cut last logit token
shifted_logits = lm_logits[:, :-1]
labels = labels[:, 1:]
loss = self.hf_compute_loss(labels, shifted_logits)
labels = tf.concat(
[labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))],
axis=-1,
)
loss = self.hf_compute_loss(labels, lm_logits)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down
116 changes: 115 additions & 1 deletion tests/models/hubert/test_modeling_tf_hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import copy
import inspect
import math
import os
import tempfile
import unittest

import numpy as np
import pytest

from transformers import is_tf_available
from transformers.testing_utils import require_soundfile, require_tf, slow
from transformers.testing_utils import is_pt_tf_cross_test, require_soundfile, require_tf, slow

from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
Expand Down Expand Up @@ -333,6 +335,62 @@ def test_keras_fit(self):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass

@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch

import transformers

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)

pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

tf_model = model_class(config)
pt_model = pt_model_class(config)

tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)

# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)


@require_tf
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -458,6 +516,62 @@ def test_keras_fit(self):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass

@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch

import transformers

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)

pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

tf_model = model_class(config)
pt_model = pt_model_class(config)

tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)

# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)


@require_tf
class TFHubertUtilsTest(unittest.TestCase):
Expand Down
115 changes: 115 additions & 0 deletions tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import inspect
import math
import multiprocessing
import os
import tempfile
import traceback
import unittest

Expand All @@ -31,6 +33,7 @@
from transformers.testing_utils import (
CaptureLogger,
is_flaky,
is_pt_tf_cross_test,
require_librosa,
require_pyctcdecode,
require_tf,
Expand Down Expand Up @@ -397,6 +400,62 @@ def test_keras_fit(self):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass

@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch

import transformers

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)

pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

tf_model = model_class(config)
pt_model = pt_model_class(config)

tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)

# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)


@require_tf
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
Expand Down Expand Up @@ -524,6 +583,62 @@ def test_keras_fit(self):
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
pass

@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
import torch

import transformers

for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions

# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
self._make_attention_mask_non_null(inputs_dict)

pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)

tf_model = model_class(config)
pt_model = pt_model_class(config)

tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)

# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
)
pt_model = transformers.load_tf2_model_in_pytorch_model(
pt_model, tf_model, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)

# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
)

tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
)

# Original test: check without `labels`
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)


@require_tf
class TFWav2Vec2UtilsTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,7 +2030,7 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False):

# For some models (e.g. base models), there is no label returned.
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
if set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
if not set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
pt_inputs_dict_with_labels = None

# Check we can load pt model in tf and vice-versa with model => model functions
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False):

# For some models (e.g. base models), there is no label returned.
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
if set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
if not set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()):
tf_inputs_dict_with_labels = None

# Check we can load pt model in tf and vice-versa with model => model functions
Expand Down

0 comments on commit edb704b

Please sign in to comment.