Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1166 from wawltor/taskflow_download_c…
Browse files Browse the repository at this point in the history
…heck

update the download check for the Taskflow
  • Loading branch information
wawltor authored Oct 14, 2021
2 parents 1b4821b + 6feb45a commit e51e699
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 100 deletions.
2 changes: 1 addition & 1 deletion paddlenlp/taskflow/dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self,
ddparser, ddparser-ernie-1.0 and ddoarser-ernie-gram-zh")
word_vocab_path = download_file(
self._task_path, self.model + os.path.sep + "word_vocab.json",
URLS[self.model][0], URLS[self.model][1], self.model)
URLS[self.model][0], URLS[self.model][1])
rel_vocab_path = download_file(
self._task_path, self.model + os.path.sep + "rel_vocab.json",
URLS[self.model][0], URLS[self.model][1])
Expand Down
8 changes: 3 additions & 5 deletions paddlenlp/taskflow/knowledge_mining.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,11 @@ class WordTagTask(Task):
def __init__(self, model, task, **kwargs):
super().__init__(model=model, task=task, **kwargs)
self._static_mode = False
self._log_name = self.kwargs[
'log_name'] if 'log_name' in self.kwargs else 'wordtag'
self._linking = self.kwargs[
'linking'] if 'linking' in self.kwargs else False
term_schema_path = download_file(
self._task_path, "termtree_type.csv", URLS['termtree_type'][0],
URLS['termtree_type'][1], self._log_name)
term_schema_path = download_file(self._task_path, "termtree_type.csv",
URLS['termtree_type'][0],
URLS['termtree_type'][1])
term_data_path = download_file(self._task_path, "TermTree.V1.0",
URLS['TermTree.V1.0'][0],
URLS['TermTree.V1.0'][1])
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/taskflow/lexical_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self, task, model, **kwargs):
self._usage = usage
word_dict_path = download_file(
self._task_path, "lac_params" + os.path.sep + "word.dic",
URLS['lac_params'][0], URLS['lac_params'][1], 'lexical_analysis')
URLS['lac_params'][0], URLS['lac_params'][1])
tag_dict_path = download_file(
self._task_path, "lac_params" + os.path.sep + "tag.dic",
URLS['lac_params'][0], URLS['lac_params'][1])
Expand Down
9 changes: 0 additions & 9 deletions paddlenlp/taskflow/poetry_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,3 @@ class PoetryGenerationTask(TextGenerationTask):

def __init__(self, task, model, **kwargs):
super().__init__(task=task, model=model, **kwargs)
if self._static_mode:
download_file(
self._task_path, "static" + os.path.sep + "inference.pdiparams",
URLS[self.model][0], URLS[self.model][1], "poetry_generation")
self._get_inference_model()
else:
self._construct_model(model)
self._construct_tokenizer(model)
self.kwargs['generation_task'] = 'poetry_generation'
30 changes: 0 additions & 30 deletions paddlenlp/taskflow/pos_tagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@
from .utils import download_file
from .lexical_analysis import load_vocab, LacTask

URLS = {
"pos_tagging_params": [
"https://paddlenlp.bj.bcebos.com/taskflow/lexical_analysis/lac/lac_params.tar.gz",
'ee9a3eaba5f74105410410e3c5b28fbc'
],
}

usage = r"""
from paddlenlp import Taskflow
Expand Down Expand Up @@ -58,29 +51,6 @@ class POSTaggingTask(LacTask):

def __init__(self, task, model, **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._static_mode = False
self._usage = usage
word_dict_path = download_file(
self._task_path, "lac_params" + os.path.sep + "word.dic",
URLS['pos_tagging_params'][0], URLS['pos_tagging_params'][1],
'pos_tagging')
tag_dict_path = download_file(
self._task_path, "lac_params" + os.path.sep + "tag.dic",
URLS['pos_tagging_params'][0], URLS['pos_tagging_params'][1])
q2b_dict_path = download_file(
self._task_path, "lac_params" + os.path.sep + "q2b.dic",
URLS['pos_tagging_params'][0], URLS['pos_tagging_params'][1])
self._word_vocab = load_vocab(word_dict_path)
self._tag_vocab = load_vocab(tag_dict_path)
self._q2b_vocab = load_vocab(q2b_dict_path)
self._id2word_dict = dict(
zip(self._word_vocab.values(), self._word_vocab.keys()))
self._id2tag_dict = dict(
zip(self._tag_vocab.values(), self._tag_vocab.keys()))
if self._static_mode:
self._get_inference_model()
else:
self._construct_model(model)

def _postprocess(self, inputs):
"""
Expand Down
9 changes: 0 additions & 9 deletions paddlenlp/taskflow/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,3 @@ class QuestionAnsweringTask(TextGenerationTask):

