diff --git a/paddlenlp/seq2vec/encoder.py b/paddlenlp/seq2vec/encoder.py index 709b6c9483e8..1f828408a0c3 100644 --- a/paddlenlp/seq2vec/encoder.py +++ b/paddlenlp/seq2vec/encoder.py @@ -24,14 +24,59 @@ class BoWEncoder(nn.Layer): - """ + r""" A `BoWEncoder` takes as input a sequence of vectors and returns a single vector, which simply sums the embeddings of a sequence across the time dimension. - The input to this module is of shape `(batch_size, num_tokens, emb_dim)`, + The input to this encoder is of shape `(batch_size, num_tokens, emb_dim)`, and the output is of shape `(batch_size, emb_dim)`. Args: - emb_dim(int): It is the input dimension to the encoder. + emb_dim(int): + The dimension of each vector in the input sequence. + + Example: + .. code-block:: + + import paddle + import paddle.nn as nn + import paddlenlp as nlp + + class BoWModel(nn.Layer): + def __init__(self, + vocab_size, + num_classes, + emb_dim=128, + padding_idx=0, + hidden_size=128, + fc_hidden_size=96): + super().__init__() + self.embedder = nn.Embedding( + vocab_size, emb_dim, padding_idx=padding_idx) + self.bow_encoder = nlp.seq2vec.BoWEncoder(emb_dim) + self.fc1 = nn.Linear(self.bow_encoder.get_output_dim(), hidden_size) + self.fc2 = nn.Linear(hidden_size, fc_hidden_size) + self.output_layer = nn.Linear(fc_hidden_size, num_classes) + + def forward(self, text): + # Shape: (batch_size, num_tokens, embedding_dim) + embedded_text = self.embedder(text) + + # Shape: (batch_size, embedding_dim) + summed = self.bow_encoder(embedded_text) + encoded_text = paddle.tanh(summed) + + # Shape: (batch_size, hidden_size) + fc1_out = paddle.tanh(self.fc1(encoded_text)) + # Shape: (batch_size, fc_hidden_size) + fc2_out = paddle.tanh(self.fc2(fc1_out)) + # Shape: (batch_size, num_classes) + logits = self.output_layer(fc2_out) + return logits + + model = BoWModel(vocab_size=100, num_classes=2) + + text = paddle.randint(low=1, high=10, shape=[1,10], dtype='int32') + logits = model(text) """ def __init__(self, emb_dim): @@ -39,7 +84,7 @@ def __init__(self, emb_dim): self._emb_dim = emb_dim def get_input_dim(self): - """ + r""" Returns the dimension of the vector input for each element in the sequence input to a `BoWEncoder`. This is not the shape of the input tensor, but the last element of that shape. @@ -47,24 +92,29 @@ def get_input_dim(self): return self._emb_dim def get_output_dim(self): - """ + r""" Returns the dimension of the final vector output by this `BoWEncoder`. This is not the shape of the returned tensor, but the last element of that shape. """ return self._emb_dim def forward(self, inputs, mask=None): - """ + r""" It simply sums the embeddings of a sequence across the time dimension. Args: - inputs (paddle.Tensor): Shape as `(batch_size, num_tokens, emb_dim)` - mask (obj: `paddle.Tensor`, optional, defaults to `None`): Shape same as `inputs`. Its each elements identify whether is padding token or not. + inputs (Tensor): + Shape as `(batch_size, num_tokens, emb_dim)` and dtype as `float32` or `float64`. + The sequence length of the input sequence. + mask (Tensor, optional): + Shape same as `inputs`. + Its each elements identify whether the corresponding input token is padding or not. If True, not padding token. If False, padding token. + Defaults to `None`. Returns: - summed (paddle.Tensor): Shape of `(batch_size, emb_dim)`. The result vector of BagOfEmbedding. - + Tensor: + Shape as `(batch_size, emb_dim)`, and dtype is same as `inputs`. The result vector of BagOfEmbedding. """ if mask is not None: inputs = inputs * mask @@ -75,10 +125,10 @@ def forward(self, inputs, mask=None): class CNNEncoder(nn.Layer): - """ + r""" A `CNNEncoder` takes as input a sequence of vectors and returns a single vector, a combination of multiple convolution layers and max pooling layers. - The input to this module is of shape `(batch_size, num_tokens, emb_dim)`, + The input to this encoder is of shape `(batch_size, num_tokens, emb_dim)`, and the output is of shape `(batch_size, ouput_dim)` or `(batch_size, len(ngram_filter_sizes) * num_filter)`. The CNN has one convolution layer for each ngram filter size. Each convolution operation gives @@ -91,27 +141,71 @@ class CNNEncoder(nn.Layer): (optionally) projected down to a lower dimensional output, specified by `output_dim`. We then use a fully connected layer to project in back to the desired output_dim. For more - details, refer to "A Sensitivity Analysis of (and Practitioners’ Guide to) Convolutional Neural - Networks for Sentence Classification", Zhang and Wallace 2016, particularly Figure 1. - ref: https://arxiv.org/abs/1510.03820 + details, refer to `A Sensitivity Analysis of (and Practitioners’ Guide to) Convolutional Neural + Networks for Sentence Classification `__ , + Zhang and Wallace 2016, particularly Figure 1. Args: emb_dim(int): - This is the input dimension to the encoder. + The dimension of each vector in the input sequence. num_filter(int): This is the output dim for each convolutional layer, which is the number of "filters" learned by that layer. - ngram_filter_sizes(Tuple[int]): + ngram_filter_sizes(Tuple[int], optinal): This specifies both the number of convolutional layers we will create and their sizes. The default of `(2, 3, 4, 5)` will have four convolutional layers, corresponding to encoding ngrams of size 2 to 5 with some number of filters. - conv_layer_activation(str): + conv_layer_activation(Layer, optional): Activation to use after the convolution layers. - output_dim(int): + Defaults to `paddle.nn.Tanh()`. + output_dim(int, optional): After doing convolutions and pooling, we'll project the collected features into a vector of this size. If this value is `None`, we will just return the result of the max pooling, giving an output of shape `len(ngram_filter_sizes) * num_filter`. - + Defaults to `None`. + + Example: + .. code-block:: + + import paddle + import paddle.nn as nn + import paddlenlp as nlp + + class CNNModel(nn.Layer): + def __init__(self, + vocab_size, + num_classes, + emb_dim=128, + padding_idx=0, + num_filter=128, + ngram_filter_sizes=(3, ), + fc_hidden_size=96): + super().__init__() + self.embedder = nn.Embedding( + vocab_size, emb_dim, padding_idx=padding_idx) + self.encoder = nlp.seq2vec.CNNEncoder( + emb_dim=emb_dim, + num_filter=num_filter, + ngram_filter_sizes=ngram_filter_sizes) + self.fc = nn.Linear(self.encoder.get_output_dim(), fc_hidden_size) + self.output_layer = nn.Linear(fc_hidden_size, num_classes) + + def forward(self, text): + # Shape: (batch_size, num_tokens, embedding_dim) + embedded_text = self.embedder(text) + # Shape: (batch_size, len(ngram_filter_sizes)*num_filter) + encoder_out = self.encoder(embedded_text) + encoder_out = paddle.tanh(encoder_out) + # Shape: (batch_size, fc_hidden_size) + fc_out = self.fc(encoder_out) + # Shape: (batch_size, num_classes) + logits = self.output_layer(fc_out) + return logits + + model = CNNModel(vocab_size=100, num_classes=2) + + text = paddle.randint(low=1, high=10, shape=[1,10], dtype='int32') + logits = model(text) """ def __init__(self, @@ -145,7 +239,7 @@ def __init__(self, self._output_dim = maxpool_output_dim def get_input_dim(self): - """ + r""" Returns the dimension of the vector input for each element in the sequence input to a `CNNEncoder`. This is not the shape of the input tensor, but the last element of that shape. @@ -153,26 +247,31 @@ def get_input_dim(self): return self._emb_dim def get_output_dim(self): - """ + r""" Returns the dimension of the final vector output by this `CNNEncoder`. This is not the shape of the returned tensor, but the last element of that shape. """ return self._output_dim def forward(self, inputs, mask=None): - """ + r""" The combination of multiple convolution layers and max pooling layers. Args: - inputs (paddle.Tensor): Shape as `(batch_size, num_tokens, emb_dim)` - mask (obj: `paddle.Tensor`, optional, defaults to `None`): Shape same as `inputs`. - Its each elements identify whether is padding token or not. - If True, not padding token. If False, padding token. + inputs (Tensor): + Shape as `(batch_size, num_tokens, emb_dim)` and dtype as `float32` or `float64`. + Tensor containing the features of the input sequence. + mask (Tensor, optional): + Shape shoule be same as `inputs` and dtype as `int32`, `int64`, `float32` or `float64`. + Its each elements identify whether the corresponding input token is padding or not. + If True, not padding token. If False, padding token. + Defaults to `None` Returns: - result (paddle.Tensor): If output_dim is None, the result shape - is of `(batch_size, output_dim)`; if not, the result shape - is of `(batch_size, len(ngram_filter_sizes) * num_filter)`. + Tensor: + If output_dim is None, the result shape is of `(batch_size, output_dim)` and + dtype is `float`; + If not, the result shape is of `(batch_size, len(ngram_filter_sizes) * num_filter)`. """ if mask is not None: @@ -198,11 +297,13 @@ def forward(self, inputs, mask=None): class GRUEncoder(nn.Layer): - """ + r""" A GRUEncoder takes as input a sequence of vectors and returns a - single vector, which is a combination of multiple GRU layers. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, - The output is of shape `(batch_size, hidden_size*2)` if GRU is bidirection; + single vector, which is a combination of multiple `paddle.nn.GRU + `__ subclass. + The input to this encoder is of shape `(batch_size, num_tokens, input_size)`, + The output is of shape `(batch_size, hidden_size * 2)` if GRU is bidirection; If not, output is of shape `(batch_size, hidden_size)`. Paddle's GRU have two outputs: the hidden state for every time step at last layer, @@ -211,24 +312,87 @@ class GRUEncoder(nn.Layer): step at last layer to create a single vector. If not None, we use the hidden state of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`); And if direction is bidirection, the we concat the hidden state of the last forward - gru and backward gru layer to create a single vector (shape of `(batch_size, hidden_size*2)`). + gru and backward gru layer to create a single vector (shape of `(batch_size, hidden_size * 2)`). Args: - input_size (obj:`int`, required): The number of expected features in the input (the last dimension). - hidden_size (obj:`int`, required): The number of features in the hidden state. - num_layers (obj:`int`, optional, defaults to 1): Number of recurrent layers. + input_size (int): + The number of expected features in the input (the last dimension). + hidden_size (int): + The number of features in the hidden state. + num_layers (int, optional): + Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. - direction (obj:`str`, optional, defaults to obj:`forward`): The direction of the network. - It can be `forward` and `bidirect` (it means bidirection network). - If `biderect`, it is a birectional GRU, and returns the concat output from both directions. - dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer - on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. - pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None, - then the GRUEncoder will return the hidden state of the last time step at last layer as a single vector. - If pooling_type is not None, it must be one of `sum`, `max` and `mean`. Then it will be pooled on - the GRU output (the hidden state of every time step at last layer) to create a single vector. - + Defaults to 1. + direction (str, optional): + The direction of the network. It can be "forward" and "bidirect" + (it means bidirection network). If "bidirect", it is a birectional GRU, + and returns the concat output from both directions. + Defaults to "forward". + dropout (float, optional): + If non-zero, introduces a Dropout layer on the outputs of each GRU layer + except the last layer, with dropout probability equal to dropout. + Defaults to 0.0. + pooling_type (str, optional): + If `pooling_type` is None, then the GRUEncoder will return the hidden state of + the last time step at last layer as a single vector. + If pooling_type is not None, it must be one of "sum", "max" and "mean". + Then it will be pooled on the GRU output (the hidden state of every time + step at last layer) to create a single vector. + Defaults to `None` + + Example: + .. code-block:: + + import paddle + import paddle.nn as nn + import paddlenlp as nlp + + class GRUModel(nn.Layer): + def __init__(self, + vocab_size, + num_classes, + emb_dim=128, + padding_idx=0, + gru_hidden_size=198, + direction='forward', + gru_layers=1, + dropout_rate=0.0, + pooling_type=None, + fc_hidden_size=96): + super().__init__() + self.embedder = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=emb_dim, + padding_idx=padding_idx) + self.gru_encoder = nlp.seq2vec.GRUEncoder( + emb_dim, + gru_hidden_size, + num_layers=gru_layers, + direction=direction, + dropout=dropout_rate, + pooling_type=pooling_type) + self.fc = nn.Linear(self.gru_encoder.get_output_dim(), fc_hidden_size) + self.output_layer = nn.Linear(fc_hidden_size, num_classes) + + def forward(self, text, seq_len): + # Shape: (batch_size, num_tokens, embedding_dim) + embedded_text = self.embedder(text) + # Shape: (batch_size, num_tokens, num_directions*gru_hidden_size) + # num_directions = 2 if direction is 'bidirect' + # if not, num_directions = 1 + text_repr = self.gru_encoder(embedded_text, sequence_length=seq_len) + # Shape: (batch_size, fc_hidden_size) + fc_out = paddle.tanh(self.fc(text_repr)) + # Shape: (batch_size, num_classes) + logits = self.output_layer(fc_out) + return logits + + model = GRUModel(vocab_size=100, num_classes=2) + + text = paddle.randint(low=1, high=10, shape=[1,10], dtype='int32') + seq_len = paddle.to_tensor([10]) + logits = model(text, seq_len) """ def __init__(self, @@ -253,7 +417,7 @@ def __init__(self, **kwargs) def get_input_dim(self): - """ + r""" Returns the dimension of the vector input for each element in the sequence input to a `GRUEncoder`. This is not the shape of the input tensor, but the last element of that shape. @@ -261,7 +425,7 @@ def get_input_dim(self): return self._input_size def get_output_dim(self): - """ + r""" Returns the dimension of the final vector output by this `GRUEncoder`. This is not the shape of the returned tensor, but the last element of that shape. """ @@ -271,19 +435,21 @@ def get_output_dim(self): return self._hidden_size def forward(self, inputs, sequence_length): - """ - GRUEncoder takes the a sequence of vectors and and returns a - single vector, which is a combination of multiple GRU layers. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, - The output is of shape `(batch_size, hidden_size*2)` if GRU is bidirection; + r""" + GRUEncoder takes the a sequence of vectors and and returns a single vector, + which is a combination of multiple GRU layers. The input to this + encoder is of shape `(batch_size, num_tokens, input_size)`, + The output is of shape `(batch_size, hidden_size * 2)` if GRU is bidirection; If not, output is of shape `(batch_size, hidden_size)`. Args: - inputs (paddle.Tensor): Shape as `(batch_size, num_tokens, input_size)`. - sequence_length (paddle.Tensor): Shape as `(batch_size)`. + inputs (Tensor): Shape as `(batch_size, num_tokens, input_size)`. + Tensor containing the features of the input sequence. + sequence_length (Tensor): Shape as `(batch_size)`. + The sequence length of the input sequence. Returns: - last_hidden (paddle.Tensor): Shape as `(batch_size, hidden_size)`. + Tensor: Shape as `(batch_size, hidden_size)` and dtype is `float`. The hidden state at the last time step for every layer. """ @@ -295,7 +461,7 @@ def forward(self, inputs, sequence_length): # If gru is not bidirection, then output is the hidden state of the last time step # at last layer. Output is shape of `(batch_size, hidden_size)`. # If gru is bidirection, then output is concatenation of the forward and backward hidden state - # of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`. + # of the last time step at last layer. Output is shape of `(batch_size, hidden_size * 2)`. if self._direction != 'bidirect': output = last_hidden[-1, :, :] else: @@ -304,8 +470,8 @@ def forward(self, inputs, sequence_length): else: # We exploit the `encoded_text` (the hidden state at the every time step for last layer) # to create a single vector. We perform pooling on the encoded text. - # The output shape is `(batch_size, hidden_size*2)` if use bidirectional GRU, - # otherwise the output shape is `(batch_size, hidden_size*2)`. + # The output shape is `(batch_size, hidden_size * 2)` if use bidirectional GRU, + # otherwise the output shape is `(batch_size, hidden_size * 2)`. if self._pooling_type == 'sum': output = paddle.sum(encoded_text, axis=1) elif self._pooling_type == 'max': @@ -321,11 +487,13 @@ def forward(self, inputs, sequence_length): class LSTMEncoder(nn.Layer): - """ - A LSTMEncoder takes as input a sequence of vectors and returns a - single vector, which is a combination of multiple LSTM layers. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, - The output is of shape `(batch_size, hidden_size*2)` if LSTM is bidirection; + r""" + An LSTMEncoder takes as input a sequence of vectors and returns a + single vector, which is a combination of multiple `paddle.nn.LSTM + `__ subclass. + The input to this encoder is of shape `(batch_size, num_tokens, input_size)`. + The output is of shape `(batch_size, hidden_size * 2)` if LSTM is bidirection; If not, output is of shape `(batch_size, hidden_size)`. Paddle's LSTM have two outputs: the hidden state for every time step at last layer, @@ -334,24 +502,86 @@ class LSTMEncoder(nn.Layer): step at last layer to create a single vector. If not None, we use the hidden state of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`); And if direction is bidirection, the we concat the hidden state of the last forward - lstm and backward lstm layer to create a single vector (shape of `(batch_size, hidden_size*2)`). + lstm and backward lstm layer to create a single vector (shape of `(batch_size, hidden_size * 2)`). Args: - input_size (int): The number of expected features in the input (the last dimension). - hidden_size (int): The number of features in the hidden state. - num_layers (int): Number of recurrent layers. + input_size (int): + The number of expected features in the input (the last dimension). + hidden_size (int): + The number of features in the hidden state. + num_layers (int, optional): + Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. - direction (str): The direction of the network. - It can be `forward` or `bidirect` (it means bidirection network). - If `biderect`, it is a birectional LSTM, and returns the concat output from both directions. - dropout (float): If non-zero, introduces a Dropout layer - on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. - pooling_type (str): If `pooling_type` is None, - then the LSTMEncoder will return the hidden state of the last time step at last layer as a single vector. - If pooling_type is not None, it must be one of `sum`, `max` and `mean`. Then it will be pooled on - the LSTM output (the hidden state of every time step at last layer) to create a single vector. - + Defaults to 1. + direction (str, optional): + The direction of the network. It can be "forward" or "bidirect" (it means bidirection network). + If "bidirect", it is a birectional LSTM, and returns the concat output from both directions. + Defaults to "forward". + dropout (float, optional): + If non-zero, introduces a Dropout layer on the outputs of each LSTM layer + except the last layer, with dropout probability equal to dropout. + Defaults to 0.0 . + pooling_type (str, optional): + If `pooling_type` is None, then the LSTMEncoder will return + the hidden state of the last time step at last layer as a single vector. + If pooling_type is not None, it must be one of "sum", "max" and "mean". + Then it will be pooled on the LSTM output (the hidden state of every + time step at last layer) to create a single vector. + Defaults to `None`. + + Example: + .. code-block:: + + import paddle + import paddle.nn as nn + import paddlenlp as nlp + + class LSTMModel(nn.Layer): + def __init__(self, + vocab_size, + num_classes, + emb_dim=128, + padding_idx=0, + lstm_hidden_size=198, + direction='forward', + lstm_layers=1, + dropout_rate=0.0, + pooling_type=None, + fc_hidden_size=96): + super().__init__() + self.embedder = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=emb_dim, + padding_idx=padding_idx) + self.lstm_encoder = nlp.seq2vec.LSTMEncoder( + emb_dim, + lstm_hidden_size, + num_layers=lstm_layers, + direction=direction, + dropout=dropout_rate, + pooling_type=pooling_type) + self.fc = nn.Linear(self.lstm_encoder.get_output_dim(), fc_hidden_size) + self.output_layer = nn.Linear(fc_hidden_size, num_classes) + + def forward(self, text, seq_len): + # Shape: (batch_size, num_tokens, embedding_dim) + embedded_text = self.embedder(text) + # Shape: (batch_size, num_tokens, num_directions*lstm_hidden_size) + # num_directions = 2 if direction is 'bidirect' + # if not, num_directions = 1 + text_repr = self.lstm_encoder(embedded_text, sequence_length=seq_len) + # Shape: (batch_size, fc_hidden_size) + fc_out = paddle.tanh(self.fc(text_repr)) + # Shape: (batch_size, num_classes) + logits = self.output_layer(fc_out) + return logits + + model = LSTMModel(vocab_size=100, num_classes=2) + + text = paddle.randint(low=1, high=10, shape=[1,10], dtype='int32') + seq_len = paddle.to_tensor([10]) + logits = model(text, seq_len) """ def __init__(self, @@ -377,7 +607,7 @@ def __init__(self, **kwargs) def get_input_dim(self): - """ + r""" Returns the dimension of the vector input for each element in the sequence input to a `LSTMEncoder`. This is not the shape of the input tensor, but the last element of that shape. @@ -385,7 +615,7 @@ def get_input_dim(self): return self._input_size def get_output_dim(self): - """ + r""" Returns the dimension of the final vector output by this `LSTMEncoder`. This is not the shape of the returned tensor, but the last element of that shape. """ @@ -395,19 +625,22 @@ def get_output_dim(self): return self._hidden_size def forward(self, inputs, sequence_length): - """ + r""" LSTMEncoder takes the a sequence of vectors and and returns a single vector, which is a combination of multiple LSTM layers. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, - The output is of shape `(batch_size, hidden_size*2)` if LSTM is bidirection; + The input to this encoder is of shape `(batch_size, num_tokens, input_size)`, + The output is of shape `(batch_size, hidden_size * 2)` if LSTM is bidirection; If not, output is of shape `(batch_size, hidden_size)`. Args: - inputs (paddle.Tensor): Shape as `(batch_size, num_tokens, input_size)`. - sequence_length (paddle.Tensor): Shape as `(batch_size)`. + inputs (Tensor): Shape as `(batch_size, num_tokens, input_size)`. + Tensor containing the features of the input sequence. + sequence_length (Tensor): Shape as `(batch_size)`. + The sequence length of the input sequence. Returns: - last_hidden (paddle.Tensor): Shape as `(batch_size, hidden_size)`. + Tensor: + Shape as `(batch_size, hidden_size)` and dtype as float. The hidden state at the last time step for every layer. """ @@ -419,7 +652,7 @@ def forward(self, inputs, sequence_length): # If lstm is not bidirection, then output is the hidden state of the last time step # at last layer. Output is shape of `(batch_size, hidden_size)`. # If lstm is bidirection, then output is concatenation of the forward and backward hidden state - # of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`. + # of the last time step at last layer. Output is shape of `(batch_size, hidden_size * 2)`. if self._direction != 'bidirect': output = last_hidden[-1, :, :] else: @@ -428,8 +661,8 @@ def forward(self, inputs, sequence_length): else: # We exploit the `encoded_text` (the hidden state at the every time step for last layer) # to create a single vector. We perform pooling on the encoded text. - # The output shape is `(batch_size, hidden_size*2)` if use bidirectional LSTM, - # otherwise the output shape is `(batch_size, hidden_size*2)`. + # The output shape is `(batch_size, hidden_size * 2)` if use bidirectional LSTM, + # otherwise the output shape is `(batch_size, hidden_size * 2)`. if self._pooling_type == 'sum': output = paddle.sum(encoded_text, axis=1) elif self._pooling_type == 'max': @@ -445,11 +678,13 @@ def forward(self, inputs, sequence_length): class RNNEncoder(nn.Layer): - """ + r""" A RNNEncoder takes as input a sequence of vectors and returns a - single vector, which is a combination of multiple RNN layers. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, - The output is of shape `(batch_size, hidden_size*2)` if RNN is bidirection; + single vector, which is a combination of multiple `paddle.nn.RNN + `__ subclass. + The input to this encoder is of shape `(batch_size, num_tokens, input_size)`, + The output is of shape `(batch_size, hidden_size * 2)` if RNN is bidirection; If not, output is of shape `(batch_size, hidden_size)`. Paddle's RNN have two outputs: the hidden state for every time step at last layer, @@ -458,24 +693,86 @@ class RNNEncoder(nn.Layer): step at last layer to create a single vector. If not None, we use the hidden state of the last time step at last layer as a single output (shape of `(batch_size, hidden_size)`); And if direction is bidirection, the we concat the hidden state of the last forward - rnn and backward rnn layer to create a single vector (shape of `(batch_size, hidden_size*2)`). + rnn and backward rnn layer to create a single vector (shape of `(batch_size, hidden_size * 2)`). Args: - input_size (obj:`int`, required): The number of expected features in the input (the last dimension). - hidden_size (obj:`int`, required): The number of features in the hidden state. - num_layers (obj:`int`, optional, defaults to 1): Number of recurrent layers. + input_size (int): + The number of expected features in the input (the last dimension). + hidden_size (int): + The number of features in the hidden state. + num_layers (int, optional): + Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. - direction (obj:`str`, optional, defaults to obj:`forward`): The direction of the network. - It can be "forward" and "bidirect" (it means bidirection network). - If `biderect`, it is a birectional RNN, and returns the concat output from both directions. - dropout (obj:`float`, optional, defaults to 0.0): If non-zero, introduces a Dropout layer - on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. - pooling_type (obj: `str`, optional, defaults to obj:`None`): If `pooling_type` is None, - then the RNNEncoder will return the hidden state of the last time step at last layer as a single vector. - If pooling_type is not None, it must be one of `sum`, `max` and `mean`. Then it will be pooled on - the RNN output (the hidden state of every time step at last layer) to create a single vector. - + Defaults to 1. + direction (str, optional): + The direction of the network. It can be "forward" and "bidirect" + (it means bidirection network). If "biderect", it is a birectional RNN, + and returns the concat output from both directions. Defaults to "forward" + dropout (float, optional): + If non-zero, introduces a Dropout layer on the outputs of each RNN layer + except the last layer, with dropout probability equal to dropout. + Defaults to 0.0. + pooling_type (str, optional): + If `pooling_type` is None, then the RNNEncoder will return the hidden state + of the last time step at last layer as a single vector. + If pooling_type is not None, it must be one of "sum", "max" and "mean". + Then it will be pooled on the RNN output (the hidden state of every time + step at last layer) to create a single vector. + Defaults to `None`. + + Example: + .. code-block:: + + import paddle + import paddle.nn as nn + import paddlenlp as nlp + + class RNNModel(nn.Layer): + def __init__(self, + vocab_size, + num_classes, + emb_dim=128, + padding_idx=0, + rnn_hidden_size=198, + direction='forward', + rnn_layers=1, + dropout_rate=0.0, + pooling_type=None, + fc_hidden_size=96): + super().__init__() + self.embedder = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=emb_dim, + padding_idx=padding_idx) + self.rnn_encoder = nlp.seq2vec.RNNEncoder( + emb_dim, + rnn_hidden_size, + num_layers=rnn_layers, + direction=direction, + dropout=dropout_rate, + pooling_type=pooling_type) + self.fc = nn.Linear(self.rnn_encoder.get_output_dim(), fc_hidden_size) + self.output_layer = nn.Linear(fc_hidden_size, num_classes) + + def forward(self, text, seq_len): + # Shape: (batch_size, num_tokens, embedding_dim) + embedded_text = self.embedder(text) + # Shape: (batch_size, num_tokens, num_directions*rnn_hidden_size) + # num_directions = 2 if direction is 'bidirect' + # if not, num_directions = 1 + text_repr = self.rnn_encoder(embedded_text, sequence_length=seq_len) + # Shape: (batch_size, fc_hidden_size) + fc_out = paddle.tanh(self.fc(text_repr)) + # Shape: (batch_size, num_classes) + logits = self.output_layer(fc_out) + return logits + + model = RNNModel(vocab_size=100, num_classes=2) + + text = paddle.randint(low=1, high=10, shape=[1,10], dtype='int32') + seq_len = paddle.to_tensor([10]) + logits = model(text, seq_len) """ def __init__(self, @@ -501,7 +798,7 @@ def __init__(self, **kwargs) def get_input_dim(self): - """ + r""" Returns the dimension of the vector input for each element in the sequence input to a `RNNEncoder`. This is not the shape of the input tensor, but the last element of that shape. @@ -509,7 +806,7 @@ def get_input_dim(self): return self._input_size def get_output_dim(self): - """ + r""" Returns the dimension of the final vector output by this `RNNEncoder`. This is not the shape of the returned tensor, but the last element of that shape. """ @@ -519,19 +816,22 @@ def get_output_dim(self): return self._hidden_size def forward(self, inputs, sequence_length): - """ + r""" RNNEncoder takes the a sequence of vectors and and returns a single vector, which is a combination of multiple RNN layers. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, - The output is of shape `(batch_size, hidden_size*2)` if RNN is bidirection; + The input to this encoder is of shape `(batch_size, num_tokens, input_size)`. + The output is of shape `(batch_size, hidden_size * 2)` if RNN is bidirection; If not, output is of shape `(batch_size, hidden_size)`. Args: - inputs (paddle.Tensor): Shape as `(batch_size, num_tokens, input_size)`. - sequence_length (paddle.Tensor): Shape as `(batch_size)`. + inputs (Tensor): Shape as `(batch_size, num_tokens, input_size)`. + Tensor containing the features of the input sequence. + sequence_length (Tensor): Shape as `(batch_size)`. + The sequence length of the input sequence. Returns: - last_hidden (paddle.Tensor): Shape as `(batch_size, hidden_size)`. + last_hidden (Tensor): + Shape as `(batch_size, hidden_size)` and dtype as `float`. The hidden state at the last time step for every layer. """ @@ -543,7 +843,7 @@ def forward(self, inputs, sequence_length): # If rnn is not bidirection, then output is the hidden state of the last time step # at last layer. Output is shape of `(batch_size, hidden_size)`. # If rnn is bidirection, then output is concatenation of the forward and backward hidden state - # of the last time step at last layer. Output is shape of `(batch_size, hidden_size*2)`. + # of the last time step at last layer. Output is shape of `(batch_size, hidden_size * 2)`. if self._direction != 'bidirect': output = last_hidden[-1, :, :] else: @@ -552,8 +852,8 @@ def forward(self, inputs, sequence_length): else: # We exploit the `encoded_text` (the hidden state at the every time step for last layer) # to create a single vector. We perform pooling on the encoded text. - # The output shape is `(batch_size, hidden_size*2)` if use bidirectional RNN, - # otherwise the output shape is `(batch_size, hidden_size*2)`. + # The output shape is `(batch_size, hidden_size * 2)` if use bidirectional RNN, + # otherwise the output shape is `(batch_size, hidden_size * 2)`. if self._pooling_type == 'sum': output = paddle.sum(encoded_text, axis=1) elif self._pooling_type == 'max': @@ -573,7 +873,7 @@ class Chomp1d(nn.Layer): Remove the elements on the right. Args: - chomp_size ([int]): The number of elements removed. + chomp_size (int): The number of elements removed. """ def __init__(self, chomp_size): @@ -663,7 +963,7 @@ class TCNEncoder(nn.Layer): r""" A `TCNEncoder` takes as input a sequence of vectors and returns a single vector, which is the last one time step in the feature map. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, + The input to this encoder is of shape `(batch_size, num_tokens, input_size)`, and the output is of shape `(batch_size, num_channels[-1])` with a receptive filed: @@ -723,7 +1023,7 @@ def forward(self, inputs): r""" TCNEncoder takes as input a sequence of vectors and returns a single vector, which is the last one time step in the feature map. - The input to this module is of shape `(batch_size, num_tokens, input_size)`, + The input to this encoder is of shape `(batch_size, num_tokens, input_size)`, and the output is of shape `(batch_size, num_channels[-1])` with a receptive filed: diff --git a/paddlenlp/transformers/ernie/modeling.py b/paddlenlp/transformers/ernie/modeling.py index 2e75be6e3b61..794fc633de8d 100644 --- a/paddlenlp/transformers/ernie/modeling.py +++ b/paddlenlp/transformers/ernie/modeling.py @@ -25,8 +25,8 @@ class ErnieEmbeddings(nn.Layer): - """ - Include embeddings from word, position and token_type embeddings + r""" + Include embeddings from word, position and token_type embeddings. """ def __init__(self, @@ -67,9 +67,6 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None): class ErniePooler(nn.Layer): - """ - """ - def __init__(self, hidden_size): super(ErniePooler, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) @@ -85,11 +82,13 @@ def forward(self, hidden_states): class ErniePretrainedModel(PretrainedModel): - """ + r""" An abstract class for pretrained ERNIE models. It provides ERNIE related `model_config_file`, `resource_files_names`, `pretrained_resource_files_map`, `pretrained_init_configuration`, `base_model_prefix` for downloading and - loading pretrained models. See `PretrainedModel` for more details. + loading pretrained models. + Refer to :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. + """ model_config_file = "model_config.json" @@ -183,7 +182,55 @@ def init_weights(self, layer): @register_base_model class ErnieModel(ErniePretrainedModel): - """ + r""" + The bare ERNIE Model transformer outputting raw hidden-states without any specific head on top. + + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + + vocab_size (int): + Vocabulary size of the ERNIE model. Also is the vocab size of token embedding matrix. + hidden_size (int, optional): + Dimension of the encoder layers and the pooler layer. Defaults to ``768``. + num_hidden_layers (int, optional): + Number of hidden layers in the Transformer encoder. Defaults to ``12``. + num_attention_heads (int, optional): + Number of attention heads for each attention layer in the Transformer encoder. + Defaults to ``12``. + intermediate_size (int, optional): + Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + Defaults to ``3072``. + hidden_act (str, optional): + The non-linear activation function in the feed-forward layer. + ``"gelu"``, ``"relu"`` and any other paddle supported activation functions + are supported. Defaults to ``"gelu"``. + hidden_dropout_prob (float, optional): + The dropout probability for all fully connected layers in the embeddings and encoder. + Defaults to ``0.1``. + attention_probs_dropout_prob (float, optional): + The dropout probability for all fully connected layers in the pooler. + Defaults to ``0.1``. + max_position_embeddings (int, optional): + The max position index of an input sequence. Defaults to ``512``. + type_vocab_size (int, optional): + The vocabulary size of the `token_type_ids` passed when calling `~transformers.ErnieModel`. + Defaults to ``2``. + initializer_range (float, optional): + The standard deviation of the normal initializer. Defaults to 0.02. + + .. note:: + A normal_initializer initializes weight matrices as normal distributions. + See :meth:`ErniePretrainedModel._init_weights()` for how weights are initialized in `ErnieModel`. + + pad_token_id(int, optional): + The pad token index in the token vocabulary. + """ def __init__(self, @@ -196,7 +243,7 @@ def __init__(self, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, - type_vocab_size=16, + type_vocab_size=2, initializer_range=0.02, pad_token_id=0): super(ErnieModel, self).__init__() @@ -222,6 +269,64 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None): + r""" + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + token_type_ids (Tensor, optional): + Segment token indices to indicate first and second portions of the inputs. + Indices can be either 0 or 1: + + - 0 corresponds to a **sentence A** token, + - 1 corresponds to a **sentence B** token. + + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + Defaults to None, which means no segment embeddings is added to token embeddings. + position_ids (Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`. + attention_mask (Tensor, optional): + Mask to indicate whether to perform attention on each input token or not. + The values should be either 0 or 1. The attention scores will be set + to **-infinity** for any positions in the mask that are **0**, and will be + **unchanged** for positions that are **1**. + + - **1** for tokens that are **not masked**, + - **0** for tokens that are **masked**. + + It's data type should be `float32` and has a shape of [batch_size, sequence_length]. + Defaults to `None`. + + Returns: + A tuple of shape (``sequence_output``, ``pooled_output``). + + With the fields: + - sequence_output (Tensor): + Sequence of hidden-states at the last layer of the model. + It's data type should be `float` and has a shape of `(batch_size, seq_lens, hidden_size)`. + ``seq_lens`` corresponds to the length of input sequence. + - pooled_output (Tensor): + A Tensor of the first token representation. + It's data type should be `float` and has a shape of `(batch_size, hidden_size]`. + We "pool" the model by simply taking the hidden state corresponding to the first token. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import ErnieModel, ErnieTokenizer + + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + model = ErnieModel.from_pretrained('ernie-1.0') + + inputs = tokenizer("这是个测试样例") + inputs = {k:paddle.to_tensor(v) for (k, v) in inputs.items()} + sequence_output, pooled_output = model(**inputs) + + """ if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id @@ -238,14 +343,18 @@ def forward(self, class ErnieForSequenceClassification(ErniePretrainedModel): - """ + r""" Model for sentence (pair) classification task with ERNIE. + Args: - ernie (ErnieModel): An instance of `ErnieModel`. - num_classes (int, optional): The number of classes. Default 2 - dropout (float, optional): The dropout probability for output of ERNIE. - If None, use the same value as `hidden_dropout_prob` of `ErnieModel` - instance `Ernie`. Default None + ernie (ErnieModel): + An instance of `paddlenlp.transformers.ErnieModel`. + num_classes (int, optional): + The number of classes. Default to `2`. + dropout (float, optional): + The dropout probability for output of ERNIE. + If None, use the same value as `hidden_dropout_prob` + of `paddlenlp.transformers.ErnieModel` instance. Defaults to `None`. """ def __init__(self, ernie, num_classes=2, dropout=None): @@ -263,6 +372,57 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None): + r""" + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + token_type_ids (Tensor, optional): + Segment token indices to indicate first and second portions of the inputs. + Indices can be either 0 or 1: + + - 0 corresponds to a **sentence A** token, + - 1 corresponds to a **sentence B** token. + + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + Defaults to None, which means no segment embeddings is added to token embeddings. + position_ids (Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`. + attention_mask (Tensor, optional): + Mask to indicate whether to perform attention on each input token or not. + The values should be either 0 or 1. The attention scores will be set + to **-infinity** for any positions in the mask that are **0**, and will be + **unchanged** for positions that are **1**. + + - **1** for tokens that are **not masked**, + - **0** for tokens that are **masked**. + + It's data type should be `float32` and has a shape of [batch_size, sequence_length]. + Defaults to `None`. + + + Returns: + logits (Tensor): + A Tensor of the input text classification logits. + Shape as `(batch_size, num_classes)` and dtype as `float`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer + + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + model = ErnieForSequenceClassification.from_pretrained('ernie-1.0') + + inputs = tokenizer("这是个测试样例") + inputs = {k:paddle.to_tensor(v) for (k, v) in inputs.items()} + logits = model(**inputs) + + """ _, pooled_output = self.ernie( input_ids, token_type_ids=token_type_ids, @@ -275,6 +435,15 @@ def forward(self, class ErnieForQuestionAnswering(ErniePretrainedModel): + """ + Model for Question and Answering task with ERNIE. + + + Args: + ernie (`ErnieModel`): + An instance of `ErnieModel`. + """ + def __init__(self, ernie): super(ErnieForQuestionAnswering, self).__init__() self.ernie = ernie # allow ernie to be config @@ -286,6 +455,59 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None): + r""" + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + token_type_ids (Tensor, optional): + Segment token indices to indicate first and second portions of the inputs. + Indices can be either 0 or 1: + + - 0 corresponds to a **sentence A** token, + - 1 corresponds to a **sentence B** token. + + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + Defaults to None, which means no segment embeddings is added to token embeddings. + position_ids (Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`. + attention_mask (Tensor, optional): + Mask to indicate whether to perform attention on each input token or not. + The values should be either 0 or 1. The attention scores will be set + to **-infinity** for any positions in the mask that are **0**, and will be + **unchanged** for positions that are **1**. + + - **1** for tokens that are **not masked**, + - **0** for tokens that are **masked**. + + It's data type should be `float32` and has a shape of [batch_size, sequence_length]. + Defaults to `None`. + + + Returns: + A tuple of shape (``start_logits``, ``end_logits``). + + With the fields: + - start_logits(Tensor): The logits of start position of prediction answer. + - end_logits(Tensor): The logits of end position of prediction answer. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import ErnieForQuestionAnswering, ErnieTokenizer + + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + model = ErnieForQuestionAnswering.from_pretrained('ernie-1.0') + + inputs = tokenizer("这是个测试样例") + inputs = {k:paddle.to_tensor(v) for (k, v) in inputs.items()} + logits = model(**inputs) + """ + sequence_output, _ = self.ernie( input_ids, token_type_ids=token_type_ids, @@ -300,6 +522,22 @@ def forward(self, class ErnieForTokenClassification(ErniePretrainedModel): + r""" + ERNIE Model transformer with a sequence classification/regression head on top + (a linear layer on top of the pooledoutput) e.g. for GLUE tasks. + + + Args: + ernie (`ErnieModel`): + An instance of `ErnieModel`. + num_classes (int, optional): + The number of classes. Default to `2`. + dropout (float, optional): + The dropout probability for output of ERNIE. + If None, use the same value as `hidden_dropout_prob` + of `ErnieModel` instance `Ernie`. Defaults to `None`. + """ + def __init__(self, ernie, num_classes=2, dropout=None): super(ErnieForTokenClassification, self).__init__() self.num_classes = num_classes @@ -315,6 +553,57 @@ def forward(self, token_type_ids=None, position_ids=None, attention_mask=None): + r""" + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + token_type_ids (Tensor, optional): + Segment token indices to indicate first and second portions of the inputs. + Indices can be either 0 or 1: + + - 0 corresponds to a **sentence A** token, + - 1 corresponds to a **sentence B** token. + + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + Defaults to None, which means no segment embeddings is added to token embeddings. + position_ids (Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`. + attention_mask (Tensor, optional): + Mask to indicate whether to perform attention on each input token or not. + The values should be either 0 or 1. The attention scores will be set + to **-infinity** for any positions in the mask that are **0**, and will be + **unchanged** for positions that are **1**. + + - **1** for tokens that are **not masked**, + - **0** for tokens that are **masked**. + + It's data type should be `float32` and has a shape of [batch_size, sequence_length]. + Defaults to `None`. + + + Returns: + logits (Tensor): + A Tensor of the input text classification logits, shape as (batch_size, seq_lens, `num_classes`). + seq_lens mean the number of tokens of the input sequence. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer + + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + model = ErnieForTokenClassification.from_pretrained('ernie-1.0') + + inputs = tokenizer("这是个测试样例") + inputs = {k:paddle.to_tensor(v) for (k, v) in inputs.items()} + logits = model(**inputs) + + """ sequence_output, _ = self.ernie( input_ids, token_type_ids=token_type_ids, @@ -327,6 +616,10 @@ def forward(self, class ErnieLMPredictionHead(nn.Layer): + r""" + Bert Model with a `language modeling` head on top. + """ + def __init__(self, hidden_size, vocab_size, @@ -377,6 +670,12 @@ def forward(self, sequence_output, pooled_output, masked_positions=None): class ErnieForPretraining(ErniePretrainedModel): + r""" + Bert Model with two heads on top as done during the pretraining: + a `masked language modeling` head and a `next sentence prediction (classification)` head. + + """ + def __init__(self, ernie): super(ErnieForPretraining, self).__init__() self.ernie = ernie @@ -394,6 +693,59 @@ def forward(self, position_ids=None, attention_mask=None, masked_positions=None): + r""" + Args: + input_ids (Tensor): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + token_type_ids (Tensor, optional): + Segment token indices to indicate first and second portions of the inputs. + Indices can be either 0 or 1: + + - 0 corresponds to a **sentence A** token, + - 1 corresponds to a **sentence B** token. + + It's data type should be `int64` and has a shape of [batch_size, sequence_length]. + Defaults to None, which means no segment embeddings is added to token embeddings. + position_ids (Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + Defaults to `None`. Shape as `(batch_sie, num_tokens)` and dtype as `int32` or `int64`. + attention_mask (Tensor, optional): + Mask to indicate whether to perform attention on each input token or not. + The values should be either 0 or 1. The attention scores will be set + to **-infinity** for any positions in the mask that are **0**, and will be + **unchanged** for positions that are **1**. + + - **1** for tokens that are **not masked**, + - **0** for tokens that are **masked**. + + It's data type should be `float32` and has a shape of [batch_size, sequence_length]. + Defaults to `None`. + + + Returns: + A tuple of shape (``prediction_scores``, ``seq_relationship_score``). + + With the fields: + - prediction_scores(Tensor): The scores of prediction on masked token. + - seq_relationship_score(Tensor): The scores of next sentence prediction. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer + + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + model = ErnieForTokenClassification.from_pretrained('ernie-1.0') + + inputs = tokenizer("这是个测试样例") + inputs = {k:paddle.to_tensor(v) for (k, v) in inputs.items()} + logits = model(**inputs) + + """ with paddle.static.amp.fp16_guard(): outputs = self.ernie( input_ids, @@ -407,6 +759,12 @@ def forward(self, class ErniePretrainingCriterion(paddle.nn.Layer): + r""" + The loss output of Bert Model during the pretraining: + a `masked language modeling` head and a `next sentence prediction (classification)` head. + + """ + def __init__(self, vocab_size): super(ErniePretrainingCriterion, self).__init__() self.loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=-1) diff --git a/paddlenlp/transformers/ernie/tokenizer.py b/paddlenlp/transformers/ernie/tokenizer.py index c0e451d9c299..16fb927ea4d2 100644 --- a/paddlenlp/transformers/ernie/tokenizer.py +++ b/paddlenlp/transformers/ernie/tokenizer.py @@ -27,29 +27,45 @@ class ErnieTokenizer(PretrainedTokenizer): - """ + r""" Constructs an ERNIE tokenizer. It uses a basic tokenizer to do punctuation splitting, lower casing and so on, and follows a WordPiece tokenizer to tokenize as subwords. + Args: - vocab_file (str): file path of the vocabulary - do_lower_case (bool): Whether the text strips accents and convert to - lower case. Default: `True`. - Default: True. - unk_token (str): The special token for unkown words. Default: "[UNK]". - sep_token (str): The special token for separator token . Default: "[SEP]". - pad_token (str): The special token for padding. Default: "[PAD]". - cls_token (str): The special token for cls. Default: "[CLS]". - mask_token (str): The special token for mask. Default: "[MASK]". + vocab_file (str): + file path of the vocabulary. + do_lower_case (str, optional): + Whether the text strips accents and convert to lower case. + Defaults to `True`. + unk_token (str, optional): + The special token for unknown words. + Defaults to "[UNK]". + sep_token (str, optional): + The special token for separator token. + Defaults to "[SEP]". + pad_token (str, optional): + The special token for padding. + Defaults to "[PAD]". + cls_token (str, optional): + The special token for cls. + Defaults to "[CLS]". + mask_token (str, optional): + The special token for mask. + Defaults to "[MASK]". Examples: .. code-block:: python from paddlenlp.transformers import ErnieTokenizer - tokenizer = ErnieTokenizer.from_pretrained('ernie') - # the following line get: ['he', 'was', 'a', 'puppet', '##eer'] - tokens = tokenizer('He was a puppeteer') - # the following line get: 'he was a puppeteer' - tokenizer.convert_tokens_to_string(tokens) + tokenizer = ErnieTokenizer.from_pretrained('ernie-1.0') + encoded_inputs = tokenizer('这是一个测试样例') + # encoded_inputs: + # { + # 'input_ids': [1, 47, 10, 7, 27, 558, 525, 314, 656, 2], + # 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # } + + """ resource_files_names = {"vocab_file": "vocab.txt"} # for save_pretrained pretrained_resource_files_map = { @@ -111,21 +127,23 @@ def __init__(self, @property def vocab_size(self): - """ + r""" return the size of vocabulary. + Returns: int: the size of vocabulary. """ return len(self.vocab) def _tokenize(self, text): - """ + r""" End-to-end tokenization for ERNIE models. + Args: text (str): The text to be tokenized. Returns: - list: A list of string representing converted tokens. + List[str]: A list of string representing converted tokens. """ split_tokens = [] for token in self.basic_tokenizer.tokenize(text): @@ -134,23 +152,26 @@ def _tokenize(self, text): return split_tokens def tokenize(self, text): - """ + r""" End-to-end tokenization for ERNIE models. + Args: text (str): The text to be tokenized. Returns: - list: A list of string representing converted tokens. + List[str]: A list of string representing converted tokens. """ return self._tokenize(text) def convert_tokens_to_string(self, tokens): - """ + r""" Converts a sequence of tokens (list of string) in a single string. Since the usage of WordPiece introducing `##` to concat subwords, also remove `##` when converting. + Args: - tokens (list): A list of string representing tokens to be converted. + tokens (List[str]): A list of string representing tokens to be converted. + Returns: str: Converted string from tokens. """ @@ -158,19 +179,20 @@ def convert_tokens_to_string(self, tokens): return out_string def num_special_tokens_to_add(self, pair=False): - """ + r""" Returns the number of added tokens when encoding a sequence with special tokens. Note: - This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this - inside your training loop. + This encodes inputs and checks the number of added tokens, and is therefore not efficient. + Do not put this inside your training loop. Args: - pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the - number of added tokens in the case of a single sequence if set to False. + pair (str, optional): Returns the number of added tokens in the case of a sequence + pair if set to True, returns the number of added tokens in the case of a single sequence + if set to False. Defaults to False. Returns: - Number of tokens added to sequences + `int`: Number of tokens added to sequences """ token_ids_0 = [] token_ids_1 = [] @@ -179,7 +201,7 @@ def num_special_tokens_to_add(self, pair=False): if pair else None)) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - """ + r""" Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. @@ -189,13 +211,14 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - pair of sequences: ``[CLS] A [SEP] B [SEP]`` Args: - token_ids_0 (:obj:`List[int]`): + token_ids_0 (List[int]): List of IDs to which the special tokens will be added. - token_ids_1 (:obj:`List[int]`, `optional`): + token_ids_1 (List[int], optional): Optional second list of IDs for sequence pairs. + Defaults to `None`. Returns: - :obj:`List[int]`: List of input_id with the appropriate special tokens. + List[int]: List of input_id with the appropriate special tokens. """ if token_ids_1 is None: return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] @@ -206,7 +229,7 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None): - """ + r""" Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. A ERNIE offset_mapping has the following format: @@ -215,13 +238,14 @@ def build_offset_mapping_with_special_tokens(self, - pair of sequences: `(0,0) A (0,0) B (0,0)`` Args: - offset_mapping_ids_0 (:obj:`List[tuple]`): + offset_mapping_ids_0 (List[tuple]): List of char offsets to which the special tokens will be added. - offset_mapping_ids_1 (:obj:`List[tuple]`, `optional`): + offset_mapping_ids_1 (List[tuple], optional): Optional second list of char offsets for offset mapping pairs. + Defaults to `None`. Returns: - :obj:`List[tuple]`: List of char offsets with the appropriate offsets of special tokens. + List[tuple]: List of char offsets with the appropriate offsets of special tokens. """ if offset_mapping_1 is None: return [(0, 0)] + offset_mapping_0 + [(0, 0)] @@ -232,7 +256,7 @@ def build_offset_mapping_with_special_tokens(self, def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): - """ + r""" Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ERNIE sequence pair mask has the following format: @@ -241,16 +265,17 @@ def create_token_type_ids_from_sequences(self, 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | - If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). Args: - token_ids_0 (:obj:`List[int]`): + token_ids_0 (List[int]): List of IDs. - token_ids_1 (:obj:`List[int]`, `optional`): - Optional second list of IDs for sequence pairs. + token_ids_1 (List[int], optional): + Optional second list of IDs for sequence pairs. + Defaults to `None`. Returns: - :obj:`List[int]`: List of token_type_id according to the given sequence(s). + List[int]: List of token_type_id according to the given sequence(s). """ _sep = [self.sep_token_id] _cls = [self.cls_token_id] @@ -261,27 +286,47 @@ def create_token_type_ids_from_sequences(self, class ErnieTinyTokenizer(PretrainedTokenizer): - """ + r""" Constructs a ErnieTiny tokenizer. It uses the `dict.wordseg.pickle` cut the text to words, and - use the `sentencepiece` tools to cut the words to sub-words. + use the `sentencepiece` tools to cut the words to sub-words. + Args: - vocab_file (str): file path of the vocabulary - do_lower_case (bool): Whether the text strips accents and convert to - lower case. Default: `True`. - unk_token (str): The special token for unkown words. Default: "[UNK]". - sep_token (str): The special token for separator token . Default: "[SEP]". - pad_token (str): The special token for padding. Default: "[PAD]". - cls_token (str): The special token for cls. Default: "[CLS]". - mask_token (str): The special token for mask. Default: "[MASK]". - + vocab_file (str): + The file path of the vocabulary. + sentencepiece_model_file (str): + The file path of sentencepice model. + word_dict(str): + The file path of word vocabulary, + which is used to do chinese word segmentation. + do_lower_case (str, optional): + Whether the text strips accents and convert to lower case. + Defaults to `True`. + unk_token (str, optional): + The special token for unknown words. + Defaults to "[UNK]". + sep_token (str, optional): + The special token for separator token. + Defaults to "[SEP]". + pad_token (str, optional): + The special token for padding. + Defaults to "[PAD]". + cls_token (str, optional): + The special token for cls. + Defaults to "[CLS]". + mask_token (str, optional): + The special token for mask. + Defaults to "[MASK]". + Examples: .. code-block:: python from paddlenlp.transformers import ErnieTinyTokenizer - tokenizer = ErnieTinyTokenizer.from_pretrained('ernie-tiny) - # the following line get: ['he', 'was', 'a', 'puppet', '##eer'] - tokens = tokenizer('He was a puppeteer') - # the following line get: 'he was a puppeteer' - tokenizer.convert_tokens_to_string(tokens) + tokenizer = ErnieTinyTokenizer.from_pretrained('ernie-tiny') + inputs = tokenizer('这是个测试样例') + # inputs: + # { + # 'input_ids': [3, 509, 79, 5822, 2340, 4734, 8886, 5], + # 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0] + # } """ resource_files_names = { "sentencepiece_model_file": "spm_cased_simp_sampled.model", @@ -342,8 +387,9 @@ def __init__(self, @property def vocab_size(self): - """ + r""" return the size of vocabulary. + Returns: int: the size of vocabulary. """ @@ -369,13 +415,16 @@ def cut(self, chars): return words def _tokenize(self, text): - """ + r""" End-to-end tokenization for ErnieTiny models. + Args: - text (str): The text to be tokenized. + text (str): + The text to be tokenized. Returns: - list: A list of string representing converted tokens. + List(str): + A list of string representing converted tokens. """ if len(text) == 0: return [] @@ -397,32 +446,39 @@ def _tokenize(self, text): return in_vocab_tokens def tokenize(self, text): - """ + r""" End-to-end tokenization for ERNIE Tiny models. + Args: - text (str): The text to be tokenized. + text (str): + The text to be tokenized. Returns: - list: A list of string representing converted tokens. + List(str): + A list of string representing converted tokens. """ return self._tokenize(text) def convert_tokens_to_string(self, tokens): - """ + r""" Converts a sequence of tokens (list of string) in a single string. Since the usage of WordPiece introducing `##` to concat subwords, also remove `##` when converting. + Args: - tokens (list): A list of string representing tokens to be converted. + tokens (list): + A list of string representing tokens to be converted. Returns: - str: Converted string from tokens. + str: + Converted string from tokens. """ out_string = " ".join(tokens).replace(" ##", "").strip() return out_string def save_resources(self, save_directory): - """ + r""" Save tokenizer related resources to files under `save_directory`. + Args: save_directory (str): Directory to save files into. """ @@ -434,7 +490,7 @@ def save_resources(self, save_directory): shutil.copyfile(source_path, save_path) def num_special_tokens_to_add(self, pair=False): - """ + r""" Returns the number of added tokens when encoding a sequence with special tokens. Note: @@ -442,11 +498,13 @@ def num_special_tokens_to_add(self, pair=False): inside your training loop. Args: - pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the + pair(str, optional): + Returns the number of added tokens in the case of a sequence pair if set to True, returns the number of added tokens in the case of a single sequence if set to False. + Defaults to `Fasle`. Returns: - Number of tokens added to sequences + int: Number of tokens added to sequences. """ token_ids_0 = [] token_ids_1 = [] @@ -455,7 +513,7 @@ def num_special_tokens_to_add(self, pair=False): if pair else None)) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - """ + r""" Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. @@ -465,13 +523,14 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - pair of sequences: ``[CLS] A [SEP] B [SEP]`` Args: - token_ids_0 (:obj:`List[int]`): + token_ids_0 (List[int]): List of IDs to which the special tokens will be added. - token_ids_1 (:obj:`List[int]`, `optional`): + token_ids_1 (List[int], optional): Optional second list of IDs for sequence pairs. + Defaults to `None`. Returns: - :obj:`List[int]`: List of input_id with the appropriate special tokens. + List[int]: List of input_id with the appropriate special tokens. """ if token_ids_1 is None: return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] @@ -482,7 +541,7 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_offset_mapping_with_special_tokens(self, offset_mapping_0, offset_mapping_1=None): - """ + r""" Build offset map from a pair of offset map by concatenating and adding offsets of special tokens. A ERNIE offset_mapping has the following format: @@ -491,13 +550,14 @@ def build_offset_mapping_with_special_tokens(self, - pair of sequences: `(0,0) A (0,0) B (0,0)`` Args: - offset_mapping_ids_0 (:obj:`List[tuple]`): + offset_mapping_ids_0 (List[tuple]): List of char offsets to which the special tokens will be added. - offset_mapping_ids_1 (:obj:`List[tuple]`, `optional`): + offset_mapping_ids_1 (List[tuple], optional): Optional second list of char offsets for offset mapping pairs. + Defaults to `None`. Returns: - :obj:`List[tuple]`: List of char offsets with the appropriate offsets of special tokens. + List[tuple]: List of char offsets with the appropriate offsets of special tokens. """ if offset_mapping_1 is None: return [(0, 0)] + offset_mapping_0 + [(0, 0)] @@ -508,7 +568,7 @@ def build_offset_mapping_with_special_tokens(self, def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): - """ + r""" Create a mask from the two sequences passed to be used in a sequence-pair classification task. A ERNIE sequence pair mask has the following format: @@ -517,16 +577,17 @@ def create_token_type_ids_from_sequences(self, 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 | first sequence | second sequence | - If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). + If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s). Args: - token_ids_0 (:obj:`List[int]`): + token_ids_0 (List[int]): List of IDs. - token_ids_1 (:obj:`List[int]`, `optional`): + token_ids_1 (List[int], optional): Optional second list of IDs for sequence pairs. + Defaults to `None`. Returns: - :obj:`List[int]`: List of token_type_id according to the given sequence(s). + List[int]: List of token_type_id according to the given sequence(s). """ _sep = [self.sep_token_id] _cls = [self.cls_token_id] @@ -539,18 +600,23 @@ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): - """ + r""" Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``encode`` methods. Args: - token_ids_0 (List[int]): List of ids of the first sequence. - token_ids_1 (List[int], optinal): List of ids of the second sequence. - already_has_special_tokens (bool, optional): Whether or not the token list is already - formatted with special tokens for the model. Defaults to None. + token_ids_0 (List[int]): + List of ids of the first sequence. + token_ids_1 (List[int], optinal): + List of ids of the second sequence. + Defaults to `None`. + already_has_special_tokens (str, optional): + Whether or not the token list is already formatted with special tokens for the model. + Defaults to `False`. Returns: - results (List[int]): The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + List[int]: + The list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: