Skip to content

Commit

Permalink
refactor: improvement check_paddle_installed (fxsjy#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
vissssa authored and fxsjy committed Jan 9, 2020
1 parent 0868c32 commit dc2b788
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 62 deletions.
39 changes: 18 additions & 21 deletions jieba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
from __future__ import absolute_import, unicode_literals

__version__ = '0.41'
__license__ = 'MIT'

import re
import os
import sys
import time
import logging
import marshal
import re
import tempfile
import threading
from math import log
import time
from hashlib import md5
from ._compat import *
from math import log

from . import finalseg
from ._compat import *

if os.name == 'nt':
from shutil import move as _replace_file
else:
_replace_file = os.rename


_get_abs_path = lambda path: os.path.normpath(os.path.join(os.getcwd(), path))

DEFAULT_DICT = None
Expand All @@ -47,10 +45,11 @@

re_skip_default = re.compile("(\r\n|\s)", re.U)


def setLogLevel(log_level):
global logger
default_logger.setLevel(log_level)


class Tokenizer(object):

def __init__(self, dictionary=DEFAULT_DICT):
Expand All @@ -69,7 +68,8 @@ def __init__(self, dictionary=DEFAULT_DICT):
def __repr__(self):
return '<Tokenizer dictionary=%r>' % self.dictionary

def gen_pfdict(self, f):
@staticmethod
def gen_pfdict(f):
lfreq = {}
ltotal = 0
f_name = resolve_filename(f)
Expand Down Expand Up @@ -128,7 +128,7 @@ def initialize(self, dictionary=None):

load_from_cache_fail = True
if os.path.isfile(cache_file) and (abs_path == DEFAULT_DICT or
os.path.getmtime(cache_file) > os.path.getmtime(abs_path)):
os.path.getmtime(cache_file) > os.path.getmtime(abs_path)):
default_logger.debug(
"Loading model from cache %s" % cache_file)
try:
Expand Down Expand Up @@ -201,7 +201,7 @@ def __cut_all(self, sentence):
eng_scan = 0
eng_buf = u''
for k, L in iteritems(dag):
if eng_scan==1 and not re_eng.match(sentence[k]):
if eng_scan == 1 and not re_eng.match(sentence[k]):
eng_scan = 0
yield eng_buf
if len(L) == 1 and k > old_j:
Expand All @@ -219,7 +219,7 @@ def __cut_all(self, sentence):
if j > k:
yield sentence[k:j + 1]
old_j = j
if eng_scan==1:
if eng_scan == 1:
yield eng_buf

def __cut_DAG_NO_HMM(self, sentence):
Expand Down Expand Up @@ -285,24 +285,21 @@ def __cut_DAG(self, sentence):
for elem in buf:
yield elem

