Skip to content

Commit

Permalink
Modify GPT Docstring (PaddlePaddle#942)
Browse files Browse the repository at this point in the history
* first commit of gpt

* modify tokenizer

* modify bart tokenizer

* modify tinybert tokenizer

* modify tinybert modeling

* modify tinybert

* fix gpt errors

* fix errors

* modify gpt

* modify models

* update

* resolve conflicts

* resolve conflict
  • Loading branch information
huhuiwen99 authored Sep 8, 2021
1 parent 0c751f3 commit f3cf4a5
Show file tree
Hide file tree
Showing 6 changed files with 752 additions and 35 deletions.
46 changes: 45 additions & 1 deletion paddlenlp/transformers/bart/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,51 @@

class BartTokenizer(GPTTokenizer):
r"""
Construct a BART tokenizer.
Construct a BART tokenizer based on byte-level Byte-Pair-Encoding.
This tokenizer inherits from :class:`~paddlenlp.transformers.gpt.tokenizer.GPTTokenizer`.
For more information regarding those methods, please refer to this superclass.
Args:
vocab_file (str):
Path to the vocabulary file.
The vocab file contains a mapping from vocabulary strings to indices.
merges_file (str):
Path to the merge file.
The merge file is used to split the input sentence into "subword" units.
The vocab file is then used to encode those units as intices.
errors (str):
Paradigm to follow when decoding bytes to UTF-8.
Defaults to `'replace'`.
max_len (int, optional):
The maximum value of the input sequence length.
Defaults to `None`.
special_tokens (list, optional):
A list of special tokens not in the vocabulary.
Defaults to `None`.
eos_token (str, optional):
A special token representing the end of a sequence that was used during pretraining.
Defaults to `"</s>"`.
pad_token (str, optional):
A special token used to make arrays of tokens the same size for batching purposes.
Defaults to "[PAD]".
eol_token (str, optional):
A special token representing the token of newline.
Defaults to `"\u010a"`.
Examples:
.. code-block::
from paddlenlp.transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('bart-base')
print(tokenizer('He was a puppeteer'))
'''
{'input_ids': [0, 894, 21, 10, 32986, 9306, 254, 2],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
'''
"""
# merges and vocab same as GPT2
resource_files_names = {
Expand Down
8 changes: 3 additions & 5 deletions paddlenlp/transformers/bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,11 +447,9 @@ def forward(self,
Mask used in multi-head attention to avoid performing attention on to some unwanted positions,
usually the paddings or the subsequent positions.
Its data type can be int, float and bool.
If its data type is int, the values should be either 0 or 1.
- **1** for tokens that **not masked**,
- **0** for tokens that **masked**.
When the data type is bool, the `masked` tokens have `False` values and the others have `True` values.
When the data type is int, the `masked` tokens have `0` values and the others have `1` values.
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
Defaults to `None`, which means nothing needed to be prevented attention to.
output_hidden_states (bool, optional):
Expand Down
Loading

0 comments on commit f3cf4a5

Please sign in to comment.