diff --git a/shadowsocks/obfsplugin/__init__.py b/shadowsocks/obfsplugin/__init__.py index 8b13789..401c7b7 100644 --- a/shadowsocks/obfsplugin/__init__.py +++ b/shadowsocks/obfsplugin/__init__.py @@ -1 +1,18 @@ +#!/usr/bin/env python +# +# Copyright 2015 clowwindy +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from __future__ import absolute_import, division, print_function, \ + with_statement diff --git a/shadowsocks/obfsplugin/auth.py b/shadowsocks/obfsplugin/auth.py new file mode 100644 index 0000000..a745e09 --- /dev/null +++ b/shadowsocks/obfsplugin/auth.py @@ -0,0 +1,787 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging +import binascii +import base64 +import time +import datetime +import random +import math +import struct +import zlib +import hmac +import hashlib + +import shadowsocks +from shadowsocks import common, lru_cache, encrypt +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord, chr + +def create_auth_sha1_v4(method): + return auth_sha1_v4(method) + +def create_auth_aes128_md5(method): + return auth_aes128_sha1(method, hashlib.md5) + +def create_auth_aes128_sha1(method): + return auth_aes128_sha1(method, hashlib.sha1) + +obfs_map = { + 'auth_sha1_v4': (create_auth_sha1_v4,), + 'auth_sha1_v4_compatible': (create_auth_sha1_v4,), + 'auth_aes128_md5': (create_auth_aes128_md5,), + 'auth_aes128_sha1': (create_auth_aes128_sha1,), +} + +def match_begin(str1, str2): + if len(str1) >= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class auth_base(plain.plain): + def __init__(self, method): + super(auth_base, self).__init__(method) + self.method = method + self.no_compatible_method = '' + self.overhead = 7 + + def init_data(self): + return '' + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + + def not_match_return(self, buf): + self.raw_trans = True + self.overhead = 0 + if self.method == self.no_compatible_method: + return (b'E'*2048, False) + return (buf, False) + +class client_queue(object): + def __init__(self, begin_id): + self.front = begin_id - 64 + self.back = begin_id + 1 + self.alloc = {} + self.enable = True + self.last_update = time.time() + + def update(self): + self.last_update = time.time() + + def is_active(self): + return time.time() - self.last_update < 60 * 3 + + def re_enable(self, connection_id): + self.enable = True + self.front = connection_id - 64 + self.back = connection_id + 1 + self.alloc = {} + + def insert(self, connection_id): + if not self.enable: + logging.warn('obfs auth: not enable') + return False + if not self.is_active(): + self.re_enable(connection_id) + self.update() + if connection_id < self.front: + logging.warn('obfs auth: deprecated id, someone replay attack') + return False + if connection_id > self.front + 0x4000: + logging.warn('obfs auth: wrong id') + return False + if connection_id in self.alloc: + logging.warn('obfs auth: duplicate id, someone replay attack') + return False + if self.back <= connection_id: + self.back = connection_id + 1 + self.alloc[connection_id] = 1 + while (self.front in self.alloc) or self.front + 0x1000 < self.back: + if self.front in self.alloc: + del self.alloc[self.front] + self.front += 1 + return True + +class obfs_auth_v2_data(object): + def __init__(self): + self.client_id = lru_cache.LRUCache() + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, client_id, connection_id): + if client_id in self.client_id: + self.client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, client_id, connection_id): + if self.client_id.get(client_id, None) is None or not self.client_id[client_id].enable: + if self.client_id.first() is None or len(self.client_id) < self.max_client: + if client_id not in self.client_id: + #TODO: check + self.client_id[client_id] = client_queue(connection_id) + else: + self.client_id[client_id].re_enable(connection_id) + return self.client_id[client_id].insert(connection_id) + + if not self.client_id[self.client_id.first()].is_active(): + del self.client_id[self.client_id.first()] + if client_id not in self.client_id: + #TODO: check + self.client_id[client_id] = client_queue(connection_id) + else: + self.client_id[client_id].re_enable(connection_id) + return self.client_id[client_id].insert(connection_id) + + logging.warn('auth_sha1_v2: no inactive client') + return False + else: + return self.client_id[client_id].insert(connection_id) + +class auth_sha1_v4(auth_base): + def __init__(self, method): + super(auth_sha1_v4, self).__init__(method) + self.recv_buf = b'' + self.unit_len = 8100 + self.decrypt_packet_num = 0 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.salt = b"auth_sha1_v4" + self.no_compatible_method = 'auth_sha1_v4' + + def init_data(self): + return obfs_auth_v2_data() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + + def rnd_data(self, buf_size): + if buf_size > 1200: + return b'\x01' + + if buf_size > 400: + rnd_data = os.urandom(common.ord(os.urandom(1)[0]) % 256) + else: + rnd_data = os.urandom(struct.unpack('>H', os.urandom(2))[0] % 512) + + if len(rnd_data) < 128: + return common.chr(len(rnd_data) + 1) + rnd_data + else: + return common.chr(255) + struct.pack('>H', len(rnd_data) + 3) + rnd_data + + def pack_data(self, buf): + data = self.rnd_data(len(buf)) + buf + data_len = len(data) + 8 + crc = binascii.crc32(struct.pack('>H', data_len)) & 0xFFFF + data = struct.pack('H', data_len) + data + adler32 = zlib.adler32(data) & 0xFFFFFFFF + data += struct.pack('H', data_len) + self.salt + self.server_info.key) & 0xFFFFFFFF + data = struct.pack('H', data_len) + data + data += hmac.new(self.server_info.iv + self.server_info.key, data, hashlib.sha1).digest()[:10] + return data + + def auth_data(self): + utc_time = int(time.time()) & 0xFFFFFFFF + if self.server_info.data.connection_id > 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = os.urandom(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 4: + crc = struct.pack('H', self.recv_buf[:2])[0] + if length >= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + if length > len(self.recv_buf): + break + + if struct.pack('H', self.recv_buf[5:7])[0] + 4 + out_buf += self.recv_buf[pos:length - 4] + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + + def server_pre_encrypt(self, buf): + if self.raw_trans: + return buf + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + sendback = False + + if not self.has_recv_header: + if len(self.recv_buf) <= 6: + return (b'', False) + crc = struct.pack('H', self.recv_buf[:2])[0] + if length > len(self.recv_buf): + return (b'', False) + sha1data = hmac.new(self.server_info.recv_iv + self.server_info.key, self.recv_buf[:length - 10], hashlib.sha1).digest()[:10] + if sha1data != self.recv_buf[length - 10:length]: + logging.error('auth_sha1_v4 data uncorrect auth HMAC-SHA1') + return self.not_match_return(self.recv_buf) + pos = common.ord(self.recv_buf[6]) + if pos < 255: + pos += 6 + else: + pos = struct.unpack('>H', self.recv_buf[7:9])[0] + 6 + out_buf = self.recv_buf[pos:length - 10] + if len(out_buf) < 12: + logging.info('auth_sha1_v4: too short, data %s' % (binascii.hexlify(self.recv_buf),)) + return self.not_match_return(self.recv_buf) + utc_time = struct.unpack(' self.max_time_dif: + logging.info('auth_sha1_v4: wrong timestamp, time_dif %d, data %s' % (time_dif, binascii.hexlify(out_buf),)) + return self.not_match_return(self.recv_buf) + elif self.server_info.data.insert(client_id, connection_id): + self.has_recv_header = True + out_buf = out_buf[12:] + self.client_id = client_id + self.connection_id = connection_id + else: + logging.info('auth_sha1_v4: auth fail, data %s' % (binascii.hexlify(out_buf),)) + return self.not_match_return(self.recv_buf) + self.recv_buf = self.recv_buf[length:] + self.has_recv_header = True + sendback = True + + while len(self.recv_buf) > 4: + crc = struct.pack('H', self.recv_buf[:2])[0] + if length >= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + logging.info('auth_sha1_v4: over size') + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if struct.pack('H', self.recv_buf[5:7])[0] + 4 + out_buf += self.recv_buf[pos:length - 4] + self.recv_buf = self.recv_buf[length:] + if pos == length - 4: + sendback = True + + if out_buf: + self.server_info.data.update(self.client_id, self.connection_id) + self.decrypt_packet_num += 1 + return (out_buf, sendback) + +class obfs_auth_mu_data(object): + def __init__(self): + self.user_id = {} + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if client_id in local_client_id: + local_client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: + if local_client_id.first() is None or len(local_client_id) < self.max_client: + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + if not local_client_id[local_client_id.first()].is_active(): + del local_client_id[local_client_id.first()] + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + logging.warn('auth_aes128: no inactive client') + return False + else: + return local_client_id[client_id].insert(connection_id) + +class auth_aes128_sha1(auth_base): + def __init__(self, method, hashfunc): + super(auth_aes128_sha1, self).__init__(method) + self.hashfunc = hashfunc + self.recv_buf = b'' + self.unit_len = 8100 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.salt = hashfunc == hashlib.md5 and b"auth_aes128_md5" or b"auth_aes128_sha1" + self.no_compatible_method = hashfunc == hashlib.md5 and "auth_aes128_md5" or 'auth_aes128_sha1' + self.extra_wait_size = struct.unpack('>H', os.urandom(2))[0] % 1024 + self.pack_id = 1 + self.recv_id = 1 + self.user_id = None + self.user_key = None + self.last_rnd_len = 0 + self.overhead = 9 + + def init_data(self): + return obfs_auth_mu_data() + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + + def trapezoid_random_float(self, d): + if d == 0: + return random.random() + s = random.random() + a = 1 - d + return (math.sqrt(a * a + 4 * d * s) - a) / (2 * d) + + def trapezoid_random_int(self, max_val, d): + v = self.trapezoid_random_float(d) + return int(v * max_val) + + def rnd_data_len(self, buf_size, full_buf_size): + if full_buf_size >= self.server_info.buffer_size: + return 0 + tcp_mss = self.server_info.tcp_mss + rev_len = tcp_mss - buf_size - 9 + if rev_len == 0: + return 0 + if rev_len < 0: + if rev_len > -tcp_mss: + return self.trapezoid_random_int(rev_len + tcp_mss, -0.3) + return common.ord(os.urandom(1)[0]) % 32 + if buf_size > 900: + return struct.unpack('>H', os.urandom(2))[0] % rev_len + return self.trapezoid_random_int(rev_len, -0.3) + + def rnd_data(self, buf_size, full_buf_size): + data_len = self.rnd_data_len(buf_size, full_buf_size) + + if data_len < 128: + return common.chr(data_len + 1) + os.urandom(data_len) + + return common.chr(255) + struct.pack(' 400: + rnd_len = struct.unpack(' 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = os.urandom(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len], ogn_data_len) + buf = buf[self.unit_len:] + ret += self.pack_data(buf, ogn_data_len) + self.last_rnd_len = ogn_data_len + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + if length > len(self.recv_buf): + break + + if hmac.new(mac_key, self.recv_buf[:length - 4], self.hashfunc).digest()[:4] != self.recv_buf[length - 4:length]: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data uncorrect checksum') + + self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF + pos = common.ord(self.recv_buf[4]) + if pos < 255: + pos += 4 + else: + pos = struct.unpack(' self.unit_len: + ret += self.pack_data(buf[:self.unit_len], ogn_data_len) + buf = buf[self.unit_len:] + ret += self.pack_data(buf, ogn_data_len) + self.last_rnd_len = ogn_data_len + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + sendback = False + + if not self.has_recv_header: + if len(self.recv_buf) >= 7 or len(self.recv_buf) in [2, 3]: + recv_len = min(len(self.recv_buf), 7) + mac_key = self.server_info.recv_iv + self.server_info.key + sha1data = hmac.new(mac_key, self.recv_buf[:1], self.hashfunc).digest()[:recv_len - 1] + if sha1data != self.recv_buf[1:recv_len]: + return self.not_match_return(self.recv_buf) + + if len(self.recv_buf) < 31: + return (b'', False) + sha1data = hmac.new(mac_key, self.recv_buf[7:27], self.hashfunc).digest()[:4] + if sha1data != self.recv_buf[27:31]: + logging.error('%s data uncorrect auth HMAC-SHA1 from %s:%d, data %s' % (self.no_compatible_method, self.server_info.client, self.server_info.client_port, binascii.hexlify(self.recv_buf))) + if len(self.recv_buf) < 31 + self.extra_wait_size: + return (b'', False) + return self.not_match_return(self.recv_buf) + + uid = self.recv_buf[7:11] + if uid in self.server_info.users: + self.user_id = uid + self.user_key = self.hashfunc(self.server_info.users[uid]).digest() + self.server_info.update_user_func(uid) + else: + if not self.server_info.users: + self.user_key = self.server_info.key + else: + self.user_key = self.server_info.recv_iv + encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + self.salt, 'aes-128-cbc') + head = encryptor.decrypt(b'\x00' * 16 + self.recv_buf[11:27] + b'\x00') # need an extra byte or recv empty + length = struct.unpack(' self.max_time_dif: + logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) + return self.not_match_return(self.recv_buf) + elif self.server_info.data.insert(self.user_id, client_id, connection_id): + self.has_recv_header = True + out_buf = self.recv_buf[31 + rnd_len:length - 4] + self.client_id = client_id + self.connection_id = connection_id + else: + logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) + return self.not_match_return(self.recv_buf) + self.recv_buf = self.recv_buf[length:] + self.has_recv_header = True + sendback = True + + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 8192 or length < 7: + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 0: + logging.info(self.no_compatible_method + ': over size') + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + if hmac.new(mac_key, self.recv_buf[:length - 4], self.hashfunc).digest()[:4] != self.recv_buf[length - 4:length]: + logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 0: + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data uncorrect checksum') + + self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF + pos = common.ord(self.recv_buf[4]) + if pos < 255: + pos += 4 + else: + pos = struct.unpack('> 17) ^ (y >> 26)) & xorshift128plus.max_int + self.v1 = x + return (x + y) & xorshift128plus.max_int + + def init_from_bin(self, bin): + bin += b'\0' * 16 + self.v0 = struct.unpack('= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class auth_base(plain.plain): + def __init__(self, method): + super(auth_base, self).__init__(method) + self.method = method + self.no_compatible_method = '' + self.overhead = 4 + + def init_data(self): + return '' + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + + def not_match_return(self, buf): + self.raw_trans = True + self.overhead = 0 + if self.method == self.no_compatible_method: + return (b'E'*2048, False) + return (buf, False) + +class client_queue(object): + def __init__(self, begin_id): + self.front = begin_id - 64 + self.back = begin_id + 1 + self.alloc = {} + self.enable = True + self.last_update = time.time() + self.ref = 0 + + def update(self): + self.last_update = time.time() + + def addref(self): + self.ref += 1 + + def delref(self): + if self.ref > 0: + self.ref -= 1 + + def is_active(self): + return (self.ref > 0) and (time.time() - self.last_update < 60 * 10) + + def re_enable(self, connection_id): + self.enable = True + self.front = connection_id - 64 + self.back = connection_id + 1 + self.alloc = {} + + def insert(self, connection_id): + if not self.enable: + logging.warn('obfs auth: not enable') + return False + if not self.is_active(): + self.re_enable(connection_id) + self.update() + if connection_id < self.front: + logging.warn('obfs auth: deprecated id, someone replay attack') + return False + if connection_id > self.front + 0x4000: + logging.warn('obfs auth: wrong id') + return False + if connection_id in self.alloc: + logging.warn('obfs auth: duplicate id, someone replay attack') + return False + if self.back <= connection_id: + self.back = connection_id + 1 + self.alloc[connection_id] = 1 + while (self.front in self.alloc) or self.front + 0x1000 < self.back: + if self.front in self.alloc: + del self.alloc[self.front] + self.front += 1 + self.addref() + return True + +class obfs_auth_chain_data(object): + def __init__(self, name): + self.name = name + self.user_id = {} + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if client_id in local_client_id: + local_client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: + if local_client_id.first() is None or len(local_client_id) < self.max_client: + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + if not local_client_id[local_client_id.first()].is_active(): + del local_client_id[local_client_id.first()] + if client_id not in local_client_id: + #TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + logging.warn(self.name + ': no inactive client') + return False + else: + return local_client_id[client_id].insert(connection_id) + + def remove(self, user_id, client_id): + if user_id in self.user_id: + local_client_id = self.user_id[user_id] + if client_id in local_client_id: + local_client_id[client_id].delref() + +class auth_chain_a(auth_base): + def __init__(self, method): + super(auth_chain_a, self).__init__(method) + self.hashfunc = hashlib.md5 + self.recv_buf = b'' + self.unit_len = 2800 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.salt = b"auth_chain_a" + self.no_compatible_method = 'auth_chain_a' + self.pack_id = 1 + self.recv_id = 1 + self.user_id = None + self.user_id_num = 0 + self.user_key = None + self.overhead = 4 + self.client_over_head = 4 + self.last_client_hash = b'' + self.last_server_hash = b'' + self.random_client = xorshift128plus() + self.random_server = xorshift128plus() + self.encryptor = None + + def init_data(self): + return obfs_auth_chain_data(self.method) + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + + def trapezoid_random_float(self, d): + if d == 0: + return random.random() + s = random.random() + a = 1 - d + return (math.sqrt(a * a + 4 * d * s) - a) / (2 * d) + + def trapezoid_random_int(self, max_val, d): + v = self.trapezoid_random_float(d) + return int(v * max_val) + + def rnd_data_len(self, buf_size, last_hash, random): + if buf_size > 1440: + return 0 + random.init_from_bin_len(last_hash, buf_size) + if buf_size > 1300: + return random.next() % 31 + if buf_size > 900: + return random.next() % 127 + if buf_size > 400: + return random.next() % 521 + return random.next() % 1021 + + def udp_rnd_data_len(self, last_hash, random): + random.init_from_bin(last_hash) + return random.next() % 127 + + def rnd_start_pos(self, rand_len, random): + if rand_len > 0: + return random.next() % 8589934609 % rand_len + return 0 + + def rnd_data(self, buf_size, buf, last_hash, random): + rand_len = self.rnd_data_len(buf_size, last_hash, random) + + rnd_data_buf = os.urandom(rand_len) + + if buf_size == 0: + return rnd_data_buf + else: + if rand_len > 0: + start_pos = self.rnd_start_pos(rand_len, random) + return rnd_data_buf[:start_pos] + buf + rnd_data_buf[start_pos:] + else: + return buf + + def pack_client_data(self, buf): + buf = self.encryptor.encrypt(buf) + data = self.rnd_data(len(buf), buf, self.last_client_hash, self.random_client) + data_len = len(data) + 8 + mac_key = self.user_key + struct.pack(' 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = os.urandom(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_client_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_client_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 4096: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + + if length + 4 > len(self.recv_buf): + break + + server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() + if server_hash[:2] != self.recv_buf[length + 2 : length + 4]: + logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data uncorrect checksum') + + pos = 2 + if data_len > 0 and rand_len > 0: + pos = 2 + self.rnd_start_pos(rand_len, self.random_server) + out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + self.last_server_hash = server_hash + if self.recv_id == 1: + self.server_info.tcp_mss = struct.unpack(' self.unit_len: + ret += self.pack_server_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_server_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + sendback = False + + if not self.has_recv_header: + if len(self.recv_buf) >= 12 or len(self.recv_buf) in [7, 8]: + recv_len = min(len(self.recv_buf), 12) + mac_key = self.server_info.recv_iv + self.server_info.key + md5data = hmac.new(mac_key, self.recv_buf[:4], self.hashfunc).digest() + if md5data[:recv_len - 4] != self.recv_buf[4:recv_len]: + return self.not_match_return(self.recv_buf) + + if len(self.recv_buf) < 12 + 24: + return (b'', False) + + self.last_client_hash = md5data + uid = struct.unpack(' self.max_time_dif: + logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) + return self.not_match_return(self.recv_buf) + elif self.server_info.data.insert(self.user_id, client_id, connection_id): + self.has_recv_header = True + self.client_id = client_id + self.connection_id = connection_id + else: + logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) + return self.not_match_return(self.recv_buf) + + self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') + self.recv_buf = self.recv_buf[36:] + self.has_recv_header = True + sendback = True + + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 4096: + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 0: + logging.info(self.no_compatible_method + ': over size') + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data error') + + if length + 4 > len(self.recv_buf): + break + + client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() + if client_hash[:2] != self.recv_buf[length + 2 : length + 4]: + logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 0: + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data uncorrect checksum') + + self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF + pos = 2 + if data_len > 0 and rand_len > 0: + pos = 2 + self.rnd_start_pos(rand_len, self.random_client) + out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + self.last_client_hash = client_hash + self.recv_buf = self.recv_buf[length + 4:] + if data_len == 0: + sendback = True + + if out_buf: + self.server_info.data.update(self.user_id, self.client_id, self.connection_id) + return (out_buf, sendback) + + def client_udp_pre_encrypt(self, buf): + if self.user_key is None: + if b':' in to_bytes(self.server_info.protocol_param): + try: + items = to_bytes(self.server_info.protocol_param).split(':') + self.user_key = self.hashfunc(items[1]).digest() + self.user_id = struct.pack('= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class http_simple(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.host = None + self.port = 0 + self.recv_buffer = b'' + self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0", + b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0", + b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36", + b"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/535.11 (KHTML, like Gecko) Ubuntu/11.10 Chromium/27.0.1453.93 Chrome/27.0.1453.93 Safari/537.36", + b"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:35.0) Gecko/20100101 Firefox/35.0", + b"Mozilla/5.0 (compatible; WOW64; MSIE 10.0; Windows NT 6.2)", + b"Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/533.20.25 (KHTML, like Gecko) Version/5.0.4 Safari/533.20.27", + b"Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.3; Trident/7.0; .NET4.0E; .NET4.0C)", + b"Mozilla/5.0 (Windows NT 6.3; Trident/7.0; rv:11.0) like Gecko", + b"Mozilla/5.0 (Linux; Android 4.4; Nexus 5 Build/BuildID) AppleWebKit/537.36 (KHTML, like Gecko) Version/4.0 Chrome/30.0.0.0 Mobile Safari/537.36", + b"Mozilla/5.0 (iPad; CPU OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3", + b"Mozilla/5.0 (iPhone; CPU iPhone OS 5_0 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9A334 Safari/7534.48.3"] + + def encode_head(self, buf): + hexstr = binascii.hexlify(buf) + chs = [] + for i in range(0, len(hexstr), 2): + chs.append(b"%" + hexstr[i:i+2]) + return b''.join(chs) + + def client_encode(self, buf): + if self.has_sent_header: + return buf + head_size = len(self.server_info.iv) + self.server_info.head_len + if len(buf) - head_size > 64: + headlen = head_size + random.randint(0, 64) + else: + headlen = len(buf) + headdata = buf[:headlen] + buf = buf[headlen:] + port = b'' + if self.server_info.port != 80: + port = b':' + to_bytes(str(self.server_info.port)) + body = None + hosts = (self.server_info.obfs_param or self.server_info.host) + pos = hosts.find("#") + if pos >= 0: + body = hosts[pos + 1:].replace("\n", "\r\n") + body = body.replace("\\n", "\r\n") + hosts = hosts[:pos] + hosts = hosts.split(',') + host = random.choice(hosts) + http_head = b"GET /" + self.encode_head(headdata) + b" HTTP/1.1\r\n" + http_head += b"Host: " + to_bytes(host) + port + b"\r\n" + if body: + http_head += body + "\r\n\r\n" + else: + http_head += b"User-Agent: " + random.choice(self.user_agent) + b"\r\n" + http_head += b"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\nDNT: 1\r\nConnection: keep-alive\r\n\r\n" + self.has_sent_header = True + return http_head + buf + + def client_decode(self, buf): + if self.has_recv_header: + return (buf, False) + pos = buf.find(b'\r\n\r\n') + if pos >= 0: + self.has_recv_header = True + return (buf[pos + 4:], False) + else: + return (b'', False) + + def server_encode(self, buf): + if self.has_sent_header: + return buf + + header = b'HTTP/1.1 200 OK\r\nConnection: keep-alive\r\nContent-Encoding: gzip\r\nContent-Type: text/html\r\nDate: ' + header += to_bytes(datetime.datetime.now().strftime('%a, %d %b %Y %H:%M:%S GMT')) + header += b'\r\nServer: nginx\r\nVary: Accept-Encoding\r\n\r\n' + self.has_sent_header = True + return header + buf + + def get_data_from_http_header(self, buf): + ret_buf = b'' + lines = buf.split(b'\r\n') + if lines and len(lines) > 1: + hex_items = lines[0].split(b'%') + if hex_items and len(hex_items) > 1: + for index in range(1, len(hex_items)): + if len(hex_items[index]) < 2: + ret_buf += binascii.unhexlify('0' + hex_items[index]) + break + elif len(hex_items[index]) > 2: + ret_buf += binascii.unhexlify(hex_items[index][:2]) + break + else: + ret_buf += binascii.unhexlify(hex_items[index]) + return ret_buf + return b'' + + def get_host_from_http_header(self, buf): + ret_buf = b'' + lines = buf.split(b'\r\n') + if lines and len(lines) > 1: + for line in lines: + if match_begin(line, b"Host: "): + return common.to_str(line[6:]) + + def not_match_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + if self.method == 'http_simple': + return (b'E'*2048, False, False) + return (buf, True, False) + + def error_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + return (b'E'*2048, False, False) + + def server_decode(self, buf): + if self.has_recv_header: + return (buf, True, False) + + self.recv_buffer += buf + buf = self.recv_buffer + if len(buf) > 10: + if match_begin(buf, b'GET ') or match_begin(buf, b'POST '): + if len(buf) > 65536: + self.recv_buffer = None + logging.warn('http_simple: over size') + return self.not_match_return(buf) + else: #not http header, run on original protocol + self.recv_buffer = None + logging.debug('http_simple: not match begin') + return self.not_match_return(buf) + else: + return (b'', True, False) + + if b'\r\n\r\n' in buf: + datas = buf.split(b'\r\n\r\n', 1) + ret_buf = self.get_data_from_http_header(buf) + host = self.get_host_from_http_header(buf) + if host and self.server_info.obfs_param: + pos = host.find(":") + if pos >= 0: + host = host[:pos] + hosts = self.server_info.obfs_param.split(',') + if host not in hosts: + return self.not_match_return(buf) + if len(ret_buf) < 4: + return self.error_return(buf) + if len(datas) > 1: + ret_buf += datas[1] + if len(ret_buf) >= 13: + self.has_recv_header = True + return (ret_buf, True, False) + return self.not_match_return(buf) + else: + return (b'', True, False) + +class http_post(http_simple): + def __init__(self, method): + super(http_post, self).__init__(method) + + def boundary(self): + return to_bytes(''.join([random.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789") for i in range(32)])) + + def client_encode(self, buf): + if self.has_sent_header: + return buf + head_size = len(self.server_info.iv) + self.server_info.head_len + if len(buf) - head_size > 64: + headlen = head_size + random.randint(0, 64) + else: + headlen = len(buf) + headdata = buf[:headlen] + buf = buf[headlen:] + port = b'' + if self.server_info.port != 80: + port = b':' + to_bytes(str(self.server_info.port)) + body = None + hosts = (self.server_info.obfs_param or self.server_info.host) + pos = hosts.find("#") + if pos >= 0: + body = hosts[pos + 1:].replace("\\n", "\r\n") + hosts = hosts[:pos] + hosts = hosts.split(',') + host = random.choice(hosts) + http_head = b"POST /" + self.encode_head(headdata) + b" HTTP/1.1\r\n" + http_head += b"Host: " + to_bytes(host) + port + b"\r\n" + if body: + http_head += body + "\r\n\r\n" + else: + http_head += b"User-Agent: " + random.choice(self.user_agent) + b"\r\n" + http_head += b"Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\nAccept-Language: en-US,en;q=0.8\r\nAccept-Encoding: gzip, deflate\r\n" + http_head += b"Content-Type: multipart/form-data; boundary=" + self.boundary() + b"\r\nDNT: 1\r\n" + http_head += b"Connection: keep-alive\r\n\r\n" + self.has_sent_header = True + return http_head + buf + + def not_match_return(self, buf): + self.has_sent_header = True + self.has_recv_header = True + if self.method == 'http_post': + return (b'E'*2048, False, False) + return (buf, True, False) + +class random_head(plain.plain): + def __init__(self, method): + self.method = method + self.has_sent_header = False + self.has_recv_header = False + self.raw_trans_sent = False + self.raw_trans_recv = False + self.send_buffer = b'' + + def client_encode(self, buf): + if self.raw_trans_sent: + return buf + self.send_buffer += buf + if not self.has_sent_header: + self.has_sent_header = True + data = os.urandom(common.ord(os.urandom(1)[0]) % 96 + 4) + crc = (0xffffffff - binascii.crc32(data)) & 0xffffffff + return data + struct.pack('= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class obfs_auth_data(object): + def __init__(self): + self.client_data = lru_cache.LRUCache(60 * 5) + self.client_id = os.urandom(32) + self.startup_time = int(time.time() - 60 * 30) & 0xFFFFFFFF + self.ticket_buf = {} + +class tls_ticket_auth(plain.plain): + def __init__(self, method): + self.method = method + self.handshake_status = 0 + self.send_buffer = b'' + self.recv_buffer = b'' + self.client_id = b'' + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.tls_version = b'\x03\x03' + self.overhead = 5 + + def init_data(self): + return obfs_auth_data() + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def sni(self, url): + url = common.to_bytes(url) + data = b"\x00" + struct.pack('>H', len(url)) + url + data = b"\x00\x00" + struct.pack('>H', len(data) + 2) + struct.pack('>H', len(data)) + data + return data + + def pack_auth_data(self, client_id): + utc_time = int(time.time()) & 0xFFFFFFFF + data = struct.pack('>I', utc_time) + os.urandom(18) + data += hmac.new(self.server_info.key + client_id, data, hashlib.sha1).digest()[:10] + return data + + def client_encode(self, buf): + if self.handshake_status == -1: + return buf + if self.handshake_status == 8: + ret = b'' + while len(buf) > 2048: + size = min(struct.unpack('>H', os.urandom(2))[0] % 4096 + 100, len(buf)) + ret += b"\x17" + self.tls_version + struct.pack('>H', size) + buf[:size] + buf = buf[size:] + if len(buf) > 0: + ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf + return ret + if len(buf) > 0: + self.send_buffer += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf + if self.handshake_status == 0: + self.handshake_status = 1 + data = self.tls_version + self.pack_auth_data(self.server_info.data.client_id) + b"\x20" + self.server_info.data.client_id + binascii.unhexlify(b"001cc02bc02fcca9cca8cc14cc13c00ac014c009c013009c0035002f000a" + b"0100") + ext = binascii.unhexlify(b"ff01000100") + host = self.server_info.obfs_param or self.server_info.host + if host and host[-1] in string.digits: + host = '' + hosts = host.split(',') + host = random.choice(hosts) + ext += self.sni(host) + ext += b"\x00\x17\x00\x00" + if host not in self.server_info.data.ticket_buf: + self.server_info.data.ticket_buf[host] = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 17 + 8) * 16) + ext += b"\x00\x23" + struct.pack('>H', len(self.server_info.data.ticket_buf[host])) + self.server_info.data.ticket_buf[host] + ext += binascii.unhexlify(b"000d001600140601060305010503040104030301030302010203") + ext += binascii.unhexlify(b"000500050100000000") + ext += binascii.unhexlify(b"00120000") + ext += binascii.unhexlify(b"75500000") + ext += binascii.unhexlify(b"000b00020100") + ext += binascii.unhexlify(b"000a0006000400170018") + data += struct.pack('>H', len(ext)) + ext + data = b"\x01\x00" + struct.pack('>H', len(data)) + data + data = b"\x16\x03\x01" + struct.pack('>H', len(data)) + data + return data + elif self.handshake_status == 1 and len(buf) == 0: + data = b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec + data += b"\x16" + self.tls_version + b"\x00\x20" + os.urandom(22) #Finished + data += hmac.new(self.server_info.key + self.server_info.data.client_id, data, hashlib.sha1).digest()[:10] + ret = data + self.send_buffer + self.send_buffer = b'' + self.handshake_status = 8 + return ret + return b'' + + def client_decode(self, buf): + if self.handshake_status == -1: + return (buf, False) + + if self.handshake_status == 8: + ret = b'' + self.recv_buffer += buf + while len(self.recv_buffer) > 5: + if ord(self.recv_buffer[0]) != 0x17: + logging.info("data = %s" % (binascii.hexlify(self.recv_buffer))) + raise Exception('server_decode appdata error') + size = struct.unpack('>H', self.recv_buffer[3:5])[0] + if len(self.recv_buffer) < size + 5: + break + buf = self.recv_buffer[5:size+5] + ret += buf + self.recv_buffer = self.recv_buffer[size+5:] + return (ret, False) + + if len(buf) < 11 + 32 + 1 + 32: + raise Exception('client_decode data error') + verify = buf[11:33] + if hmac.new(self.server_info.key + self.server_info.data.client_id, verify, hashlib.sha1).digest()[:10] != buf[33:43]: + raise Exception('client_decode data error') + if hmac.new(self.server_info.key + self.server_info.data.client_id, buf[:-10], hashlib.sha1).digest()[:10] != buf[-10:]: + raise Exception('client_decode data error') + return (b'', True) + + def server_encode(self, buf): + if self.handshake_status == -1: + return buf + if (self.handshake_status & 8) == 8: + ret = b'' + while len(buf) > 2048: + size = min(struct.unpack('>H', os.urandom(2))[0] % 4096 + 100, len(buf)) + ret += b"\x17" + self.tls_version + struct.pack('>H', size) + buf[:size] + buf = buf[size:] + if len(buf) > 0: + ret += b"\x17" + self.tls_version + struct.pack('>H', len(buf)) + buf + return ret + self.handshake_status |= 8 + data = self.tls_version + self.pack_auth_data(self.client_id) + b"\x20" + self.client_id + binascii.unhexlify(b"c02f000005ff01000100") + data = b"\x02\x00" + struct.pack('>H', len(data)) + data #server hello + data = b"\x16" + self.tls_version + struct.pack('>H', len(data)) + data + if random.randint(0, 8) < 1: + ticket = os.urandom((struct.unpack('>H', os.urandom(2))[0] % 164) * 2 + 64) + ticket = struct.pack('>H', len(ticket) + 4) + b"\x04\x00" + struct.pack('>H', len(ticket)) + ticket + data += b"\x16" + self.tls_version + ticket #New session ticket + data += b"\x14" + self.tls_version + b"\x00\x01\x01" #ChangeCipherSpec + finish_len = random.choice([32, 40]) + data += b"\x16" + self.tls_version + struct.pack('>H', finish_len) + os.urandom(finish_len - 10) #Finished + data += hmac.new(self.server_info.key + self.client_id, data, hashlib.sha1).digest()[:10] + if buf: + data += self.server_encode(buf) + return data + + def decode_error_return(self, buf): + self.handshake_status = -1 + self.overhead = 0 + if self.method == 'tls1.2_ticket_auth': + return (b'E'*2048, False, False) + return (buf, True, False) + + def server_decode(self, buf): + if self.handshake_status == -1: + return (buf, True, False) + + if (self.handshake_status & 4) == 4: + ret = b'' + self.recv_buffer += buf + while len(self.recv_buffer) > 5: + if ord(self.recv_buffer[0]) != 0x17 or ord(self.recv_buffer[1]) != 0x3 or ord(self.recv_buffer[2]) != 0x3: + logging.info("data = %s" % (binascii.hexlify(self.recv_buffer))) + raise Exception('server_decode appdata error') + size = struct.unpack('>H', self.recv_buffer[3:5])[0] + if len(self.recv_buffer) < size + 5: + break + ret += self.recv_buffer[5:size+5] + self.recv_buffer = self.recv_buffer[size+5:] + return (ret, True, False) + + if (self.handshake_status & 1) == 1: + self.recv_buffer += buf + buf = self.recv_buffer + verify = buf + if len(buf) < 11: + raise Exception('server_decode data error') + if not match_begin(buf, b"\x14" + self.tls_version + b"\x00\x01\x01"): #ChangeCipherSpec + raise Exception('server_decode data error') + buf = buf[6:] + if not match_begin(buf, b"\x16" + self.tls_version + b"\x00"): #Finished + raise Exception('server_decode data error') + verify_len = struct.unpack('>H', buf[3:5])[0] + 1 # 11 - 10 + if len(verify) < verify_len + 10: + return (b'', False, False) + if hmac.new(self.server_info.key + self.client_id, verify[:verify_len], hashlib.sha1).digest()[:10] != verify[verify_len:verify_len+10]: + raise Exception('server_decode data error') + self.recv_buffer = verify[verify_len + 10:] + status = self.handshake_status + self.handshake_status |= 4 + ret = self.server_decode(b'') + return ret; + + #raise Exception("handshake data = %s" % (binascii.hexlify(buf))) + self.recv_buffer += buf + buf = self.recv_buffer + ogn_buf = buf + if len(buf) < 3: + return (b'', False, False) + if not match_begin(buf, b'\x16\x03\x01'): + return self.decode_error_return(ogn_buf) + buf = buf[3:] + header_len = struct.unpack('>H', buf[:2])[0] + if header_len > len(buf) - 2: + return (b'', False, False) + + self.recv_buffer = self.recv_buffer[header_len + 5:] + self.handshake_status = 1 + buf = buf[2:header_len + 2] + if not match_begin(buf, b'\x01\x00'): #client hello + logging.info("tls_auth not client hello message") + return self.decode_error_return(ogn_buf) + buf = buf[2:] + if struct.unpack('>H', buf[:2])[0] != len(buf) - 2: + logging.info("tls_auth wrong message size") + return self.decode_error_return(ogn_buf) + buf = buf[2:] + if not match_begin(buf, self.tls_version): + logging.info("tls_auth wrong tls version") + return self.decode_error_return(ogn_buf) + buf = buf[2:] + verifyid = buf[:32] + buf = buf[32:] + sessionid_len = ord(buf[0]) + if sessionid_len < 32: + logging.info("tls_auth wrong sessionid_len") + return self.decode_error_return(ogn_buf) + sessionid = buf[1:sessionid_len + 1] + buf = buf[sessionid_len+1:] + self.client_id = sessionid + sha1 = hmac.new(self.server_info.key + sessionid, verifyid[:22], hashlib.sha1).digest()[:10] + utc_time = struct.unpack('>I', verifyid[:4])[0] + time_dif = common.int32((int(time.time()) & 0xffffffff) - utc_time) + if self.server_info.obfs_param: + try: + self.max_time_dif = int(self.server_info.obfs_param) + except: + pass + if self.max_time_dif > 0 and (time_dif < -self.max_time_dif or time_dif > self.max_time_dif \ + or common.int32(utc_time - self.server_info.data.startup_time) < -self.max_time_dif / 2): + logging.info("tls_auth wrong time") + return self.decode_error_return(ogn_buf) + if sha1 != verifyid[22:]: + logging.info("tls_auth wrong sha1") + return self.decode_error_return(ogn_buf) + if self.server_info.data.client_data.get(verifyid[:22]): + logging.info("replay attack detect, id = %s" % (binascii.hexlify(verifyid))) + return self.decode_error_return(ogn_buf) + self.server_info.data.client_data.sweep() + self.server_info.data.client_data[verifyid[:22]] = sessionid + if len(self.recv_buffer) >= 11: + ret = self.server_decode(b'') + return (ret[0], True, True) + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) + return (b'', False, True) + diff --git a/shadowsocks/obfsplugin/plain.py b/shadowsocks/obfsplugin/plain.py new file mode 100644 index 0000000..8c6355c --- /dev/null +++ b/shadowsocks/obfsplugin/plain.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging + +from shadowsocks.common import ord + +def create_obfs(method): + return plain(method) + +obfs_map = { + 'plain': (create_obfs,), + 'origin': (create_obfs,), +} + +class plain(object): + def __init__(self, method): + self.method = method + self.server_info = None + + def init_data(self): + return b'' + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return 0 + + def get_server_info(self): + return self.server_info + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_pre_encrypt(self, buf): + return buf + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + # (buffer_to_recv, is_need_to_encode_and_send_back) + return (buf, False) + + def client_post_decrypt(self, buf): + return buf + + def server_pre_encrypt(self, buf): + return buf + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + # (buffer_to_recv, is_need_decrypt, is_need_to_encode_and_send_back) + return (buf, True, False) + + def server_post_decrypt(self, buf): + return (buf, False) + + def client_udp_pre_encrypt(self, buf): + return buf + + def client_udp_post_decrypt(self, buf): + return buf + + def server_udp_pre_encrypt(self, buf, uid): + return buf + + def server_udp_post_decrypt(self, buf): + return (buf, None) + + def dispose(self): + pass + + def get_head_size(self, buf, def_value): + if len(buf) < 2: + return def_value + head_type = ord(buf[0]) & 0x7 + if head_type == 1: + return 7 + if head_type == 4: + return 19 + if head_type == 3: + return 4 + ord(buf[1]) + return def_value + diff --git a/shadowsocks/obfsplugin/verify.py b/shadowsocks/obfsplugin/verify.py new file mode 100644 index 0000000..0dc0ca6 --- /dev/null +++ b/shadowsocks/obfsplugin/verify.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# +# Copyright 2015-2015 breakwa11 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import os +import sys +import hashlib +import logging +import binascii +import base64 +import time +import datetime +import random +import struct +import zlib +import hmac +import hashlib + +import shadowsocks +from shadowsocks import common +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord, chr + +def create_verify_deflate(method): + return verify_deflate(method) + +obfs_map = { + 'verify_deflate': (create_verify_deflate,), +} + +def match_begin(str1, str2): + if len(str1) >= len(str2): + if str1[:len(str2)] == str2: + return True + return False + +class obfs_verify_data(object): + def __init__(self): + pass + +class verify_base(plain.plain): + def __init__(self, method): + super(verify_base, self).__init__(method) + self.method = method + + def init_data(self): + return obfs_verify_data() + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + +class verify_deflate(verify_base): + def __init__(self, method): + super(verify_deflate, self).__init__(method) + self.recv_buf = b'' + self.unit_len = 32700 + self.decrypt_packet_num = 0 + self.raw_trans = False + + def pack_data(self, buf): + if len(buf) == 0: + return b'' + data = zlib.compress(buf) + data = struct.pack('>H', len(data)) + data[2:] + return data + + def client_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 32768 or length < 6: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + if length > len(self.recv_buf): + break + + out_buf += zlib.decompress(b'x\x9c' + self.recv_buf[2:length]) + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return out_buf + + def server_pre_encrypt(self, buf): + ret = b'' + while len(buf) > self.unit_len: + ret += self.pack_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 2: + length = struct.unpack('>H', self.recv_buf[:2])[0] + if length >= 32768 or length < 6: + self.raw_trans = True + self.recv_buf = b'' + if self.decrypt_packet_num == 0: + return (b'E'*2048, False) + else: + raise Exception('server_post_decrype data error') + if length > len(self.recv_buf): + break + + out_buf += zlib.decompress(b'\x78\x9c' + self.recv_buf[2:length]) + self.recv_buf = self.recv_buf[length:] + + if out_buf: + self.decrypt_packet_num += 1 + return (out_buf, False) +