Skip to content

Commit

Permalink
refine proxylib.py
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslu committed Jan 3, 2015
1 parent 314c6eb commit 0d5eed2
Showing 1 changed file with 90 additions and 93 deletions.
183 changes: 90 additions & 93 deletions local/proxylib.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,30 +167,30 @@ def clear(self):
self.key_order = []


class CertUtil(object):
"""CertUtil module, based on mitmproxy"""

ca_vendor = 'GoAgent'
ca_keyfile = 'CA.crt'
ca_thumbprint = ''
ca_certdir = 'certs'
ca_digest = 'sha1' if sys.platform == 'win32' and sys.getwindowsversion() < (6,) else 'sha256'
ca_lock = threading.Lock()

@staticmethod
def create_ca():
class CertUtility(object):
"""Cert Utility module, based on mitmproxy"""

def __init__(self, vendor, filename, dirname):
self.ca_vendor = vendor
self.ca_keyfile = filename
self.ca_thumbprint = ''
self.ca_certdir = dirname
self.ca_digest = 'sha1' if sys.platform == 'win32' and sys.getwindowsversion() < (6,) else 'sha256'
self.ca_lock = threading.Lock()

def create_ca(self):
key = OpenSSL.crypto.PKey()
key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
req = OpenSSL.crypto.X509Req()
subj = req.get_subject()
subj.countryName = 'CN'
subj.stateOrProvinceName = 'Internet'
subj.localityName = 'Cernet'
subj.organizationName = CertUtil.ca_vendor
subj.organizationalUnitName = '%s Root' % CertUtil.ca_vendor
subj.commonName = '%s CA' % CertUtil.ca_vendor
subj.organizationName = self.ca_vendor
subj.organizationalUnitName = self.ca_vendor
subj.commonName = self.ca_vendor
req.set_pubkey(key)
req.sign(key, CertUtil.ca_digest)
req.sign(key, self.ca_digest)
ca = OpenSSL.crypto.X509()
ca.set_serial_number(0)
ca.gmtime_adj_notBefore(0)
Expand All @@ -201,22 +201,19 @@ def create_ca():
ca.sign(key, 'sha1')
return key, ca

@staticmethod
def dump_ca():
key, ca = CertUtil.create_ca()
with open(CertUtil.ca_keyfile, 'wb') as fp:
def dump_ca(self):
key, ca = self.create_ca()
with open(self.ca_keyfile, 'wb') as fp:
fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, ca))
fp.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key))

@staticmethod
def get_cert_serial_number(commonname):
assert CertUtil.ca_thumbprint
saltname = '%s|%s' % (CertUtil.ca_thumbprint, commonname)
def get_cert_serial_number(self, commonname):
assert self.ca_thumbprint
saltname = '%s|%s' % (self.ca_thumbprint, commonname)
return int(hashlib.md5(saltname.encode('utf-8')).hexdigest(), 16)

@staticmethod
def _get_cert(commonname, sans=()):
with open(CertUtil.ca_keyfile, 'rb') as fp:
def _get_cert(self, commonname, sans=()):
with open(self.ca_keyfile, 'rb') as fp:
content = fp.read()
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, content)
ca = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, content)
Expand All @@ -229,7 +226,7 @@ def _get_cert(commonname, sans=()):
subj.countryName = 'CN'
subj.stateOrProvinceName = 'Internet'
subj.localityName = 'Cernet'
subj.organizationalUnitName = '%s Branch' % CertUtil.ca_vendor
subj.organizationalUnitName = self.ca_vendor
if commonname[0] == '.':
subj.commonName = '*' + commonname
subj.organizationName = '*' + commonname
Expand All @@ -240,12 +237,12 @@ def _get_cert(commonname, sans=()):
sans = [commonname] + [x for x in sans if x != commonname]
#req.add_extensions([OpenSSL.crypto.X509Extension(b'subjectAltName', True, ', '.join('DNS: %s' % x for x in sans)).encode()])
req.set_pubkey(pkey)
req.sign(pkey, CertUtil.ca_digest)
req.sign(pkey, self.ca_digest)

cert = OpenSSL.crypto.X509()
cert.set_version(2)
try:
cert.set_serial_number(CertUtil.get_cert_serial_number(commonname))
cert.set_serial_number(self.get_cert_serial_number(commonname))
except OpenSSL.SSL.Error:
cert.set_serial_number(int(time.time()*1000))
cert.gmtime_adj_notBefore(-600) #avoid crt time error warning
Expand All @@ -258,31 +255,29 @@ def _get_cert(commonname, sans=()):
else:
sans = [commonname] + [s for s in sans if s != commonname]
#cert.add_extensions([OpenSSL.crypto.X509Extension(b'subjectAltName', True, ', '.join('DNS: %s' % x for x in sans))])
cert.sign(key, CertUtil.ca_digest)
cert.sign(key, self.ca_digest)

