-
Notifications
You must be signed in to change notification settings - Fork 2
/
hubconf.py
81 lines (64 loc) · 3.44 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
dependencies = ['paddle', 'jieba', 'colorlog', 'colorama', 'seqeval']
import os
from paddlenlp.transformers import BertForPretraining, BertModel, BertForSequenceClassification
from paddlenlp.transformers import BertForTokenClassification, BertForQuestionAnswering
from paddlenlp.transformers import BertTokenizer
_BERT_MODEL_CLASSES = {
"bert": BertModel,
"sequence_classification": BertForSequenceClassification,
"token_classification": BertForTokenClassification,
"question_answering": BertForQuestionAnswering,
"pretraining": BertForPretraining
}
_BERT_PRETRAINED_MODELS = [
'bert-base-uncased', 'bert-large-uncased', 'bert-base-multilingual-uncased',
'bert-base-cased', 'bert-base-chinese', 'bert-large-cased',
'bert-base-multilingual-cased', 'bert-wwm-chinese', 'bert-wwm-ext-chinese'
]
def bert(model_name_or_path='bert-base-uncased',
model_select='sequence_classification'):
"""
Returns BERT model and tokenizer from given pretrained model name or path
and class type of tasks, such as sequence classification.
Args:
model_name_or_path (str, optional): A name of or a file path to a
pretrained model. It could be 'bert-base-uncased',
'bert-large-uncased', 'bert-base-multilingual-uncased',
'bert-base-cased', 'bert-base-chinese', 'bert-large-cased',
'bert-base-multilingual-cased', 'bert-wwm-chinese' or
'bert-wwm-ext-chinese'. Default: 'bert-base-uncased'.
model_select (str, optional): model class to select. It could be
'bert', 'sequence_classification', 'token_classification',
'question_answering' or 'pretraining'. If 'sequence_classification'
is chosen, model class would be `BertForSequenceClassification`.
The document of BERT model could be seen at `bert.modeling
<https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.bert.modeling.html>`_
Default: 'sequence_classification'.
Returns:
tuple: Returns the pretrained bert model and bert tokenizer.
Example:
.. code-block:: python
import paddle.hub as hub
model, tokenizer = hub.load('PaddlePaddle/PaddleNLP:develop', model='bert', model_name_or_path='bert-base-cased')
"""
assert model_name_or_path in _BERT_PRETRAINED_MODELS or os.path.isdir(model_name_or_path), \
"Please check your model name or path. Supported model names are: {}.".format(tuple(_BERT_PRETRAINED_MODELS))
assert model_select in _BERT_MODEL_CLASSES.keys(), \
"Please check `model_select`, it should be in {}.".format(tuple(_BERT_MODEL_CLASSES.keys()))
model_class = _BERT_MODEL_CLASSES[model_select]
model = model_class.from_pretrained(model_name_or_path)
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
return model, tokenizer