def cut(self, sentence, cut_all = False, HMM = True,use_paddle = False):
'''
def cut(self, sentence, cut_all=False, HMM=True, use_paddle=False):
"""
The main function that segments an entire sentence that contains
Chinese characters into separated words.
Parameter:
- sentence: The str(unicode) to be segmented.
- cut_all: Model type. True for full pattern, False for accurate pattern.
- HMM: Whether to use the Hidden Markov Model.
'''
is_paddle_installed = False
if use_paddle == True:
is_paddle_installed = check_paddle_install()
"""
is_paddle_installed = check_paddle_install['is_paddle_installed']
sentence = strdecode(sentence)
if use_paddle == True and is_paddle_installed == True:
if use_paddle and is_paddle_installed:
if sentence is None or sentence == "" or sentence == u"":
yield sentence
return
import jieba.lac_small.predict as predict
results = predict.get_sent(sentence)
for sent in results:
Expand Down
49 changes: 20 additions & 29 deletions jieba/_compat.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,55 @@
# -*- coding: utf-8 -*-
import logging
import os
import sys
import logging

log_console = logging.StreamHandler(sys.stderr)
default_logger = logging.getLogger(__name__)
default_logger.setLevel(logging.DEBUG)


def setLogLevel(log_level):
global logger
default_logger.setLevel(log_level)


check_paddle_install = {'is_paddle_installed': False}

try:
import pkg_resources

get_module_res = lambda *res: pkg_resources.resource_stream(__name__,
os.path.join(*res))
except ImportError:
get_module_res = lambda *res: open(os.path.normpath(os.path.join(
os.getcwd(), os.path.dirname(__file__), *res)), 'rb')
os.getcwd(), os.path.dirname(__file__), *res)), 'rb')


def enable_paddle():
import_paddle_check = False
try:
import paddle
except ImportError:
default_logger.debug("Installing paddle-tiny, please wait a minute......")
os.system("pip install paddlepaddle-tiny")
try:
import paddle
except ImportError:
default_logger.debug("Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1."
"Now, back to jieba basic cut......")
try:
import paddle
except ImportError:
default_logger.debug(
"Import paddle error, please use command to install: pip install paddlepaddle-tiny==1.6.1."
"Now, back to jieba basic cut......")
if paddle.__version__ < '1.6.1':
default_logger.debug("Find your own paddle version doesn't satisfy the minimum requirement (1.6.1), "
"please install paddle tiny by 'pip install --upgrade paddlepaddle-tiny', "
"or upgrade paddle full version by 'pip install --upgrade paddlepaddle (-gpu for GPU version)' ")
"or upgrade paddle full version by "
"'pip install --upgrade paddlepaddle (-gpu for GPU version)' ")
else:
try:
import jieba.lac_small.predict as predict
import_paddle_check = True
default_logger.debug("Paddle enabled successfully......")
check_paddle_install['is_paddle_installed'] = True
except ImportError:
default_logger.debug("Import error, cannot find paddle.fluid and jieba.lac_small.predict module. "
"Now, back to jieba basic cut......")
"Now, back to jieba basic cut......")


PY2 = sys.version_info[0] == 2

Expand All @@ -66,6 +72,7 @@ def enable_paddle():
itervalues = lambda d: iter(d.values())
iteritems = lambda d: iter(d.items())


def strdecode(sentence):
if not isinstance(sentence, text_type):
try:
Expand All @@ -74,25 +81,9 @@ def strdecode(sentence):
sentence = sentence.decode('gbk', 'ignore')
return sentence


def resolve_filename(f):
try:
return f.name
except AttributeError:
return repr(f)


def check_paddle_install():
is_paddle_installed = False
try:
import paddle
if paddle.__version__ >= '1.6.1':
is_paddle_installed = True
else:
is_paddle_installed = False
default_logger.debug("Check the paddle version is not correct, the current version is "+ paddle.__version__+","
"please use command to install paddle: pip uninstall paddlepaddle(-gpu), "
"pip install paddlepaddle-tiny==1.6.1. Now, back to jieba basic cut......")
except ImportError:
default_logger.debug("Import paddle error, back to jieba basic cut......")
is_paddle_installed = False
return is_paddle_installed
22 changes: 10 additions & 12 deletions jieba/posseg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import absolute_import, unicode_literals
import os

import pickle
import re
import sys

import jieba
import pickle
from .._compat import *
from .viterbi import viterbi
from .._compat import *

PROB_START_P = "prob_start.p"
PROB_TRANS_P = "prob_trans.p"
Expand Down Expand Up @@ -252,6 +252,7 @@ def cut(self, sentence, HMM=True):
def lcut(self, *args, **kwargs):
return list(self.cut(*args, **kwargs))


# default Tokenizer instance

dt = POSTokenizer(jieba.dt)
Expand All @@ -276,19 +277,16 @@ def cut(sentence, HMM=True, use_paddle=False):
Note that this only works using dt, custom POSTokenizer
instances are not supported.
"""
is_paddle_installed = False
if use_paddle == True:
is_paddle_installed = check_paddle_install()
if use_paddle==True and is_paddle_installed == True:
is_paddle_installed = check_paddle_install['is_paddle_installed']
if use_paddle and is_paddle_installed:
if sentence is None or sentence == "" or sentence == u"":
yield pair(None, None)
return
import jieba.lac_small.predict as predict
sents,tags = predict.get_result(strdecode(sentence))
for i,sent in enumerate(sents):
sents, tags = predict.get_result(strdecode(sentence))
for i, sent in enumerate(sents):
if sent is None or tags[i] is None:
continue
yield pair(sent,tags[i])
yield pair(sent, tags[i])
return
global dt
if jieba.pool is None:
Expand Down

0 comments on commit dc2b788

Please sign in to comment.