certfile = os.path.join(CertUtil.ca_certdir, commonname + '.crt')
certfile = os.path.join(self.ca_certdir, commonname + '.crt')
with open(certfile, 'wb') as fp:
fp.write(OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, cert))
fp.write(OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, pkey))
return certfile

@staticmethod
def get_cert(commonname, sans=()):
def get_cert(self, commonname, sans=()):
if commonname.count('.') >= 2 and [len(x) for x in reversed(commonname.split('.'))] > [2, 4]:
commonname = '.'+commonname.partition('.')[-1]
certfile = os.path.join(CertUtil.ca_certdir, commonname + '.crt')
certfile = os.path.join(self.ca_certdir, commonname + '.crt')
if os.path.exists(certfile):
return certfile
elif OpenSSL is None:
return CertUtil.ca_keyfile
return self.ca_keyfile
else:
with CertUtil.ca_lock:
with self.ca_lock:
if os.path.exists(certfile):
return certfile
return CertUtil._get_cert(commonname, sans)
return self._get_cert(commonname, sans)

@staticmethod
def import_ca(certfile):
def import_ca(self, certfile):
commonname = os.path.splitext(os.path.basename(certfile))[0]
if sys.platform.startswith('win'):
import ctypes
Expand All @@ -301,8 +296,8 @@ def import_ca(certfile):
X509_ASN_ENCODING = 0x00000001
class CRYPT_HASH_BLOB(ctypes.Structure):
_fields_ = [('cbData', ctypes.c_ulong), ('pbData', ctypes.c_char_p)]
assert CertUtil.ca_thumbprint
crypt_hash = CRYPT_HASH_BLOB(20, binascii.a2b_hex(CertUtil.ca_thumbprint.replace(':', '')))
assert self.ca_thumbprint
crypt_hash = CRYPT_HASH_BLOB(20, binascii.a2b_hex(self.ca_thumbprint.replace(':', '')))
crypt_handle = crypt32.CertFindCertificateInStore(store_handle, X509_ASN_ENCODING, 0, CERT_FIND_HASH, ctypes.byref(crypt_hash), None)
if crypt_handle:
crypt32.CertFreeCertificateContext(crypt_handle)
Expand All @@ -327,8 +322,7 @@ class CRYPT_HASH_BLOB(ctypes.Structure):
logging.warning('please install *libnss3-tools* package to import GoAgent root ca')
return 0

@staticmethod
def remove_ca(name):
def remove_ca(self, name):
import ctypes
import ctypes.wintypes
class CERT_CONTEXT(ctypes.Structure):
Expand All @@ -348,43 +342,44 @@ class CERT_CONTEXT(ctypes.Structure):
if hasattr(cert, 'get_subject'):
cert = cert.get_subject()
cert_name = next((v for k, v in cert.get_components() if k == 'CN'), '')
if cert_name and name == cert_name:
if cert_name and name.lower() == cert_name.split()[0].lower():
crypt32.CertDeleteCertificateFromStore(crypt32.CertDuplicateCertificateContext(pCertCtx))
pCertCtx = crypt32.CertEnumCertificatesInStore(store_handle, pCertCtx)
return 0

@staticmethod
def check_ca():
def check_ca(self):
#Check CA exists
capath = os.path.join(os.path.dirname(os.path.abspath(__file__)), CertUtil.ca_keyfile)
certdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), CertUtil.ca_certdir)
capath = os.path.join(os.path.dirname(os.path.abspath(__file__)), self.ca_keyfile)
certdir = os.path.join(os.path.dirname(os.path.abspath(__file__)), self.ca_certdir)
if not os.path.exists(capath):
if os.path.exists(certdir):
any(os.remove(x) for x in glob.glob(certdir+'/*.crt')+glob.glob(certdir+'/.*.crt'))
if os.name == 'nt':
try:
CertUtil.remove_ca('%s CA' % CertUtil.ca_vendor)
self.remove_ca(self.ca_vendor)
except Exception as e:
logging.warning('CertUtil.remove_ca failed: %r', e)
CertUtil.dump_ca()
logging.warning('self.remove_ca failed: %r', e)
self.dump_ca()
with open(capath, 'rb') as fp:
CertUtil.ca_thumbprint = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, fp.read()).digest('sha1')
self.ca_thumbprint = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, fp.read()).digest('sha1')
#Check Certs
certfiles = glob.glob(certdir+'/*.crt')+glob.glob(certdir+'/.*.crt')
if certfiles:
filename = random.choice(certfiles)
commonname = os.path.splitext(os.path.basename(filename))[0]
with open(filename, 'rb') as fp:
serial_number = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, fp.read()).get_serial_number()
if serial_number != CertUtil.get_cert_serial_number(commonname):
if serial_number != self.get_cert_serial_number(commonname):
any(os.remove(x) for x in certfiles)
#Check CA imported
if CertUtil.import_ca(capath) != 0:
if self.import_ca(capath) != 0:
logging.warning('install root certificate failed, Please run as administrator/root/sudo')
#Check Certs Dir
if not os.path.exists(certdir):
os.makedirs(certdir)

