diff --git a/paddlenlp/transformers/distilbert/modeling.py b/paddlenlp/transformers/distilbert/modeling.py index a0684e744ba3..93d94bf6b4dc 100644 --- a/paddlenlp/transformers/distilbert/modeling.py +++ b/paddlenlp/transformers/distilbert/modeling.py @@ -64,7 +64,7 @@ def forward(self, input_ids, position_ids=None): class DistilBertPretrainedModel(PretrainedModel): """ - An abstract class for pretrained DistilBERT models. It provides DistilBERT related + An abstract class for pretrained DistilBert models. It provides DistilBert related `model_config_file`, `resource_files_names`, `pretrained_resource_files_map`, `pretrained_init_configuration`, `base_model_prefix` for downloading and loading pretrained models. See `PretrainedModel` for more details. @@ -131,6 +131,62 @@ def init_weights(self, layer): @register_base_model class DistilBertModel(DistilBertPretrainedModel): + """ + The bare DistilBert Model transformer outputting raw hidden-states without any specific head on top. + + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + vocab_size (int): + Vocabulary size of `inputs_ids` in `DistilBertModel`. Defines the number of different tokens that can + be represented by the `inputs_ids` passed when calling `DistilBertModel`. + hidden_size (int, optional): + Dimensionality of the embedding layer, encoder layers and the pooler layer. Defaults to `768`. + num_hidden_layers (int, optional): + Number of hidden layers in the Transformer encoder. Defaults to `12`. + num_attention_heads (int, optional): + Number of attention heads for each attention layer in the Transformer encoder. + Defaults to `12`. + intermediate_size (int, optional): + Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors + to ff layers are firstly projected from `hidden_size` to `intermediate_size`, + and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. + Defaults to `3072`. + hidden_act (str, optional): + The non-linear activation function in the feed-forward layer. + ``"gelu"``, ``"relu"`` and any other paddle supported activation functions + are supported. Defaults to `"gelu"`. + hidden_dropout_prob (float, optional): + The dropout probability for all fully connected layers in the embeddings and encoder. + Defaults to `0.1`. + attention_probs_dropout_prob (float, optional): + The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. + Defaults to `0.1`. + max_position_embeddings (int, optional): + The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input + sequence. Defaults to `512`. + type_vocab_size (int, optional): + The vocabulary size of `token_type_ids`. + Defaults to `16`. + initializer_range (float, optional): + The standard deviation of the normal initializer. + Defaults to `0.02`. + + .. note:: + A normal_initializer initializes weight matrices as normal distributions. + See :meth:`DistilBertPretrainedModel.init_weights()` for how weights are initialized in `DistilBertModel`. + + pad_token_id (int, optional): + The index of padding token in the token vocabulary. + Defaults to `0`. + + """ + def __init__(self, vocab_size, hidden_size=768, @@ -162,6 +218,44 @@ def __init__(self, self.apply(self.init_weights) def forward(self, input_ids, attention_mask=None): + r''' + The DistilBertModel forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. + attention_mask (Tensor, optional): + Mask used in multi-head attention to avoid performing attention to some unwanted positions, + usually the paddings or the subsequent positions. + Its data type can be int, float and bool. + When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. + When the data type is int, the `masked` tokens have `0` values and the others have `1` values. + When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. + It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. + For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], + [batch_size, num_attention_heads, sequence_length, sequence_length]. + Defaults to `None`, which means nothing needed to be prevented attention to. + + Returns: + Tensor: Returns tensor `encoder_output`, which means the sequence of hidden-states at the last layer of the model. + Its data type should be float32 and its shape is [batch_size, sequence_length, hidden_size]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import DistilBertModel, DistilBertTokenizer + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') + model = DistilBertModel.from_pretrained('distilbert-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + output = model(**inputs) + ''' + if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id @@ -174,6 +268,21 @@ def forward(self, input_ids, attention_mask=None): class DistilBertForSequenceClassification(DistilBertPretrainedModel): + """ + DistilBert Model with a linear layer on top of the output layer, designed for + sequence classification/regression tasks like GLUE tasks. + + Args: + distilbert (:class:`DistilBertModel`): + An instance of DistilBertModel. + num_classes (int, optional): + The number of classes. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of DistilBert. + If None, use the same value as `hidden_dropout_prob` of `DistilBertModel` + instance `distilbert`. Defaults to None. + """ + def __init__(self, distilbert, num_classes=2, dropout=None): super(DistilBertForSequenceClassification, self).__init__() self.num_classes = num_classes @@ -188,6 +297,36 @@ def __init__(self, distilbert, num_classes=2, dropout=None): self.apply(self.init_weights) def forward(self, input_ids, attention_mask=None): + r""" + The DistilBertForSequenceClassification forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`DistilBertModel`. + attention_mask (list, optional): + See :class:`DistilBertModel`. + + Returns: + Tensor: Returns tensor `logits`, a tensor of the input text classification logits. + Shape as `[batch_size, num_classes]` and dtype as `float32`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers.distilbert.modeling import DistilBertForSequenceClassification + from paddlenlp.transformers.distilbert.tokenizer import DistilBertTokenizer + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') + model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + logits = outputs[0] + """ + distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask) @@ -202,6 +341,19 @@ def forward(self, input_ids, attention_mask=None): class DistilBertForQuestionAnswering(DistilBertPretrainedModel): + """ + DistilBert Model with a linear layer on top of the hidden-states output to + compute `span_start_logits` and `span_end_logits`, designed for question-answering tasks like SQuAD. + + Args: + distilbert (:class:`DistilBertModel`): + An instance of DistilBertModel. + dropout (float, optional): + The dropout probability for output of DistilBert. + If None, use the same value as `hidden_dropout_prob` of `DistilBertModel` + instance `distilbert`. Defaults to None. + """ + def __init__(self, distilbert, dropout=None): super(DistilBertForQuestionAnswering, self).__init__() self.distilbert = distilbert # allow bert to be config @@ -211,6 +363,46 @@ def __init__(self, distilbert, dropout=None): self.apply(self.init_weights) def forward(self, input_ids, attention_mask=None): + r""" + The DistilBertForQuestionAnswering forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`DistilBertModel`. + attention_mask (list, optional): + See :class:`DistilBertModel`. + + Returns: + tuple: Returns tuple (`start_logits`, `end_logits`). + + With the fields: + + - start_logits(Tensor): + A tensor of the input token classification logits, indicates the start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - end_logits(Tensor): + A tensor of the input token classification logits, indicates the end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers.distilbert.modeling import DistilBertForQuestionAnswering + from paddlenlp.transformers.distilbert.tokenizer import DistilBertTokenizer + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') + model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + start_logits = outputs[0] + end_logits =outputs[1] + """ + sequence_output = self.distilbert( input_ids, attention_mask=attention_mask) sequence_output = self.dropout(sequence_output) @@ -221,6 +413,21 @@ def forward(self, input_ids, attention_mask=None): class DistilBertForTokenClassification(DistilBertPretrainedModel): + """ + DistilBert Model with a linear layer on top of the hidden-states output layer, + designed for token classification tasks like NER tasks. + + Args: + distilbert (:class:`DistilBertModel`): + An instance of DistilBertModel. + num_classes (int, optional): + The number of classes. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of DistilBert. + If None, use the same value as `hidden_dropout_prob` of `DistilBertModel` + instance `distilbert`. Defaults to None. + """ + def __init__(self, distilbert, num_classes=2, dropout=None): super(DistilBertForTokenClassification, self).__init__() self.num_classes = num_classes @@ -232,6 +439,36 @@ def __init__(self, distilbert, num_classes=2, dropout=None): self.apply(self.init_weights) def forward(self, input_ids, attention_mask=None): + r""" + The DistilBertForTokenClassification forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`DistilBertModel`. + attention_mask (list, optional): + See :class:`DistilBertModel`. + + Returns: + Tensor: Returns tensor `logits`, a tensor of the input token classification logits. + Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers.distilbert.modeling import DistilBertForTokenClassification + from paddlenlp.transformers.distilbert.tokenizer import DistilBertTokenizer + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') + model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + logits = outputs[0] + """ + sequence_output = self.distilbert( input_ids, attention_mask=attention_mask) @@ -241,6 +478,14 @@ def forward(self, input_ids, attention_mask=None): class DistilBertForMaskedLM(DistilBertPretrainedModel): + """ + DistilBert Model with a `language modeling` head on top. + + Args: + distilbert (:class:`DistilBertModel`): + An instance of DistilBertModel. + """ + def __init__(self, distilbert): super(DistilBertForMaskedLM, self).__init__() self.distilbert = distilbert @@ -255,6 +500,33 @@ def __init__(self, distilbert): self.apply(self.init_weights) def forward(self, input_ids=None, attention_mask=None): + r''' + The DistilBertForMaskedLM forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + See :class:`DistilBertModel`. + attention_mask (Tensor, optional): + See :class:`DistilBertModel`. + + Returns: + Tensor: Returns tensor `prediction_logits`, the scores of masked token prediction. + Its data type should be float32 and its shape is [batch_size, sequence_length, vocab_size]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import DistilBertForMaskedLM, DistilBertTokenizer + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') + model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + prediction_logits = model(**inputs) + ''' + distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask) prediction_logits = self.vocab_transform(distilbert_output) diff --git a/paddlenlp/transformers/distilbert/tokenizer.py b/paddlenlp/transformers/distilbert/tokenizer.py index f804b607c955..f0017f8cc3f0 100644 --- a/paddlenlp/transformers/distilbert/tokenizer.py +++ b/paddlenlp/transformers/distilbert/tokenizer.py @@ -19,9 +19,9 @@ class DistilBertTokenizer(BertTokenizer): """ - Constructs a DistilBERT tokenizer. It uses a basic tokenizer to do punctuation - splitting, lower casing and so on, and follows a WordPiece tokenizer to - tokenize as subwords. + Constructs a DistilBertTokenizer. + The usage of DistilBertTokenizer is the same as + `BertTokenizer `__. """ resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained pretrained_resource_files_map = { diff --git a/paddlenlp/transformers/nezha/modeling.py b/paddlenlp/transformers/nezha/modeling.py index 58f5552f54d3..2308679b9926 100644 --- a/paddlenlp/transformers/nezha/modeling.py +++ b/paddlenlp/transformers/nezha/modeling.py @@ -8,15 +8,10 @@ from paddlenlp.transformers import PretrainedModel, register_base_model - __all__ = [ - 'NeZhaModel', - "NeZhaPretrainedModel", - 'NeZhaForPretraining', - 'NeZhaForSequenceClassification', - 'NeZhaPretrainingHeads', - 'NeZhaForTokenClassification', - 'NeZhaForQuestionAnswering', + 'NeZhaModel', "NeZhaPretrainedModel", 'NeZhaForPretraining', + 'NeZhaForSequenceClassification', 'NeZhaPretrainingHeads', + 'NeZhaForTokenClassification', 'NeZhaForQuestionAnswering', 'NeZhaForMultipleChoice' ] @@ -46,7 +41,8 @@ def gelu_new(x): Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ - return 0.5 * x * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * paddle.pow(x, 3.0)))) + return 0.5 * x * (1.0 + paddle.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * paddle.pow(x, 3.0)))) ACT2FN = { @@ -63,12 +59,12 @@ def gelu_new(x): class NeZhaAttention(nn.Layer): def __init__(self, - hidden_size, - num_attention_heads, - hidden_dropout_prob, - attention_probs_dropout_prob, - max_relative_position, - layer_norm_eps): + hidden_size=768, + num_attention_heads=12, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_relative_position=64, + layer_norm_eps=1e-12): super(NeZhaAttention, self).__init__() if hidden_size % num_attention_heads != 0: raise ValueError( @@ -82,51 +78,54 @@ def __init__(self, self.key = nn.Linear(hidden_size, self.all_head_size) self.value = nn.Linear(hidden_size, self.all_head_size) self.relative_positions_embeddings = self.generate_relative_positions_embeddings( - length=512, depth=self.attention_head_size, max_relative_position=max_relative_position - ) + length=512, + depth=self.attention_head_size, + max_relative_position=max_relative_position) self.attention_dropout = nn.Dropout(attention_probs_dropout_prob) self.dense = nn.Linear(hidden_size, hidden_size) self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) self.output_dropout = nn.Dropout(hidden_dropout_prob) - def generate_relative_positions_embeddings(self, length, depth, max_relative_position=127): + def generate_relative_positions_embeddings(self, + length, + depth, + max_relative_position=127): vocab_size = max_relative_position * 2 + 1 range_vec = paddle.arange(length) range_mat = paddle.tile( - range_vec, repeat_times=[length] - ).reshape((length, length)) + range_vec, repeat_times=[length]).reshape((length, length)) distance_mat = range_mat - paddle.t(range_mat) distance_mat_clipped = paddle.clip( - distance_mat.astype( 'float32'), - -max_relative_position, - max_relative_position - ) + distance_mat.astype('float32'), -max_relative_position, + max_relative_position) final_mat = distance_mat_clipped + max_relative_position embeddings_table = np.zeros([vocab_size, depth]) for pos in range(vocab_size): for i in range(depth // 2): - embeddings_table[pos, 2 * i] = np.sin(pos / np.power(10000, 2 * i / depth)) - embeddings_table[pos, 2 * i + 1] = np.cos(pos / np.power(10000, 2 * i / depth)) - - embeddings_table_tensor = paddle.to_tensor(embeddings_table, dtype='float32') - flat_relative_positions_matrix = final_mat.reshape((-1,)) + embeddings_table[pos, 2 * i] = np.sin(pos / np.power(10000, 2 * + i / depth)) + embeddings_table[pos, 2 * i + 1] = np.cos(pos / np.power( + 10000, 2 * i / depth)) + + embeddings_table_tensor = paddle.to_tensor( + embeddings_table, dtype='float32') + flat_relative_positions_matrix = final_mat.reshape((-1, )) one_hot_relative_positions_matrix = paddle.nn.functional.one_hot( - flat_relative_positions_matrix.astype('int64'), - num_classes=vocab_size - ) - embeddings = paddle.matmul( - one_hot_relative_positions_matrix, - embeddings_table_tensor - ) + flat_relative_positions_matrix.astype('int64'), + num_classes=vocab_size) + embeddings = paddle.matmul(one_hot_relative_positions_matrix, + embeddings_table_tensor) my_shape = final_mat.shape my_shape.append(depth) embeddings = embeddings.reshape(my_shape) return embeddings def transpose_for_scores(self, x): - new_x_shape = x.shape[:-1] + [self.num_attention_heads, self.attention_head_size] + new_x_shape = x.shape[:-1] + [ + self.num_attention_heads, self.attention_head_size + ] x = x.reshape(new_x_shape) return x.transpose((0, 2, 1, 3)) @@ -140,29 +139,25 @@ def forward(self, hidden_states, attention_mask): value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = paddle.matmul( - query_layer, - key_layer.transpose((0, 1, 3, 2)) - ) + attention_scores = paddle.matmul(query_layer, + key_layer.transpose((0, 1, 3, 2))) batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.shape - relations_keys = self.relative_positions_embeddings.detach().clone()[:to_seq_length, :to_seq_length, :] + relations_keys = self.relative_positions_embeddings.detach().clone( + )[:to_seq_length, :to_seq_length, :] query_layer_t = query_layer.transpose((2, 0, 1, 3)) query_layer_r = query_layer_t.reshape( - (from_seq_length, batch_size * - num_attention_heads, self.attention_head_size) - ) - key_position_scores = paddle.matmul( - query_layer_r, - relations_keys.transpose((0, 2, 1)) - ) + (from_seq_length, batch_size * num_attention_heads, + self.attention_head_size)) + key_position_scores = paddle.matmul(query_layer_r, + relations_keys.transpose((0, 2, 1))) key_position_scores_r = key_position_scores.reshape( - (from_seq_length, batch_size, num_attention_heads, from_seq_length) - ) + (from_seq_length, batch_size, num_attention_heads, from_seq_length)) key_position_scores_r_t = key_position_scores_r.transpose((1, 2, 0, 3)) attention_scores = attention_scores + key_position_scores_r_t - attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_scores = attention_scores / math.sqrt( + self.attention_head_size) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. @@ -174,60 +169,63 @@ def forward(self, hidden_states, attention_mask): context_layer = paddle.matmul(attention_probs, value_layer) - relations_values = self.relative_positions_embeddings.clone()[:to_seq_length, :to_seq_length, :] + relations_values = self.relative_positions_embeddings.clone( + )[:to_seq_length, :to_seq_length, :] attention_probs_t = attention_probs.transpose((2, 0, 1, 3)) attentions_probs_r = attention_probs_t.reshape( - (from_seq_length, batch_size * num_attention_heads, to_seq_length) - ) - value_position_scores = paddle.matmul(attentions_probs_r, relations_values) + (from_seq_length, batch_size * num_attention_heads, to_seq_length)) + value_position_scores = paddle.matmul(attentions_probs_r, + relations_values) value_position_scores_r = value_position_scores.reshape( - (from_seq_length, batch_size, - num_attention_heads, self.attention_head_size) - ) - value_position_scores_r_t = value_position_scores_r.transpose((1, 2, 0, 3)) + (from_seq_length, batch_size, num_attention_heads, + self.attention_head_size)) + value_position_scores_r_t = value_position_scores_r.transpose( + (1, 2, 0, 3)) context_layer = context_layer + value_position_scores_r_t context_layer = context_layer.transpose((0, 2, 1, 3)) - new_context_layer_shape = context_layer.shape[:-2] + [self.all_head_size] + new_context_layer_shape = context_layer.shape[:-2] + [ + self.all_head_size + ] context_layer = context_layer.reshape(new_context_layer_shape) projected_context_layer = self.dense(context_layer) - projected_context_layer_dropout = self.output_dropout(projected_context_layer) + projected_context_layer_dropout = self.output_dropout( + projected_context_layer) layer_normed_context_layer = self.layer_norm( - hidden_states + projected_context_layer_dropout - ) + hidden_states + projected_context_layer_dropout) return layer_normed_context_layer, attention_scores class NeZhaLayer(nn.Layer): def __init__(self, - hidden_size, - num_attention_heads, - intermediate_size, - hidden_act, - hidden_dropout_prob, - attention_probs_dropout_prob, - max_relative_position, - layer_norm_eps): + hidden_size=768, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_relative_position=64, + layer_norm_eps=1e-12): super(NeZhaLayer, self).__init__() self.seq_len_dim = 1 self.layer_norm = nn.LayerNorm(hidden_size, epsilon=layer_norm_eps) self.attention = NeZhaAttention( - hidden_size, - num_attention_heads, - hidden_dropout_prob, - attention_probs_dropout_prob, - max_relative_position, - layer_norm_eps - ) + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + max_relative_position=max_relative_position, + layer_norm_eps=layer_norm_eps) self.ffn = nn.Linear(hidden_size, intermediate_size) self.ffn_output = nn.Linear(intermediate_size, hidden_size) self.activation = ACT2FN[hidden_act] self.dropout = nn.Dropout(hidden_dropout_prob) def forward(self, hidden_states, attention_mask=None): - attention_output, layer_att = self.attention(hidden_states, attention_mask) + attention_output, layer_att = self.attention(hidden_states, + attention_mask) ffn_output = self.ffn(attention_output) ffn_output = self.activation(ffn_output) @@ -241,34 +239,35 @@ def forward(self, hidden_states, attention_mask=None): class NeZhaEncoder(nn.Layer): def __init__(self, - hidden_size, - num_hidden_layers, - num_attention_heads, - intermediate_size, - hidden_act, - hidden_dropout_prob, - attention_probs_dropout_prob, - max_relative_position, - layer_norm_eps): + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_relative_position=64, + layer_norm_eps='1e-12'): super(NeZhaEncoder, self).__init__() layer = NeZhaLayer( - hidden_size, - num_attention_heads, - intermediate_size, - hidden_act, - hidden_dropout_prob, - attention_probs_dropout_prob, - max_relative_position, - layer_norm_eps - ) - self.layer = nn.LayerList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + max_relative_position=max_relative_position, + layer_norm_eps=layer_norm_eps) + self.layer = nn.LayerList( + [copy.deepcopy(layer) for _ in range(num_hidden_layers)]) def forward(self, hidden_states, attention_mask): all_encoder_layers = [] all_encoder_att = [] for i, layer_module in enumerate(self.layer): all_encoder_layers.append(hidden_states) - hidden_states, layer_att = layer_module(all_encoder_layers[i], attention_mask) + hidden_states, layer_att = layer_module(all_encoder_layers[i], + attention_mask) all_encoder_att.append(layer_att) all_encoder_layers.append(hidden_states) return all_encoder_layers, all_encoder_att @@ -288,8 +287,8 @@ def __init__(self, self.word_embeddings = nn.Embedding(vocab_size, hidden_size) if not use_relative_position: - self.position_embeddings = nn.Embedding( - max_position_embeddings, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, + hidden_size) self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) self.layer_norm = nn.LayerNorm(hidden_size) @@ -335,6 +334,13 @@ def forward(self, hidden_states): class NeZhaPretrainedModel(PretrainedModel): + """ + An abstract class for pretrained NeZha models. It provides NeZha related + `model_config_file`, `resource_files_names`, `pretrained_resource_files_map`, + `pretrained_init_configuration`, `base_model_prefix` for downloading and + loading pretrained models. See `PretrainedModel` for more details. + """ + model_config_file = "model_config.json" pretrained_init_configuration = { "nezha-base-chinese": { @@ -432,6 +438,68 @@ def init_weights(self, layer): @register_base_model class NeZhaModel(NeZhaPretrainedModel): + """ + The bare NeZha Model transformer outputting raw hidden-states. + + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + vocab_size (int): + Vocabulary size of `inputs_ids` in `DistilBertModel`. Defines the number of different tokens that can + be represented by the `inputs_ids` passed when calling `DistilBertModel`. + hidden_size (int, optional): + Dimensionality of the embedding layer, encoder layers and the pooler layer. Defaults to `768`. + num_hidden_layers (int, optional): + Number of hidden layers in the Transformer encoder. Defaults to `12`. + num_attention_heads (int, optional): + Number of attention heads for each attention layer in the Transformer encoder. + Defaults to `12`. + intermediate_size (int, optional): + Dimensionality of the feed-forward (ff) layer in the encoder. Input tensors + to ff layers are firstly projected from `hidden_size` to `intermediate_size`, + and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. + Defaults to `3072`. + hidden_act (str, optional): + The non-linear activation function in the feed-forward layer. + ``"gelu"``, ``"relu"`` and any other paddle supported activation functions + are supported. Defaults to `"gelu"`. + hidden_dropout_prob (float, optional): + The dropout probability for all fully connected layers in the embeddings and encoder. + Defaults to `0.1`. + attention_probs_dropout_prob (float, optional): + The dropout probability used in MultiHeadAttention in all encoder layers to drop some attention target. + Defaults to `0.1`. + max_position_embeddings (int, optional): + The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input + sequence. Defaults to `512`. + type_vocab_size (int, optional): + The vocabulary size of `token_type_ids`. + Defaults to `16`. + initializer_range (float, optional): + The standard deviation of the normal initializer. + Defaults to `0.02`. + + .. note:: + A normal_initializer initializes weight matrices as normal distributions. + See :meth:`NeZhaPretrainedModel.init_weights()` for how weights are initialized in `NeZhaModel`. + + max_relative_embeddings (int, optional): + The maximum value of the dimensionality of relative encoding, which dictates the maximum supported + relative distance of two sentences. + Defaults to `64`. + layer_norm_eps (float, optional): + The small value added to the variance in `LayerNorm` to prevent division by zero. + Defaults to `1e-12`. + use_relative_position (bool, optional): + Whether or not to use relative position embedding. Defaults to `True`. + + """ + def __init__(self, vocab_size, hidden_size=768, @@ -451,30 +519,88 @@ def __init__(self, self.initializer_range = initializer_range self.embeddings = NeZhaEmbeddings( - vocab_size, - hidden_size, - hidden_dropout_prob, - max_position_embeddings, - type_vocab_size, - use_relative_position - ) + vocab_size=vocab_size, + hidden_size=hidden_size, + hidden_dropout_prob=hidden_dropout_prob, + max_position_embeddings=max_position_embeddings, + type_vocab_size=type_vocab_size, + use_relative_position=use_relative_position) self.encoder = NeZhaEncoder( - hidden_size, - num_hidden_layers, - num_attention_heads, - intermediate_size, - hidden_act, - hidden_dropout_prob, - attention_probs_dropout_prob, - max_relative_position, - layer_norm_eps - ) + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + intermediate_size=intermediate_size, + hidden_act=hidden_act, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, + max_relative_position=max_relative_position, + layer_norm_eps=layer_norm_eps) self.pooler = NeZhaPooler(hidden_size) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): + r''' + The NeZhaModel forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. + token_type_ids (Tensor, optional): + Segment token indices to indicate different portions of the inputs. + Selected in the range ``[0, type_vocab_size - 1]``. + If `type_vocab_size` is 2, which means the inputs have two portions. + Indices can either be 0 or 1: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. + Defaults to `None`, which means we don't add segment embeddings. + attention_mask (Tensor, optional): + Mask used in multi-head attention to avoid performing attention to some unwanted positions, + usually the paddings or the subsequent positions. + Its data type can be int, float and bool. + When the data type is bool, the `masked` tokens have `False` values and the others have `True` values. + When the data type is int, the `masked` tokens have `0` values and the others have `1` values. + When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values. + It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. + For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], + [batch_size, num_attention_heads, sequence_length, sequence_length]. + We use whole-word-mask in NeZha, so the whole word will have the same value. For example, "使用" as a word, + "使" and "用" will have the same value. + Defaults to `None`, which means nothing needed to be prevented attention to. + + Returns: + tuple: Returns tuple (`sequence_output`, `pooled_output`). + + With the fields: + + - `sequence_output` (Tensor): + Sequence of hidden-states at the last layer of the model. + It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size]. + + - `pooled_output` (Tensor): + The output of first token (`[CLS]`) in sequence. + We "pool" the model by simply taking the hidden state corresponding to the first token. + Its data type should be float32 and its shape is [batch_size, hidden_size]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import NeZhaModel, NeZhaTokenizer + + tokenizer = NeZhaTokenizer.from_pretrained('nezha-base-chinese') + model = NeZhaModel.from_pretrained('nezha-base-chinese') + + inputs = tokenizer("欢迎使用百度飞浆!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + output = model(**inputs) + ''' if attention_mask is None: attention_mask = paddle.ones_like(input_ids) if token_type_ids is None: @@ -485,7 +611,8 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): embedding_output = self.embeddings(input_ids, token_type_ids) - encoder_outputs, _ = self.encoder(embedding_output, extended_attention_mask) + encoder_outputs, _ = self.encoder(embedding_output, + extended_attention_mask) sequence_output = encoder_outputs[-1] pooled_output = self.pooler(sequence_output) @@ -507,10 +634,7 @@ def __init__(self, self.decoder_weight = embedding_weights self.decoder_bias = self.create_parameter( - shape=[vocab_size], - dtype=self.decoder_weight.dtype, - is_bias=True - ) + shape=[vocab_size], dtype=self.decoder_weight.dtype, is_bias=True) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -518,65 +642,154 @@ def forward(self, hidden_states): hidden_states = self.layer_norm(hidden_states) hidden_states = paddle.tensor.matmul( - hidden_states, - self.decoder_weight, - transpose_y=True - ) + self.decoder_bias + hidden_states, self.decoder_weight, + transpose_y=True) + self.decoder_bias return hidden_states class NeZhaPretrainingHeads(nn.Layer): + """ + Perform language modeling task and next sentence classification task. + + Args: + hidden_size (int): + See :class:`NeZhaModel`. + vocab_size (int): + See :class:`NeZhaModel`. + hidden_act (str): + Activation function used in the language modeling task. + embedding_weights (Tensor, optional): + Decoding weights used to map hidden_states to logits of the masked token prediction. + Its data type should be float32 and its shape is [vocab_size, hidden_size]. + Defaults to `None`, which means use the same weights of the embedding layer. + + """ + def __init__(self, hidden_size, vocab_size, hidden_act, embedding_weights=None): super(NeZhaPretrainingHeads, self).__init__() - self.predictions = NeZhaLMPredictionHead( - hidden_size, - vocab_size, - hidden_act, - embedding_weights - ) + self.predictions = NeZhaLMPredictionHead(hidden_size, vocab_size, + hidden_act, embedding_weights) self.seq_relationship = nn.Linear(hidden_size, 2) def forward(self, sequence_output, pooled_output): + """ + Args: + sequence_output(Tensor): + Sequence of hidden-states at the last layer of the model. + It's data type should be float32 and its shape is [batch_size, sequence_length, hidden_size]. + pooled_output(Tensor): + The output of first token (`[CLS]`) in sequence. + We "pool" the model by simply taking the hidden state corresponding to the first token. + Its data type should be float32 and its shape is [batch_size, hidden_size]. + + Returns: + tuple: Returns tuple (``prediction_scores``, ``seq_relationship_score``). + + With the fields: + + - `prediction_scores` (Tensor): + The scores of masked token prediction. Its data type should be float32. + If `masked_positions` is None, its shape is [batch_size, sequence_length, vocab_size]. + Otherwise, its shape is [batch_size, mask_token_num, vocab_size]. + + - `seq_relationship_score` (Tensor): + The scores of next sentence prediction. + Its data type should be float32 and its shape is [batch_size, 2]. + + """ prediction_scores = self.predictions(sequence_output) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score class NeZhaForPretraining(NeZhaPretrainedModel): + """ + NeZha Model with pretraining tasks on top. + + Args: + nezha (:class:`NeZhaModel`): + An instance of :class:`NeZhaModel`. + + """ + def __init__(self, nezha): super(NeZhaForPretraining, self).__init__() self.nezha = nezha self.cls = NeZhaPretrainingHeads( - self.nezha.config["hidden_size"], - self.nezha.config["vocab_size"], + self.nezha.config["hidden_size"], self.nezha.config["vocab_size"], self.nezha.config["hidden_act"], - self.nezha.embeddings.word_embeddings.weight - ) + self.nezha.embeddings.word_embeddings.weight) self.apply(self.init_weights) - def forward(self, input_ids, token_type_ids=None, attention_mask=None, - masked_lm_labels=None, next_sentence_label=None): - sequence_output, pooled_output = self.nezha(input_ids, token_type_ids, attention_mask) - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + next_sentence_label=None): + r""" + + Args: + input_ids (Tensor): + See :class:`NeZhaModel`. + token_type_ids (Tensor, optional): + See :class:`NeZhaModel`. + attention_mask (Tensor, optional): + See :class:`NeZhaModel`. + masked_lm_labels (Tensor, optional): + The labels of the masked language modeling, its dimensionality is equal to `prediction_scores`. + Its data type should be int64 and its shape is [batch_size, sequence_length, 1]. + next_sentence_label (Tensor, optional): + The labels of the next sentence prediction task, the dimensionality of `next_sentence_labels` + is equal to `seq_relation_labels`. Its data type should be int64 and its shape is [batch_size, 1]. + + Returns: + Tensor or tuple: Returns Tensor ``total_loss`` if `masked_lm_labels` is not None. + Returns tuple (``prediction_scores``, ``seq_relationship_score``) if `masked_lm_labels` is None. + + With the fields: + + - `total_loss` (Tensor): + + + - `prediction_scores` (Tensor): + The scores of masked token prediction. Its data type should be float32. + If `masked_positions` is None, its shape is [batch_size, sequence_length, vocab_size]. + Otherwise, its shape is [batch_size, mask_token_num, vocab_size]. + + - `seq_relationship_score` (Tensor): + The scores of next sentence prediction. + Its data type should be float32 and its shape is [batch_size, 2]. + + """ + sequence_output, pooled_output = self.nezha(input_ids, token_type_ids, + attention_mask) + prediction_scores, seq_relationship_score = self.cls(sequence_output, + pooled_output) if masked_lm_labels is not None and next_sentence_label is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.reshape( - (-1, self.nezha.config["vocab_size"])), masked_lm_labels.reshape((-1,))) - next_sentence_loss = loss_fct(seq_relationship_score.reshape( - (-1, 2)), next_sentence_label.reshape((-1,))) + masked_lm_loss = loss_fct( + prediction_scores.reshape( + (-1, self.nezha.config["vocab_size"])), + masked_lm_labels.reshape((-1, ))) + next_sentence_loss = loss_fct( + seq_relationship_score.reshape((-1, 2)), + next_sentence_label.reshape((-1, ))) total_loss = masked_lm_loss + next_sentence_loss return total_loss elif masked_lm_labels is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.reshape( - (-1, self.nezha.config["vocab_size"])), masked_lm_labels.reshape((-1,))) + masked_lm_loss = loss_fct( + prediction_scores.reshape( + (-1, self.nezha.config["vocab_size"])), + masked_lm_labels.reshape((-1, ))) total_loss = masked_lm_loss return total_loss else: @@ -584,6 +797,20 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, class NeZhaForQuestionAnswering(NeZhaPretrainedModel): + """ + NeZha Model with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and + `span end logits`). + + Args: + nezha (:class:`BertModel`): + An instance of NeZhaModel. + dropout (float, optional): + The dropout probability for output of NeZha. + If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` + instance `nezha`. Defaults to `None`. + """ + def __init__(self, nezha, dropout=None): super(NeZhaForQuestionAnswering, self).__init__() self.nezha = nezha @@ -591,7 +818,49 @@ def __init__(self, nezha, dropout=None): self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): - sequence_output, _ = self.nezha(input_ids, token_type_ids, attention_mask) + r""" + The NeZhaForQuestionAnswering forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`NeZhaModel`. + token_type_ids (Tensor, optional): + See :class:`NeZhaModel`. + attention_mask (Tensor, optional): + See :class:`NeZhaModel`. + + Returns: + tuple: Returns tuple (`start_logits`, `end_logits`). + + With the fields: + + - `start_logits` (Tensor): + A tensor of the input token classification logits, indicates the start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - `end_logits` (Tensor): + A tensor of the input token classification logits, indicates the end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import NeZhaForQuestionAnswering + from paddlenlp.transformers import NeZhaTokenizer + + tokenizer = NeZhaTokenizer.from_pretrained('nezha-base-chinese') + model = NeZhaForQuestionAnswering.from_pretrained('nezha-base-chinese') + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + start_logits = outputs[0] + end_logits =outputs[1] + """ + sequence_output, _ = self.nezha(input_ids, token_type_ids, + attention_mask) logits = self.classifier(sequence_output) logits = paddle.transpose(logits, perm=[2, 0, 1]) @@ -602,15 +871,64 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): class NeZhaForSequenceClassification(NeZhaPretrainedModel): + """ + NeZha Model with a linear layer on top of the output layer, designed for + sequence classification/regression tasks like GLUE tasks. + + Args: + nezha (:class:`NeZhaModel`): + An instance of NeZhaModel. + num_classes (int, optional): + The number of classes. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of NeZha. + If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` + instance `nezha`. Defaults to None. + """ + def __init__(self, nezha, num_classes=2, dropout=None): super(NeZhaForSequenceClassification, self).__init__() self.num_classes = num_classes self.nezha = nezha - self.dropout = nn.Dropout(dropout if dropout is not None else self.nezha.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.nezha.config["hidden_size"], num_classes) + self.dropout = nn.Dropout(dropout if dropout is not None else + self.nezha.config["hidden_dropout_prob"]) + self.classifier = nn.Linear(self.nezha.config["hidden_size"], + num_classes) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): + r""" + The NeZhaForSequenceClassification forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`NeZhaModel`. + token_type_ids (Tensor, optional): + See :class:`NeZhaModel`. + attention_mask (Tensor, optional): + See :class:`NeZhaModel`. + + Returns: + Tensor: Returns tensor `logits`, a tensor of the input text classification logits. + Shape as `[batch_size, num_classes]` and dtype as float32. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import NeZhaForSequenceClassification + from paddlenlp.transformers import NeZhaTokenizer + + tokenizer = NeZhaTokenizer.from_pretrained('nezha-base-chinese') + model = NeZhaForSequenceClassification.from_pretrained('nezha-base-chinese') + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + logits =outputs[0] + + """ _, pooled_output = self.nezha(input_ids, token_type_ids, attention_mask) pooled_output = self.dropout(pooled_output) @@ -620,16 +938,65 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): class NeZhaForTokenClassification(NeZhaPretrainedModel): + """ + NeZha Model with a linear layer on top of the hidden-states output layer, + designed for token classification tasks like NER tasks. + + Args: + nezha (:class:`NeZhaModel`): + An instance of NeZhaModel. + num_classes (int, optional): + The number of classes. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of NeZha. + If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` + instance `nezha`. Defaults to `None`. + """ + def __init__(self, nezha, num_classes=2, dropout=None): super(NeZhaForTokenClassification, self).__init__() self.num_classes = num_classes - self.nezha = nezha - self.dropout = nn.Dropout(dropout if dropout is not None else self.nezha.config["hidden_dropout_prob"]) - self.classifier = nn.Linear(self.nezha.config["hidden_size"], num_classes) + self.nezha = nezha + self.dropout = nn.Dropout(dropout if dropout is not None else + self.nezha.config["hidden_dropout_prob"]) + self.classifier = nn.Linear(self.nezha.config["hidden_size"], + num_classes) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): - sequence_output, _ = self.nezha(input_ids, token_type_ids, attention_mask) + r""" + The NeZhaForTokenClassification forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`NeZhaModel`. + token_type_ids (Tensor, optional): + See :class:`NeZhaModel`. + attention_mask (list, optional): + See :class:`NeZhaModel`. + + Returns: + Tensor: Returns tensor `logits`, a tensor of the input token classification logits. + Shape as `[batch_size, sequence_length, num_classes]` and dtype as `float32`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import NeZhaForTokenClassification + from paddlenlp.transformers import NeZhaTokenizer + + tokenizer = NeZhaTokenizer.from_pretrained('nezha-base-chinese') + model = NeZhaForTokenClassification.from_pretrained('nezha-base-chinese') + + inputs = tokenizer("欢迎使用百度飞桨!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + logits = outputs[0] + """ + sequence_output, _ = self.nezha(input_ids, token_type_ids, + attention_mask) sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) @@ -638,27 +1005,62 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None): class NeZhaForMultipleChoice(NeZhaPretrainedModel): + """ + NeZha Model with a multiple choice classification head on top. + + Args: + nezha (:class:`NeZhaModel`): + An instance of NeZhaModel. + num_choices (int, optional): + The number of choices. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of NeZha. + If None, use the same value as `hidden_dropout_prob` of `NeZhaModel` + instance `nezha`. Defaults to `None`. + """ + def __init__(self, nezha, num_choices=2, dropout=None): super(NeZhaForMultipleChoice, self).__init__() self.num_choices = num_choices self.nezha = nezha - self.dropout = nn.Dropout(dropout if dropout is not None else self.nezha.config["hidden_dropout_prob"]) + self.dropout = nn.Dropout(dropout if dropout is not None else + self.nezha.config["hidden_dropout_prob"]) self.classifier = nn.Linear(self.nezha.config["hidden_size"], 1) self.apply(self.init_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None): + r""" + The NeZhaForMultipleChoice forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`NeZhaModel`. + token_type_ids (Tensor, optional): + See :class:`NeZhaModel`. + attention_mask (list, optional): + See :class:`NeZhaModel`. + + Returns: + Tensor: Returns tensor `reshaped_logits`, a tensor of the input multiple choice classification logits. + Shape as `[batch_size, num_classes]` and dtype as `float32`. + """ + # input_ids: [bs, num_choice, seq_l] - input_ids = input_ids.reshape((-1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l] - + input_ids = input_ids.reshape( + (-1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l] + if token_type_ids: - token_type_ids = token_type_ids.reshape((-1, token_type_ids.shape[-1])) + token_type_ids = token_type_ids.reshape( + (-1, token_type_ids.shape[-1])) if attention_mask: - attention_mask = attention_mask.reshape((-1, attention_mask.shape[-1])) + attention_mask = attention_mask.reshape( + (-1, attention_mask.shape[-1])) _, pooled_output = self.nezha(input_ids, token_type_ids, attention_mask) pooled_output = self.dropout(pooled_output) - + logits = self.classifier(pooled_output) # logits: (bs*num_choice,1) - reshaped_logits = logits.reshape((-1, self.num_choices)) # logits: (bs, num_choice) + reshaped_logits = logits.reshape( + (-1, self.num_choices)) # logits: (bs, num_choice) return reshaped_logits diff --git a/paddlenlp/transformers/nezha/tokenizer.py b/paddlenlp/transformers/nezha/tokenizer.py index cc2428e5779b..d29af95a7172 100644 --- a/paddlenlp/transformers/nezha/tokenizer.py +++ b/paddlenlp/transformers/nezha/tokenizer.py @@ -7,35 +7,54 @@ from paddlenlp.transformers import PretrainedTokenizer, BasicTokenizer, WordpieceTokenizer - __all__ = ['NeZhaTokenizer'] class NeZhaTokenizer(PretrainedTokenizer): """ - Constructs a BERT tokenizer. It uses a basic tokenizer to do punctuation + Constructs a NeZha tokenizer. It uses a basic tokenizer to do punctuation splitting, lower casing and so on, and follows a WordPiece tokenizer to tokenize as subwords. + Args: - vocab_file (str): file path of the vocabulary - do_lower_case (bool): Whether the text strips accents and convert to - lower case. If you use the BERT pretrained model, lower is set to - Flase when using the cased model, otherwise it is set to True. - Default: True. - unk_token (str): The special token for unkown words. Default: "[UNK]". - sep_token (str): The special token for separator token . Default: "[SEP]". - pad_token (str): The special token for padding. Default: "[PAD]". - cls_token (str): The special token for cls. Default: "[CLS]". - mask_token (str): The special token for mask. Default: "[MASK]". + vocab_file (str): + The vocabulary file path (ends with '.txt') required to instantiate + a `WordpieceTokenizer`. + do_lower_case (bool): + Whether or not to lowercase the input when tokenizing. + Defaults to`True`. + unk_token (str): + A special token representing the *unknown (out-of-vocabulary)* token. + An unknown token is set to be `unk_token` inorder to be converted to an ID. + Defaults to "[UNK]". + sep_token (str): + A special token separating two different sentences in the same input. + Defaults to "[SEP]". + pad_token (str): + A special token used to make arrays of tokens the same size for batching purposes. + Defaults to "[PAD]". + cls_token (str): + A special token used for sequence classification. It is the last token + of the sequence when built with special tokens. Defaults to "[CLS]". + mask_token (str): + A special token representing a masked token. This is the token used + in the masked language modeling task which the model tries to predict the original unmasked ones. + Defaults to "[MASK]". Examples: - .. code-block:: python - from paddle.hapi.text import BertTokenizer - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') - # the following line get: ['he', 'was', 'a', 'puppet', '##eer'] - tokens = tokenizer('He was a puppeteer') - # the following line get: 'he was a puppeteer' - tokenizer.convert_tokens_to_string(tokens) + .. code-block:: + + from paddlenlp.transformers import NeZhaTokenizer + tokenizer = NeZhaTokenizer.from_pretrained('nezha-base-chinese') + + inputs = tokenizer('欢迎使用百度飞桨!') + print(inputs) + + ''' + {'input_ids': [101, 3614, 6816, 886, 4500, 4636, 2428, 7607, 3444, 8013, 102], + 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} + ''' + """ resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained pretrained_resource_files_map = { @@ -89,15 +108,16 @@ def __init__(self, @property def vocab_size(self): """ - return the size of vocabulary. + Return the size of vocabulary. + Returns: - int: the size of vocabulary. + int: The size of vocabulary. """ return len(self.vocab) def _tokenize(self, text): """ - End-to-end tokenization for BERT models. + End-to-end tokenization for NeZha models. Args: text (str): The text to be tokenized. @@ -112,24 +132,55 @@ def _tokenize(self, text): def tokenize(self, text): """ - End-to-end tokenization for BERT models. + Converts a string to a list of tokens. + Args: text (str): The text to be tokenized. - + Returns: - list: A list of string representing converted tokens. + List(str): A list of string representing converted tokens. + + Examples: + .. code-block:: + + from paddlenlp.transformers import NeZhaokenizer + + tokenizer = NeZhaTokenizer.from_pretrained('nezha-base-chinese') + tokens = tokenizer.tokenize('欢迎使用百度飞桨!') + + ''' + ['欢', '迎', '使', '用', '百', '度', '飞', '桨', '!'] + ''' + """ return self._tokenize(text) def convert_tokens_to_string(self, tokens): """ - Converts a sequence of tokens (list of string) in a single string. Since - the usage of WordPiece introducing `##` to concat subwords, also remove + Converts a sequence of tokens (list of string) to a single string. Since + the usage of WordPiece introducing `##` to concat subwords, also removes `##` when converting. + Args: tokens (list): A list of string representing tokens to be converted. + Returns: str: Converted string from tokens. + + Examples: + .. code-block:: + + from paddlenlp.transformers import NeZhaTokenizer + + tokenizer = NeZhaTokenizer.from_pretrained('bert-base-uncased') + tokens = tokenizer.tokenize('欢迎使用百度飞桨!') + ''' + ['欢', '迎', '使', '用', '百', '度', '飞', '桨', '!'] + ''' + strings = tokenizer.convert_tokens_to_string(tokens) + ''' + 欢 迎 使 用 百 度 飞 桨 ! + ''' """ out_string = " ".join(tokens).replace(" ##", "").strip() return out_string @@ -137,14 +188,14 @@ def convert_tokens_to_string(self, tokens): def num_special_tokens_to_add(self, pair=False): """ Returns the number of added tokens when encoding a sequence with special tokens. - Note: - This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this - inside your training loop. + Args: - pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the - number of added tokens in the case of a single sequence if set to False. + pair(bool): + Whether the input is a sequence pair or a single sequence. + Defaults to `False` and the input is a single sequence. + Returns: - Number of tokens added to sequences + int: Number of tokens added to sequences. """ token_ids_0 = [] token_ids_1 = [] @@ -157,17 +208,19 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. - A BERT sequence has the following format: - :: - - single sequence: ``[CLS] X [SEP]`` - - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + A NeZha sequence has the following format: + + - single sequence: ``[CLS] X [SEP]`` + - pair of sequences: ``[CLS] A [SEP] B [SEP]`` + Args: - token_ids_0 (:obj:`List[int]`): + token_ids_0 (List[int]): List of IDs to which the special tokens will be added. - token_ids_1 (:obj:`List[int]`, `optional`): - Optional second list of IDs for sequence pairs. + token_ids_1 (List[int], optional): + Optional second list of IDs for sequence pairs. Defaults to `None`. + Returns: - :obj:`List[int]`: List of input_id with the appropriate special tokens. + List[int]: List of input_id with the appropriate special tokens. """ if token_ids_1 is None: return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] @@ -179,20 +232,21 @@ def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None): """ - Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. - - A BERT offset_mapping has the following format: - :: - - single sequence: ``(0,0) X (0,0)`` - - pair of sequences: `(0,0) A (0,0) B (0,0)`` - + Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. + + A NeZha offset_mapping has the following format: + + - single sequence: ``(0,0) X (0,0)`` + - pair of sequences: ``(0,0) A (0,0) B (0,0)`` + Args: - offset_mapping_ids_0 (:obj:`List[tuple]`): - List of char offsets to which the special tokens will be added. - offset_mapping_ids_1 (:obj:`List[tuple]`, `optional`): - Optional second list of char offsets for offset mapping pairs. + offset_mapping_ids_0 (List[tuple]): + List of wordpiece offsets to which the special tokens will be added. + offset_mapping_ids_1 (List[tuple], optional): + Optional second list of wordpiece offsets for offset mapping pairs. Defaults to `None`. + Returns: - :obj:`List[tuple]`: List of char offsets with the appropriate offsets of special tokens. + List[tuple]: A list of wordpiece offsets with the appropriate offsets of special tokens. """ if offset_mapping_1 is None: return [(0, 0)] + offset_mapping_0 + [(0, 0)] @@ -205,18 +259,23 @@ def create_token_type_ids_from_sequences(self, token_ids_1=None): """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. - A BERT sequence pair mask has the following format: + + A NeZha sequence pair mask has the following format: :: + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | + If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + Args: - token_ids_0 (:obj:`List[int]`): - List of IDs. - token_ids_1 (:obj:`List[int]`, `optional`): - Optional second list of IDs for sequence pairs. + token_ids_0 (List[int]): + A list of `inputs_ids` for the first sequence. + token_ids_1 (List[int], optional): + Optional second list of IDs for sequence pairs. Defaults to None. + Returns: - :obj:`List[int]`: List of token_type_id according to the given sequence(s). + List[int]: List of token_type_id according to the given sequence(s). """ _sep = [self.sep_token_id] _cls = [self.cls_token_id] @@ -232,13 +291,18 @@ def get_special_tokens_mask(self, """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``encode`` methods. + Args: - token_ids_0 (List[int]): List of ids of the first sequence. - token_ids_1 (List[int], optinal): List of ids of the second sequence. - already_has_special_tokens (bool, optional): Whether or not the token list is already - formatted with special tokens for the model. Defaults to None. + token_ids_0 (List[int]): + A list of `inputs_ids` for the first sequence. + token_ids_1 (List[int], optinal): + Optional second list of IDs for sequence pairs. Defaults to `None`. + already_has_special_tokens (bool, optional): + Whether or not the token list is already formatted with special tokens for the model. + Defaults to `False`. + Returns: - results (List[int]): The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + List[int]: The list of integers either be 0 or 1: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: