Skip to content

Commit

Permalink
Config to Model mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Jan 13, 2020
1 parent cf8a70b commit b803b06
Show file tree
Hide file tree
Showing 17 changed files with 58 additions and 55 deletions.
1 change: 1 addition & 0 deletions examples/summarization/configuration_bertabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class BertAbsConfig(PretrainedConfig):
"""

pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
model_type = "bertabs"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AlbertConfig(PretrainedConfig):
"""

pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "albert"

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
assert unused_kwargs == {'foo': False}
"""
config_dict, _ = PretrainedConfig.resolved_config_dict(
config_dict, _ = PretrainedConfig.get_config_dict(
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class BertConfig(PretrainedConfig):
layer_norm_eps: The epsilon used by LayerNorm.
"""
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "bert"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@

class CamembertConfig(RobertaConfig):
pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "camembert"
1 change: 1 addition & 0 deletions src/transformers/configuration_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class CTRLConfig(PretrainedConfig):
"""

pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "ctrl"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

class DistilBertConfig(PretrainedConfig):
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "distilbert"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class GPT2Config(PretrainedConfig):
"""

pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "gpt2"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class OpenAIGPTConfig(PretrainedConfig):
"""

pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "openai-gpt"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class T5Config(PretrainedConfig):
layer_norm_eps: The epsilon used by LayerNorm.
"""
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "t5"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class TransfoXLConfig(PretrainedConfig):
"""

pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "transfo-xl"

def __init__(
self,
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class PretrainedConfig(object):
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
``torchscript``: string, default `False`. Is the model used with Torchscript.
"""
pretrained_config_archive_map = {}
pretrained_config_archive_map: Dict[str, str] = {}
model_type: str

def __init__(self, **kwargs):
# Attributes with defaults
Expand Down Expand Up @@ -155,11 +156,11 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
assert unused_kwargs == {'foo': False}
"""
config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)

@classmethod
def resolved_config_dict(
def get_config_dict(
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
) -> Tuple[Dict, Dict]:
"""
Expand Down Expand Up @@ -257,7 +258,7 @@ def from_dict(cls, config_dict: Dict, **kwargs):

@classmethod
def from_json_file(cls, json_file: str):
"""Constructs a `Config` from a json file of parameters."""
"""Constructs a `Config` from the path to a json file of parameters."""
config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class XLMConfig(PretrainedConfig):
"""

pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "xlm"

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/configuration_xlm_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@

class XLMRobertaConfig(RobertaConfig):
pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "xlm-roberta"
1 change: 1 addition & 0 deletions src/transformers/configuration_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class XLNetConfig(PretrainedConfig):
"""

pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "xlnet"

def __init__(
self,
Expand Down
88 changes: 38 additions & 50 deletions src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


import logging
from collections import OrderedDict
from typing import Type

from .configuration_auto import (
AlbertConfig,
Expand Down Expand Up @@ -76,6 +78,7 @@
)
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
from .modeling_utils import PreTrainedModel
from .modeling_xlm import (
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
XLMForQuestionAnswering,
Expand Down Expand Up @@ -123,6 +126,35 @@
for key, value, in pretrained_map.items()
)

MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
[
(T5Config, T5Model),
(DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel),
(CamembertConfig, CamembertModel),
(RobertaConfig, XLMRobertaModel),
(XLMRobertaConfig, RobertaModel),
(BertConfig, BertModel),
(OpenAIGPTConfig, OpenAIGPTModel),
(GPT2Config, GPT2Model),
(TransfoXLConfig, TransfoXLModel),
(XLNetConfig, XLNetModel),
(XLMConfig, XLMModel),
(CTRLConfig, CTRLModel),
]
)

MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
[
(DistilBertConfig, DistilBertForTokenClassification),
(CamembertConfig, CamembertForTokenClassification),
(RobertaConfig, XLMRobertaForTokenClassification),
(XLMRobertaConfig, RobertaForTokenClassification),
(BertConfig, BertForTokenClassification),
(XLNetConfig, XLNetForTokenClassification),
]
)


class AutoModel(object):
r"""
Expand Down Expand Up @@ -183,30 +215,9 @@ def from_config(cls, config):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
model = AutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
"""
if isinstance(config, DistilBertConfig):
return DistilBertModel(config)
elif isinstance(config, RobertaConfig):
return RobertaModel(config)
elif isinstance(config, BertConfig):
return BertModel(config)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTModel(config)
elif isinstance(config, GPT2Config):
return GPT2Model(config)
elif isinstance(config, TransfoXLConfig):
return TransfoXLModel(config)
elif isinstance(config, XLNetConfig):
return XLNetModel(config)
elif isinstance(config, XLMConfig):
return XLMModel(config)
elif isinstance(config, CTRLConfig):
return CTRLModel(config)
elif isinstance(config, AlbertConfig):
return AlbertModel(config)
elif isinstance(config, CamembertConfig):
return CamembertModel(config)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaModel(config)
for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class(config)
raise ValueError("Unrecognized configuration class {}".format(config))

@classmethod
Expand Down Expand Up @@ -294,32 +305,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

if isinstance(config, T5Config):
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, DistilBertConfig):
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, AlbertConfig):
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CamembertConfig):
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMRobertaConfig):
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, RobertaConfig):
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, BertConfig):
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, OpenAIGPTConfig):
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, GPT2Config):
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, TransfoXLConfig):
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLNetConfig):
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, XLMConfig):
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif isinstance(config, CTRLConfig):
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
raise ValueError(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
Expand Down
1 change: 1 addition & 0 deletions templates/adding_a_new_model/configuration_xxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class XxxConfig(PretrainedConfig):
layer_norm_eps: The epsilon used by LayerNorm.
"""
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
model_type = "xxx"

def __init__(
self,
Expand Down

0 comments on commit b803b06

Please sign in to comment.