CertUtil = CertUtility('GoAgent', 'CA.crt', 'certs')


class SSLConnection(object):
"""OpenSSL Connection Wapper"""
Expand Down Expand Up @@ -861,30 +856,39 @@ def get_process_list():

def forward_socket(local, remote, timeout, bufsize):
"""forward socket"""
def __io_copy(dest, source, timeout):
try:
dest.settimeout(timeout)
source.settimeout(timeout)
while 1:
data = source.recv(bufsize)
try:
tick = 1
timecount = timeout
while 1:
timecount -= tick
if timecount <= 0:
break
(ins, _, errors) = select.select([local, remote], [], [local, remote], tick)
if errors:
break
for sock in ins:
data = sock.recv(bufsize)
if not data:
break
dest.sendall(data)
except socket.timeout:
pass
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.ENOTCONN, errno.EPIPE):
raise
if e.args[0] in (errno.EBADF,):
return
finally:
for sock in (dest, source):
try:
sock.close()
except StandardError:
pass
thread.start_new_thread(__io_copy, (remote.dup(), local.dup(), timeout))
__io_copy(local, remote, timeout)
if sock is remote:
local.sendall(data)
timecount = timeout
else:
remote.sendall(data)
timecount = timeout
except socket.timeout:
pass
except (socket.error, ssl.SSLError, OpenSSL.SSL.Error) as e:
if e.args[0] not in (errno.ECONNABORTED, errno.ECONNRESET, errno.ENOTCONN, errno.EPIPE):
raise
if e.args[0] in (errno.EBADF,):
return
finally:
for sock in (remote, local):
try:
sock.close()
except StandardError:
pass


class LocalProxyServer(SocketServer.ThreadingTCPServer):
Expand All @@ -893,6 +897,14 @@ class LocalProxyServer(SocketServer.ThreadingTCPServer):
allow_reuse_address = True
daemon_threads = True

def __init__(self, listener, RequestHandlerClass, bind_and_activate=True):
"""Constructor. May be extended, do not override."""
if hasattr(listener, 'getsockname'):
SocketServer.BaseServer.__init__(self, listener.getsockname(), RequestHandlerClass)
self.socket = listener
else:
SocketServer.ThreadingTCPServer.__init__(self, listener, RequestHandlerClass, bind_and_activate)

def close_request(self, request):
try:
request.close()
Expand Down Expand Up @@ -1117,22 +1129,11 @@ def handle_connect(self, handler, kwargs):
port = handler.port
local = handler.connection
remote = None
handler.send_response(200)
handler.end_headers()
handler.connection.send('HTTP/1.1 200 OK\r\n\r\n')
handler.close_connection = 1
data = local.recv(1024)
if not data:
local.close()
return
data_is_clienthello = is_clienthello(data)
if data_is_clienthello:
kwargs['client_hello'] = data
for i in xrange(self.max_retry):
try:
remote = handler.net2.create_tcp_connection(host, port, handler.net2.connect_timeout, **kwargs)
if not data_is_clienthello and remote and not isinstance(remote, Exception):
remote.sendall(data)
break
except StandardError as e:
logging.exception('%s "FORWARD %s %s:%d %s" %r', handler.address_string(), handler.command, host, port, handler.protocol_version, e)
if hasattr(remote, 'close'):
Expand All @@ -1143,10 +1144,6 @@ def handle_connect(self, handler, kwargs):
if hasattr(remote, 'fileno'):
# reset timeout default to avoid long http upload failure, but it will delay timeout retry :(
remote.settimeout(None)
data = data_is_clienthello and getattr(remote, 'data', None)
if data:
del remote.data
local.sendall(data)
forward_socket(local, remote, 60, bufsize=256*1024)


Expand Down

0 comments on commit 0d5eed2

Please sign in to comment.