Skip to content

Commit

Permalink
Fix Lawouach#179 properly
Browse files Browse the repository at this point in the history
Thanks @medington

Closes Lawouach#219

Signed-off-by: Sylvain Hellegouarch <[email protected]>
  • Loading branch information
Lawouach committed Feb 28, 2018
2 parents b50de2b + 854b33b commit 60d5384
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
66 changes: 33 additions & 33 deletions test/test_client.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,28 @@
# -*- coding: utf-8 -*-
from base64 import b64encode
from hashlib import sha1
import os
import socket
import time
import unittest

from mock import MagicMock, call, patch
from mock import MagicMock, patch

from ws4py.manager import WebSocketManager
from ws4py.websocket import WebSocket
from ws4py import WS_KEY
from ws4py.exc import HandshakeError
from ws4py.framing import Frame, OPCODE_TEXT, OPCODE_CLOSE
from ws4py.messaging import CloseControlMessage
from ws4py.client import WebSocketBaseClient
from ws4py.client.threadedclient import WebSocketClient

class BasicClientTest(unittest.TestCase):
def test_invalid_hostname_in_url(self):
self.assertRaises(ValueError, WebSocketBaseClient, url="qsdfqsd65qsd354")

def test_invalid_scheme_in_url(self):
self.assertRaises(ValueError, WebSocketBaseClient, url="ftp://localhost")

def test_invalid_hostname_in_url(self):
self.assertRaises(ValueError, WebSocketBaseClient, url="ftp://?/")

def test_parse_unix_schemes(self):
c = WebSocketBaseClient(url="ws+unix:///my.socket")
self.assertEqual(c.scheme, "ws+unix")
Expand All @@ -35,95 +31,98 @@ def test_parse_unix_schemes(self):
self.assertEqual(c.unix_socket_path, "/my.socket")
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, "/my.socket")

c = WebSocketBaseClient(url="wss+unix:///my.socket")
self.assertEqual(c.scheme, "wss+unix")
self.assertEqual(c.host, "localhost")
self.assertIsNone(c.port)
self.assertEqual(c.unix_socket_path, "/my.socket")
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, "/my.socket")

def test_parse_ws_scheme(self):
c = WebSocketBaseClient(url="ws://127.0.0.1/")
self.assertEqual(c.scheme, "ws")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 80)
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, ("127.0.0.1", 80))

def test_parse_ws_scheme_when_missing_resource(self):
c = WebSocketBaseClient(url="ws://127.0.0.1")
self.assertEqual(c.scheme, "ws")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 80)
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, ("127.0.0.1", 80))

def test_parse_ws_scheme_with_port(self):
c = WebSocketBaseClient(url="ws://127.0.0.1:9090")
self.assertEqual(c.scheme, "ws")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 9090)
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, ("127.0.0.1", 9090))

def test_parse_ws_scheme_with_query_string(self):
c = WebSocketBaseClient(url="ws://127.0.0.1/?token=value")
self.assertEqual(c.scheme, "ws")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 80)
self.assertEqual(c.resource, "/?token=value")
self.assertEqual(c.bind_addr, ("127.0.0.1", 80))

def test_parse_wss_scheme(self):
c = WebSocketBaseClient(url="wss://127.0.0.1/")
self.assertEqual(c.scheme, "wss")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 443)
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, ("127.0.0.1", 443))

def test_parse_wss_scheme_when_missing_resource(self):
c = WebSocketBaseClient(url="wss://127.0.0.1")
self.assertEqual(c.scheme, "wss")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 443)
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, ("127.0.0.1", 443))

def test_parse_wss_scheme_with_port(self):
c = WebSocketBaseClient(url="wss://127.0.0.1:9090")
self.assertEqual(c.scheme, "wss")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 9090)
self.assertEqual(c.resource, "/")
self.assertEqual(c.bind_addr, ("127.0.0.1", 9090))

def test_parse_wss_scheme_with_query_string(self):
c = WebSocketBaseClient(url="wss://127.0.0.1/?token=value")
self.assertEqual(c.scheme, "wss")
self.assertEqual(c.host, "127.0.0.1")
self.assertEqual(c.port, 443)
self.assertEqual(c.resource, "/?token=value")
self.assertEqual(c.bind_addr, ("127.0.0.1", 443))

@patch('ws4py.client.socket')
def test_connect_and_close(self, sock):

s = MagicMock()
sock.socket.return_value = s
sock.getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 0, "",
("127.0.0.1", 80, 0, 0))]

c = WebSocketBaseClient(url="ws://127.0.0.1/?token=value")

