Skip to content

Commit

Permalink
Fix PretrainedTokenizer saving. (PaddlePaddle#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
guoshengCS authored Jun 27, 2021
1 parent 3038913 commit 5a4f57c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/data/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def to_tokens(self, indices):

tokens = []
for idx in indices:
if not isinstance(idx, int):
if not isinstance(idx, (int, np.integer)):
warnings.warn(
"The type of `to_tokens()`'s input `indices` is not `int` which will be forcibly transfered to `int`. "
)
Expand Down
9 changes: 6 additions & 3 deletions paddlenlp/transformers/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,12 @@ def save_pretrained(self, save_dir):
# reload from save_directory
model = BertForSequenceClassification.from_pretrained('./trained_model/')
"""
assert os.path.isdir(
save_dir), "save_dir ({}) is not available.".format(save_dir)
# Save model config
assert not os.path.isfile(
save_dir
), "Saving directory ({}) should be a directory, not a file".format(
save_dir)
os.makedirs(save_dir, exist_ok=True)
# Save model config
self.save_model_config(save_dir)
# Save model
file_name = os.path.join(save_dir,
Expand Down
23 changes: 12 additions & 11 deletions paddlenlp/transformers/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import six
import unicodedata
from shutil import copyfile
from typing import Iterable, Iterator, Optional, List, Any, Callable, Union

from paddlenlp.utils.downloader import get_path_from_url
Expand Down Expand Up @@ -525,9 +526,12 @@ def save_pretrained(self, save_directory):
# reload from save_directory
tokenizer = BertTokenizer.from_pretrained('trained_model')
"""
assert os.path.isdir(
assert not os.path.isfile(
save_directory
), "Saving directory ({}) should be a directory".format(save_directory)
), "Saving directory ({}) should be a directory, not a file".format(
save_directory)
os.makedirs(save_directory, exist_ok=True)

tokenizer_config_file = os.path.join(save_directory,
self.tokenizer_config_file)
# init_config is set in metaclass created `__init__`,
Expand All @@ -540,19 +544,16 @@ def save_pretrained(self, save_directory):
def save_resources(self, save_directory):
"""
Save tokenizer related resources to `resource_files_names` indicating
files under `save_directory`.
Currently, it only can support saving `vocab` of tokenizer by using
`self.save_vocabulary(file_name, self.vocab)`. Override it if necessary.
files under `save_directory` by copying directly. Override it if necessary.
Args:
save_directory (str): Directory to save files into.
"""
assert hasattr(self, 'vocab') and len(
self.resource_files_names) == 1, "Must overwrite `save_resources`"
file_name = os.path.join(save_directory,
list(self.resource_files_names.values())[0])
self.save_vocabulary(file_name, self.vocab)
for name, file_name in self.resource_files_names.items():
src_path = self.init_config[name]
dst_path = os.path.join(save_directory, file_name)
if os.path.abspath(src_path) != os.path.abspath(dst_path):
copyfile(src_path, dst_path)

@staticmethod
def load_vocabulary(filepath,
Expand Down

0 comments on commit 5a4f57c

Please sign in to comment.