Skip to content

Commit

Permalink
Change a logic in pipeline test regarding TF (huggingface#20710)
Browse files Browse the repository at this point in the history
* Fix the pipeline test regarding TF

* Fix the pipeline test regarding TF

* update comment

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Dec 13, 2022
1 parent 1af4bee commit a12c5cb
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions tests/pipelines/test_pipelines_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
SummarizationPipeline,
TFPreTrainedModel,
pipeline,
)
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
from transformers.testing_utils import get_gpu_count, require_tf, require_torch, slow, torch_device
from transformers.tokenization_utils import TruncationStrategy

from .test_pipelines_common import ANY, PipelineTestCaseMeta
Expand Down Expand Up @@ -51,6 +52,7 @@ def run_pipeline_test(self, summarizer, _):
)
self.assertEqual(outputs, [{"summary_text": ANY(str)}])

# Some models (Switch Transformers, LED, T5, LongT5, etc) can handle long sequences.
model_can_handle_longer_seq = [
"SwitchTransformersConfig",
"T5Config",
Expand All @@ -62,10 +64,16 @@ def run_pipeline_test(self, summarizer, _):
"ProphetNetConfig", # positional embeddings up to a fixed maximum size (otherwise clamping the values)
]
if model.config.__class__.__name__ not in model_can_handle_longer_seq:
# Switch Transformers, LED, T5, LongT5 can handle it.
# Too long.
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
# Too long and exception is expected.
# For TF models, if the weights are initialized in GPU context, we won't get expected index error from
# the embedding layer.
if not (
isinstance(model, TFPreTrainedModel)
and get_gpu_count() > 0
and len(summarizer.model.trainable_weights) > 0
):
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
outputs = summarizer("This " * 1000, truncation=TruncationStrategy.ONLY_FIRST)

@require_torch
Expand Down

0 comments on commit a12c5cb

Please sign in to comment.