Skip to content

Commit

Permalink
Modify Distilbert docstring (PaddlePaddle#949)
Browse files Browse the repository at this point in the history
* modify distilbert

* modify distilbert

* modify nezha tokenizer

* modify nezha

* update

* modify models

* modify args

* fix errors

* modify args
  • Loading branch information
huhuiwen99 authored Sep 22, 2021
1 parent 15693c7 commit 499cbe4
Show file tree
Hide file tree
Showing 4 changed files with 968 additions and 230 deletions.
274 changes: 273 additions & 1 deletion paddlenlp/transformers/distilbert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self, input_ids, position_ids=None):

class DistilBertPretrainedModel(PretrainedModel):
"""
An abstract class for pretrained DistilBERT models. It provides DistilBERT related
An abstract class for pretrained DistilBert models. It provides DistilBert related
`model_config_file`, `resource_files_names`, `pretrained_resource_files_map`,
`pretrained_init_configuration`, `base_model_prefix` for downloading and
loading pretrained models. See `PretrainedModel` for more details.
Expand Down Expand Up @@ -131,6 +131,62 @@ def init_weights(self, layer):

@register_base_model
class DistilBertModel(DistilBertPretrainedModel):
"""
The bare DistilBert Model transformer outputting raw hidden-states without any specific head on top.
This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`.
Refer to the superclass documentation for the generic methods.
This model is also a Paddle `paddle.nn.Layer <https://www.paddlepaddle.org.cn/documentation
/docs/en/api/paddle/fluid/dygraph/layers/Layer_en.html>`__ subclass. Use it as a regular Paddle Layer
and refer to the Paddle documentation for all matter related to general usage and behavior.
Args:
vocab_size (int):
Vocabulary size of `inputs_ids` in `DistilBertModel`. Defines the number of different tokens that can
be represented by the `inputs_ids` passed when calling `DistilBertModel`.
hidden_size (int, optional):
Dimensionality of the embedding layer, encoder layers and the pooler layer. Defaults to `768`.
num_hidden_layers (int, optional):
Number of hidden layers in the Transformer encoder. Defaults to `12`.
num_attention_heads (int, optional):
Number of attention heads for each attention layer in the Transformer encoder.
Defaults to `12`.
intermediate_size (int, optional):
Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors
to ff layers are firstly projected from `hidden_size` to `intermediate_size`,
and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`.
Defaults to `3072`.
hidden_act (str, optional):
The non-linear activation function in the feed-forward layer.
``"gelu"``, ``"relu"`` and any other paddle supported activation functions
are supported. Defaults to `"gelu"`.
hidden_dropout_prob (float, optional):
The dropout probability for all fully connected layers in the embeddings and encoder.
Defaults to `0.1`.
attention_probs_dropout_prob (float, optional):
The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target.
Defaults to `0.1`.
max_position_embeddings (int, optional):
The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input
sequence. Defaults to `512`.
type_vocab_size (int, optional):
The vocabulary size of `token_type_ids`.
Defaults to `16`.
initializer_range (float, optional):
The standard deviation of the normal initializer.
Defaults to `0.02`.
.. note::
A normal_initializer initializes weight matrices as normal distributions.
See :meth:`DistilBertPretrainedModel.init_weights()` for how weights are initialized in `DistilBertModel`.
pad_token_id (int, optional):
The index of padding token in the token vocabulary.
Defaults to `0`.
"""

def __init__(self,
vocab_size,
hidden_size=768,
Expand Down Expand Up @@ -162,6 +218,44 @@ def __init__(self,
self.apply(self.init_weights)

def forward(self, input_ids, attention_mask=None):
r'''
The DistilBertModel forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
attention_mask (Tensor, optional):
Mask used in multi-head attention to avoid performing attention to some unwanted positions,
usually the paddings or the subsequent positions.
Its data type can be int, float and bool.
When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length],
[batch_size, num_attention_heads, sequence_length, sequence_length].
Defaults to `None`, which means nothing needed to be prevented attention to.
Returns:
Tensor: Returns tensor `encoder_output`, which means the sequence of hidden-states at the last layer of the model.
Its data type should be float32 and its shape is [batch_size, sequence_length, hidden_size].
Example:
.. code-block::
import paddle
from paddlenlp.transformers import DistilBertModel, DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
output = model(**inputs)
'''

if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id
Expand All @@ -174,6 +268,21 @@ def forward(self, input_ids, attention_mask=None):


class DistilBertForSequenceClassification(DistilBertPretrainedModel):
"""
DistilBert Model with a linear layer on top of the output layer, designed for
sequence classification/regression tasks like GLUE tasks.
Args:
distilbert (:class:`DistilBertModel`):
An instance of DistilBertModel.
num_classes (int, optional):
The number of classes. Defaults to `2`.
dropout (float, optional):
The dropout probability for output of DistilBert.
If None, use the same value as `hidden_dropout_prob` of `DistilBertModel`
instance `distilbert`. Defaults to None.
"""

def __init__(self, distilbert, num_classes=2, dropout=None):
super(DistilBertForSequenceClassification, self).__init__()
self.num_classes = num_classes
Expand All @@ -188,6 +297,36 @@ def __init__(self, distilbert, num_classes=2, dropout=None):
self.apply(self.init_weights)

def forward(self, input_ids, attention_mask=None):
r"""
The DistilBertForSequenceClassification forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`DistilBertModel`.
attention_mask (list, optional):
See :class:`DistilBertModel`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input text classification logits.
Shape as `[batch_size, num_classes]` and dtype as `float32`.
Example:
.. code-block::
import paddle
from paddlenlp.transformers.distilbert.modeling import DistilBertForSequenceClassification
from paddlenlp.transformers.distilbert.tokenizer import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
outputs = model(**inputs)
logits = outputs[0]
"""

distilbert_output = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask)

Expand All @@ -202,6 +341,19 @@ def forward(self, input_ids, attention_mask=None):


class DistilBertForQuestionAnswering(DistilBertPretrainedModel):
"""
DistilBert Model with a linear layer on top of the hidden-states output to
compute `span_start_logits` and `span_end_logits`, designed for question-answering tasks like SQuAD.
Args:
distilbert (:class:`DistilBertModel`):
An instance of DistilBertModel.
dropout (float, optional):
The dropout probability for output of DistilBert.
If None, use the same value as `hidden_dropout_prob` of `DistilBertModel`
instance `distilbert`. Defaults to None.
"""

def __init__(self, distilbert, dropout=None):
super(DistilBertForQuestionAnswering, self).__init__()
self.distilbert = distilbert # allow bert to be config
Expand All @@ -211,6 +363,46 @@ def __init__(self, distilbert, dropout=None):
self.apply(self.init_weights)

def forward(self, input_ids, attention_mask=None):
r"""
The DistilBertForQuestionAnswering forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`DistilBertModel`.
attention_mask (list, optional):
See :class:`DistilBertModel`.
Returns:
tuple: Returns tuple (`start_logits`, `end_logits`).
With the fields:
- start_logits(Tensor):
A tensor of the input token classification logits, indicates the start position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
- end_logits(Tensor):
A tensor of the input token classification logits, indicates the end position of the labelled span.
Its data type should be float32 and its shape is [batch_size, sequence_length].
Example:
.. code-block::
import paddle
from paddlenlp.transformers.distilbert.modeling import DistilBertForQuestionAnswering
from paddlenlp.transformers.distilbert.tokenizer import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
outputs = model(**inputs)
start_logits = outputs[0]
end_logits =outputs[1]
"""

sequence_output = self.distilbert(
input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(sequence_output)
Expand All @@ -221,6 +413,21 @@ def forward(self, input_ids, attention_mask=None):


class DistilBertForTokenClassification(DistilBertPretrainedModel):
"""
DistilBert Model with a linear layer on top of the hidden-states output layer,
designed for token classification tasks like NER tasks.
Args:
distilbert (:class:`DistilBertModel`):
An instance of DistilBertModel.
num_classes (int, optional):
The number of classes. Defaults to `2`.
dropout (float, optional):
The dropout probability for output of DistilBert.
If None, use the same value as `hidden_dropout_prob` of `DistilBertModel`
instance `distilbert`. Defaults to None.
"""

def __init__(self, distilbert, num_classes=2, dropout=None):
super(DistilBertForTokenClassification, self).__init__()
self.num_classes = num_classes
Expand All @@ -232,6 +439,36 @@ def __init__(self, distilbert, num_classes=2, dropout=None):
self.apply(self.init_weights)

def forward(self, input_ids, attention_mask=None):
r"""
The DistilBertForTokenClassification forward method, overrides the __call__() special method.
Args:
input_ids (Tensor):
See :class:`DistilBertModel`.
attention_mask (list, optional):
See :class:`DistilBertModel`.
Returns:
Tensor: Returns tensor `logits`, a tensor of the input token classification logits.
Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`.
Example:
.. code-block::
import paddle
from paddlenlp.transformers.distilbert.modeling import DistilBertForTokenClassification
from paddlenlp.transformers.distilbert.tokenizer import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
outputs = model(**inputs)
logits = outputs[0]
"""

sequence_output = self.distilbert(
input_ids, attention_mask=attention_mask)

Expand All @@ -241,6 +478,14 @@ def forward(self, input_ids, attention_mask=None):


class DistilBertForMaskedLM(DistilBertPretrainedModel):
"""
DistilBert Model with a `language modeling` head on top.
Args:
distilbert (:class:`DistilBertModel`):
An instance of DistilBertModel.
"""

def __init__(self, distilbert):
super(DistilBertForMaskedLM, self).__init__()
self.distilbert = distilbert
Expand All @@ -255,6 +500,33 @@ def __init__(self, distilbert):
self.apply(self.init_weights)

def forward(self, input_ids=None, attention_mask=None):
r'''
The DistilBertForMaskedLM forward method, overrides the `__call__()` special method.
Args:
input_ids (Tensor):
See :class:`DistilBertModel`.
attention_mask (Tensor, optional):
See :class:`DistilBertModel`.
Returns:
Tensor: Returns tensor `prediction_logits`, the scores of masked token prediction.
Its data type should be float32 and its shape is [batch_size, sequence_length, vocab_size].
Example:
.. code-block::
import paddle
from paddlenlp.transformers import DistilBertForMaskedLM, DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
prediction_logits = model(**inputs)
'''

distilbert_output = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask)
prediction_logits = self.vocab_transform(distilbert_output)
Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/transformers/distilbert/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

class DistilBertTokenizer(BertTokenizer):
"""
Constructs a DistilBERT tokenizer. It uses a basic tokenizer to do punctuation
splitting, lower casing and so on, and follows a WordPiece tokenizer to
tokenize as subwords.
Constructs a DistilBertTokenizer.
The usage of DistilBertTokenizer is the same as
`BertTokenizer <https://paddlenlp.readthedocs.io/zh/latest/source/paddlenlp.transformers.bert.tokenizer.html>`__.
"""
resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained
pretrained_resource_files_map = {
Expand Down
Loading

0 comments on commit 499cbe4

Please sign in to comment.