Skip to content

Commit

Permalink
[T5, MT5, UMT5] Add [T5, MT5, UMT5]ForSequenceClassification (h…
Browse files Browse the repository at this point in the history
…uggingface#24726)

* Initial addition of t5forsequenceclassification

* Adding imports and adding tests

* Formatting

* Running make fix-copies

* Adding mt5forseq

* Formatting

* run make fix-copies

* Adding to docs

* Add model_parallel

* Fix bug

* Fix

* Remove TODO

* Fixing tests for T5ForSequenceClassification

* Undo changes to dependency_versions_table.py

* Change classification head to work with T5Config directly

* Change seq length to let tests pass

* PR comments for formatting

* Formatting

* Initial addition of UMT5ForSequenceClassification

* Adding to inits and formatting

* run make fix-copies

* Add doc for UMT5ForSeqClass

* Update UMT5 config

* Fix docs

* Skip torch fx test for SequenceClassification

* Formatting

* Add skip to UMT5 tests as well

* Fix umt5 tests

* Running make fix-copies

* PR comments

* Fix for change to sentence_representation

* Rename seq_len to hidden_size since that's what it is

* Use base_model to follow format of the rest of the library

* Update docs

* Extract the decoder_input_ids changes and make one liner

* Make one-liner
  • Loading branch information
sjrl authored Jul 25, 2023
1 parent 21150cb commit 8f36ab3
Show file tree
Hide file tree
Showing 18 changed files with 960 additions and 21 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/mt5.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ See [`T5TokenizerFast`] for all details.

[[autodoc]] MT5EncoderModel

## MT5ForSequenceClassification

[[autodoc]] MT5ForSequenceClassification

## MT5ForQuestionAnswering

[[autodoc]] MT5ForQuestionAnswering
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/t5.md
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] T5EncoderModel
- forward

## T5ForSequenceClassification

[[autodoc]] T5ForSequenceClassification
- forward

## T5ForQuestionAnswering

[[autodoc]] T5ForQuestionAnswering
Expand Down
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/umt5.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ The conversion script is also different because the model was saved in t5x's lat
[[autodoc]] UMT5EncoderModel
- forward

## UMT5ForSequenceClassification

[[autodoc]] UMT5ForSequenceClassification
- forward

## UMT5ForQuestionAnswering

[[autodoc]] UMT5ForQuestionAnswering
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/tasks/sequence_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->


[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)



Expand Down
14 changes: 13 additions & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,7 +2240,14 @@
]
)
_import_structure["models.mt5"].extend(
["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5ForQuestionAnswering", "MT5Model", "MT5PreTrainedModel"]
[
"MT5EncoderModel",
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5Model",
"MT5PreTrainedModel",
]
)
_import_structure["models.musicgen"].extend(
[
Expand Down Expand Up @@ -2694,6 +2701,7 @@
"T5EncoderModel",
"T5ForConditionalGeneration",
"T5ForQuestionAnswering",
"T5ForSequenceClassification",
"T5Model",
"T5PreTrainedModel",
"load_tf_weights_in_t5",
Expand Down Expand Up @@ -2763,6 +2771,7 @@
"UMT5EncoderModel",
"UMT5ForConditionalGeneration",
"UMT5ForQuestionAnswering",
"UMT5ForSequenceClassification",
"UMT5Model",
"UMT5PreTrainedModel",
]
Expand Down Expand Up @@ -5930,6 +5939,7 @@
MT5EncoderModel,
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5Model,
MT5PreTrainedModel,
)
Expand Down Expand Up @@ -6303,6 +6313,7 @@
T5EncoderModel,
T5ForConditionalGeneration,
T5ForQuestionAnswering,
T5ForSequenceClassification,
T5Model,
T5PreTrainedModel,
load_tf_weights_in_t5,
Expand Down Expand Up @@ -6356,6 +6367,7 @@
UMT5EncoderModel,
UMT5ForConditionalGeneration,
UMT5ForQuestionAnswering,
UMT5ForSequenceClassification,
UMT5Model,
UMT5PreTrainedModel,
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@
("mpnet", "MPNetForSequenceClassification"),
("mpt", "MptForSequenceClassification"),
("mra", "MraForSequenceClassification"),
("mt5", "MT5ForSequenceClassification"),
("mvp", "MvpForSequenceClassification"),
("nezha", "NezhaForSequenceClassification"),
("nystromformer", "NystromformerForSequenceClassification"),
Expand All @@ -740,8 +741,10 @@
("roc_bert", "RoCBertForSequenceClassification"),
("roformer", "RoFormerForSequenceClassification"),
("squeezebert", "SqueezeBertForSequenceClassification"),
("t5", "T5ForSequenceClassification"),
("tapas", "TapasForSequenceClassification"),
("transfo-xl", "TransfoXLForSequenceClassification"),
("umt5", "UMT5ForSequenceClassification"),
("xlm", "XLMForSequenceClassification"),
("xlm-roberta", "XLMRobertaForSequenceClassification"),
("xlm-roberta-xl", "XLMRobertaXLForSequenceClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mt5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"MT5EncoderModel",
"MT5ForConditionalGeneration",
"MT5ForQuestionAnswering",
"MT5ForSequenceClassification",
"MT5Model",
"MT5PreTrainedModel",
"MT5Stack",
Expand Down Expand Up @@ -86,6 +87,7 @@
MT5EncoderModel,
MT5ForConditionalGeneration,
MT5ForQuestionAnswering,
MT5ForSequenceClassification,
MT5Model,
MT5PreTrainedModel,
MT5Stack,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/mt5/configuration_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class MT5Config(PretrainedConfig):
The maximum distance of the longer sequences for the bucket separation.
dropout_rate (`float`, *optional*, defaults to 0.1):
The ratio for all dropout layers.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
initializer_factor (`float`, *optional*, defaults to 1):
Expand Down Expand Up @@ -91,6 +93,7 @@ def __init__(
pad_token_id=0,
eos_token_id=1,
decoder_start_token_id=0,
classifier_dropout=0.0,
**kwargs,
):
super().__init__(
Expand All @@ -114,6 +117,7 @@ def __init__(
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dropout_rate = dropout_rate
self.classifier_dropout = classifier_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
Expand Down
Loading

0 comments on commit 8f36ab3

Please sign in to comment.