Skip to content

Commit

Permalink
Add new AutoModel classes in pipeline (huggingface#6062)
Browse files Browse the repository at this point in the history
* use new AutoModel classed

* make style and quality
  • Loading branch information
patil-suraj authored Jul 27, 2020
1 parent 5779e54 commit c8bdf7f
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@
AutoModelForSequenceClassification,
AutoModelForQuestionAnswering,
AutoModelForTokenClassification,
AutoModelWithLMHead,
AutoModelForSeq2SeqLM,
MODEL_WITH_LM_HEAD_MAPPING,
AutoModelForCausalLM,
AutoModelForMaskedLM,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def __init__(
task=task,
)

self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_MASKED_LM_MAPPING)

self.topk = topk

Expand Down Expand Up @@ -1817,7 +1818,9 @@ class TranslationPipeline(Pipeline):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.check_model_type(TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_WITH_LM_HEAD_MAPPING)
self.check_model_type(
TF_MODEL_WITH_LM_HEAD_MAPPING if self.framework == "tf" else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
)

def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
Expand Down Expand Up @@ -1933,7 +1936,7 @@ def __call__(
"fill-mask": {
"impl": FillMaskPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"pt": AutoModelForMaskedLM if is_torch_available() else None,
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
},
"summarization": {
Expand All @@ -1945,25 +1948,25 @@ def __call__(
"translation_en_to_fr": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"translation_en_to_de": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"translation_en_to_ro": {
"impl": TranslationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"pt": AutoModelForSeq2SeqLM if is_torch_available() else None,
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
},
"text-generation": {
"impl": TextGenerationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"pt": AutoModelForCausalLM if is_torch_available() else None,
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
},
"zero-shot-classification": {
Expand Down

0 comments on commit c8bdf7f

Please sign in to comment.