diff --git a/requestspool_client/__init__.py b/requestspool_client/__init__.py index 57e9ee1..93644bc 100644 --- a/requestspool_client/__init__.py +++ b/requestspool_client/__init__.py @@ -105,27 +105,37 @@ def receive(self): return result +class RequestsPoolCheck(object): + def __init__(self, server_url): + import requests + import re + try: + # 获取路由 + routes = json.loads(requests.get(server_url + "/admin/route/all").content).get("route") + except: + print "尚未启动服务或服务运行异常" + raise ImportError + self.patterns = [re.compile(s) for s in routes if isinstance(s, basestring)] + + def check(self, url): + for p in self.patterns: + if p.match(url): + return True + return False + def patch_requests(server="localhost:8801"): - import requests - import re - try: - # 获取路由 - routes = json.loads(requests.get("http://%s/admin/route/all" % server).content).get("route") - except: - print "尚未启动服务或服务运行异常" - raise ImportError - patterns = [re.compile(s) for s in routes if isinstance(s, basestring)] + import requests.api + server_url = "http://" + server # patch - old_request = requests.request + old_request = requests.api.request + check = RequestsPoolCheck(server_url) def new_request(method, url, **kwargs): - need_proxy = False - for p in patterns: - if p.match(url): - need_proxy = True - break - if need_proxy: + if url.startswith(server_url): + return old_request(method, url, **kwargs) + + if check.check(url): url = 'http://{server}/{url}'.format(server=server, url=url) return old_request(method, url, **kwargs) - requests.request = new_request + requests.api.request = new_request