def __init__(self, task, model, **kwargs):
super().__init__(task=task, model=model, **kwargs)
if self._static_mode:
download_file(
self._task_path, "static" + os.path.sep + "inference.pdiparams",
URLS[self.model][0], URLS[self.model][1], "question_answering")
self._get_inference_model()
else:
self._construct_model(model)
self._construct_tokenizer(model)
self.kwargs['generation_task'] = 'question_answering'
2 changes: 1 addition & 1 deletion paddlenlp/taskflow/sentiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _construct_model(self, model):
model_instance = SkepSequenceModel.from_pretrained(
model, num_classes=len(self._label_map))
model_path = download_file(self._task_path, model + ".pdparams",
URLS[model][0], URLS[model][1], model)
URLS[model][0], URLS[model][1])
state_dict = paddle.load(model_path)
model_instance.set_state_dict(state_dict)
self._model = model_instance
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/taskflow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import paddle
from ..utils.env import PPNLP_HOME
from ..utils.log import logger
from .utils import static_mode_guard, dygraph_mode_guard
from .utils import download_check, static_mode_guard, dygraph_mode_guard


class Task(metaclass=abc.ABCMeta):
Expand All @@ -44,6 +44,9 @@ def __init__(self, model, task, **kwargs):
self._config = None
self._task_path = os.path.join(PPNLP_HOME, "taskflow", self.task,
self.model)
self._task_flag = self.kwargs[
'task_flag'] if 'task_flag' in self.kwargs else self.model
download_check(self._task_flag)