s.recv.return_value = b"\r\n".join([
b"HTTP/1.1 101 Switching Protocols",
b"Connection: Upgrade",
b"Sec-Websocket-Version: 13",
b"Content-Type: text/plain;charset=utf-8",
b"Sec-Websocket-Accept: " + b64encode(sha1(c.key + WS_KEY).digest()),
b"Sec-WebSocket-Protocol: proto1, proto2",
b"Sec-WebSocket-Extensions: ext1, ext2",
b"Sec-WebSocket-Extensions: ext3",
b"Upgrade: websocket",
b"Date: Sun, 26 Jul 2015 12:32:55 GMT",
b"Server: ws4py/test",
Expand All @@ -132,6 +131,8 @@ def test_connect_and_close(self, sock):

c.connect()
s.connect.assert_called_once_with(("127.0.0.1", 80))
self.assertEqual(c.protocols, [b'proto1', b'proto2'])
self.assertEqual(c.extensions, [b'ext1', b'ext2', b'ext3'])

s.reset_mock()
c.close(code=1006, reason="boom")
Expand All @@ -140,32 +141,32 @@ def test_connect_and_close(self, sock):
f.parser.send(args[0][0])
f.parser.close()
self.assertIn(b'boom', f.unmask(f.body))

@patch('ws4py.client.socket')
def test_empty_response(self, sock):

s = MagicMock()
sock.socket.return_value = s
sock.getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 0, "",
("127.0.0.1", 80, 0, 0))]

c = WebSocketBaseClient(url="ws://127.0.0.1/?token=value")

s.recv.return_value = b""
self.assertRaises(HandshakeError, c.connect)
s.shutdown.assert_called_once_with(socket.SHUT_RDWR)
s.close.assert_called_once_with()

@patch('ws4py.client.socket')
def test_invdalid_response_code(self, sock):
def test_invalid_response_code(self, sock):

s = MagicMock()
sock.socket.return_value = s
sock.getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 0, "",
("127.0.0.1", 80, 0, 0))]

c = WebSocketBaseClient(url="ws://127.0.0.1/?token=value")

s.recv.return_value = b"\r\n".join([
b"HTTP/1.1 200 Switching Protocols",
b"Connection: Upgrade",
Expand All @@ -181,18 +182,18 @@ def test_invdalid_response_code(self, sock):
self.assertRaises(HandshakeError, c.connect)
s.shutdown.assert_called_once_with(socket.SHUT_RDWR)
s.close.assert_called_once_with()

@patch('ws4py.client.socket')
def test_invalid_response_headers(self, sock):

for key_header, invalid_value in ((b'upgrade', b'boom'),
(b'connection', b'bim')):
s = MagicMock()
sock.socket.return_value = s
sock.getaddrinfo.return_value = [(socket.AF_INET, socket.SOCK_STREAM, 0, "",
("127.0.0.1", 80, 0, 0))]
c = WebSocketBaseClient(url="ws://127.0.0.1/?token=value")

status_line = b"HTTP/1.1 101 Switching Protocols"
headers = {
b"connection": b"Upgrade",
Expand All @@ -205,15 +206,15 @@ def test_invalid_response_headers(self, sock):
}

headers[key_header] = invalid_value

request = [status_line] + [k + b" : " + v for (k, v) in headers.items()] + [b'\r\n']
s.recv.return_value = b"\r\n".join(request)

self.assertRaises(HandshakeError, c.connect)
s.shutdown.assert_called_once_with(socket.SHUT_RDWR)
s.close.assert_called_once_with()
sock.reset_mock()

class ThreadedClientTest(unittest.TestCase):

@patch('ws4py.client.socket')
Expand Down Expand Up @@ -277,7 +278,6 @@ def test_thread_is_started_once_connected_secure(self):
self.assertFalse(self.client._th.is_alive())



if __name__ == '__main__':
suite = unittest.TestSuite()
loader = unittest.TestLoader()
Expand Down
12 changes: 6 additions & 6 deletions ws4py/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ def __init__(self, url, protocols=None, extensions=None,
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> from ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws://localhost/ws')
Here is an example for a TCP client over SSL:
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> from ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('wss://localhost/ws')
Finally an example of a Unix-domain connection:
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> from ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')
Note that in this case, the initial Upgrade request
Expand All @@ -61,7 +61,7 @@ def __init__(self, url, protocols=None, extensions=None,
.. code-block:: python
>>> from websocket.client import WebSocketBaseClient
>>> from ws4py.client import WebSocketBaseClient
>>> ws = WebSocketBaseClient('ws+unix:///tmp/my.sock')
>>> ws.resource = '/ws'
>>> ws.connect()
Expand Down Expand Up @@ -333,10 +333,10 @@ def process_handshake_header(self, headers):
raise HandshakeError("Invalid challenge response: %s" % value)

elif header == b'sec-websocket-protocol':
protocols.append(value.decode('utf-8'))
protocols.extend([x.strip() for x in value.split(b',')])

elif header == b'sec-websocket-extensions':
extensions.append(value.decode('utf-8'))
extensions.extend([x.strip() for x in value.split(b',')])

return protocols, extensions

Expand Down

0 comments on commit 60d5384

Please sign in to comment.