@abstractmethod
def _construct_model(self, model):
Expand Down
33 changes: 22 additions & 11 deletions paddlenlp/taskflow/taskflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"models": {
"wordtag": {
"task_class": WordTagTask,
"log_name": 'knowledge_mining_wordtag',
"task_flag": 'knowledge_mining-wordtag',
"linking": True,
}
},
Expand All @@ -48,7 +48,7 @@
"models": {
"wordtag": {
"task_class": NERTask,
"log_name": 'ner_wordtag',
"task_flag": 'ner-wordtag',
"linking": False,
}
},
Expand All @@ -60,6 +60,7 @@
"models": {
"gpt-cpm-large-cn": {
"task_class": PoetryGenerationTask,
"task_flag": 'poetry_generation-gpt-cpm-large-cn',
},
},
"default": {
Expand All @@ -70,6 +71,7 @@
"models": {
"gpt-cpm-large-cn": {
"task_class": QuestionAnsweringTask,
"task_flag": 'question_answering-gpt-cpm-large-cn',
},
},
"default": {
Expand All @@ -81,7 +83,8 @@
"lac": {
"task_class": LacTask,
"hidden_size": 128,
"emb_dim": 128
"emb_dim": 128,
"task_flag": 'lexical_analysis-gru_crf',
}
},
"default": {
Expand All @@ -93,7 +96,8 @@
"lac": {
"task_class": WordSegmentationTask,
"hidden_size": 128,
"emb_dim": 128
"emb_dim": 128,
"task_flag": 'word_segmentation-gru_crf',
}
},
"default": {
Expand All @@ -105,7 +109,8 @@
"lac": {
"task_class": POSTaggingTask,
"hidden_size": 128,
"emb_dim": 128
"emb_dim": 128,
"task_flag": 'pos_tagging-gru_crf',
}
},
"default": {
Expand All @@ -115,10 +120,12 @@
'sentiment_analysis': {
"models": {
"bilstm": {
"task_class": SentaTask
"task_class": SentaTask,
"task_flag": 'sentiment_analysis-bilstm',
},
"skep_ernie_1.0_large_ch": {
"task_class": SkepTask
"task_class": SkepTask,
"task_flag": 'sentiment_analysis-skep_ernie_1.0_large_ch',
}
},
"default": {
Expand All @@ -128,13 +135,16 @@
'dependency_parsing': {
"models": {
"ddparser": {
"task_class": DDParserTask
"task_class": DDParserTask,
"task_flag": 'dependency_parsing-biaffine',
},
"ddparser-ernie-1.0": {
"task_class": DDParserTask
"task_class": DDParserTask,
"task_flag": 'dependency_parsing-ernie-1.0',
},
"ddparser-ernie-gram-zh": {
"task_class": DDParserTask
"task_class": DDParserTask,
"task_flag": 'dependency_parsing-ernie-gram-zh',
},
},
"default": {
Expand All @@ -144,7 +154,8 @@
'text_correction': {
"models": {
"csc-ernie-1.0": {
"task_class": CSCTask
"task_class": CSCTask,
"task_flag": "text_correction-csc-ernie-1.0"
},
},
"default": {
Expand Down
9 changes: 9 additions & 0 deletions paddlenlp/taskflow/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def __init__(self, task, model, **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._static_mode = True
self._usage = usage
if self._static_mode:
download_file(self._task_path,
"static" + os.path.sep + "inference.pdiparams",
URLS[self.model][0], URLS[self.model][1])
self._get_inference_model()
else:
self._construct_model(model)
self._construct_tokenizer(model)
self.kwargs['generation_task'] = task

def _construct_input_spec(self):
"""
Expand Down
25 changes: 18 additions & 7 deletions paddlenlp/taskflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
DOWNLOAD_CHECK = False


def download_file(save_dir, filename, url, md5=None, task=None):
def download_file(save_dir, filename, url, md5=None):
"""
Download the file from the url to specified directory.
Check md5 value when the file is exists, if the md5 value is the same as the existed file, just use
Expand All @@ -44,12 +44,6 @@ def download_file(save_dir, filename, url, md5=None, task=None):
md5(string, optional): The md5 value that checking the version downloaded.
"""
logger.disable()
global DOWNLOAD_CHECK
if not DOWNLOAD_CHECK:
DOWNLOAD_CHECK = True
checker = DownloaderCheck(task)
checker.start()
checker.join()
fullname = os.path.join(save_dir, filename)
if os.path.exists(fullname):
if md5 and (not md5file(fullname) == md5):
Expand All @@ -60,6 +54,23 @@ def download_file(save_dir, filename, url, md5=None, task=None):
return fullname


def download_check(task):
"""
Check the resource statuc in the specified task.
Args:
task(string): The name of specified task.
"""
logger.disable()
global DOWNLOAD_CHECK
if not DOWNLOAD_CHECK:
DOWNLOAD_CHECK = True
checker = DownloaderCheck(task)
checker.start()
checker.join()
logger.enable()


def add_docstrings(*docstr):
"""
The function that add the doc string to doc of class.
Expand Down
25 changes: 0 additions & 25 deletions paddlenlp/taskflow/word_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,6 @@ class WordSegmentationTask(LacTask):

def __init__(self, task, model, **kwargs):
super().__init__(task=task, model=model, **kwargs)
self._static_mode = False
self._usage = usage
word_dict_path = download_file(
self._task_path, "lac_params" + os.path.sep + "word.dic",
URLS['word_segmentation_params'][0],
URLS['word_segmentation_params'][1], 'word_segmentation')
tag_dict_path = download_file(self._task_path,
"lac_params" + os.path.sep + "tag.dic",
URLS['word_segmentation_params'][0],
URLS['word_segmentation_params'][1])
q2b_dict_path = download_file(self._task_path,
"lac_params" + os.path.sep + "q2b.dic",
URLS['word_segmentation_params'][0],
URLS['word_segmentation_params'][1])
self._word_vocab = load_vocab(word_dict_path)
self._tag_vocab = load_vocab(tag_dict_path)
self._q2b_vocab = load_vocab(q2b_dict_path)
self._id2word_dict = dict(
zip(self._word_vocab.values(), self._word_vocab.keys()))
self._id2tag_dict = dict(
zip(self._tag_vocab.values(), self._tag_vocab.keys()))
if self._static_mode:
self._get_inference_model()
else:
self._construct_model(model)

def _postprocess(self, inputs):
"""
Expand Down

0 comments on commit e51e699

Please sign in to comment.