Skip to content

Commit

Permalink
Make all unit tests pass on py3.3/3.4
Browse files Browse the repository at this point in the history
  • Loading branch information
brutasse authored and Mark Roberts committed Sep 3, 2014
1 parent 83af510 commit cf0b7f0
Show file tree
Hide file tree
Showing 22 changed files with 349 additions and 271 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ language: python
python:
- 2.6
- 2.7
- 3.3
- 3.4
- pypy

env:
Expand Down
2 changes: 1 addition & 1 deletion kafka/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _next_id(self):
"""
Generate a new correlation id
"""
return KafkaClient.ID_GEN.next()
return next(KafkaClient.ID_GEN)

def _send_broker_unaware_request(self, requestId, request):
"""
Expand Down
19 changes: 11 additions & 8 deletions kafka/codec.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from cStringIO import StringIO
from io import BytesIO
import gzip
import struct

_XERIAL_V1_HEADER = (-126, 'S', 'N', 'A', 'P', 'P', 'Y', 0, 1, 1)
import six
from six.moves import xrange

_XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1)
_XERIAL_V1_FORMAT = 'bccccccBii'

try:
Expand All @@ -21,7 +24,7 @@ def has_snappy():


def gzip_encode(payload):
buffer = StringIO()
buffer = BytesIO()
handle = gzip.GzipFile(fileobj=buffer, mode="w")
handle.write(payload)
handle.close()
Expand All @@ -32,7 +35,7 @@ def gzip_encode(payload):


def gzip_decode(payload):
buffer = StringIO(payload)
buffer = BytesIO(payload)
handle = gzip.GzipFile(fileobj=buffer, mode='r')
result = handle.read()
handle.close()
Expand Down Expand Up @@ -68,9 +71,9 @@ def _chunker():
for i in xrange(0, len(payload), xerial_blocksize):
yield payload[i:i+xerial_blocksize]

out = StringIO()
out = BytesIO()

header = ''.join([struct.pack('!' + fmt, dat) for fmt, dat
header = b''.join([struct.pack('!' + fmt, dat) for fmt, dat
in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER)])

out.write(header)
Expand Down Expand Up @@ -121,8 +124,8 @@ def snappy_decode(payload):

if _detect_xerial_stream(payload):
# TODO ? Should become a fileobj ?
out = StringIO()
byt = buffer(payload[16:])
out = BytesIO()
byt = payload[16:]
length = len(byt)
cursor = 0

Expand Down
10 changes: 6 additions & 4 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from random import shuffle
from threading import local

import six

from kafka.common import ConnectionError

log = logging.getLogger("kafka")
Expand All @@ -19,7 +21,7 @@ def collect_hosts(hosts, randomize=True):
randomize the returned list.
"""

if isinstance(hosts, basestring):
if isinstance(hosts, six.string_types):
hosts = hosts.strip().split(',')

result = []
Expand Down Expand Up @@ -92,7 +94,7 @@ def _read_bytes(self, num_bytes):
# Receiving empty string from recv signals
# that the socket is in error. we will never get
# more data from this socket
if data == '':
if data == b'':
raise socket.error("Not enough data to read message -- did server kill socket?")

except socket.error:
Expand All @@ -103,7 +105,7 @@ def _read_bytes(self, num_bytes):
log.debug("Read %d/%d bytes from Kafka", num_bytes - bytes_left, num_bytes)
responses.append(data)

return ''.join(responses)
return b''.join(responses)

##################
# Public API #
Expand Down Expand Up @@ -144,7 +146,7 @@ def recv(self, request_id):

# Read the remainder of the response
resp = self._read_bytes(size)
return str(resp)
return resp

def copy(self):
"""
Expand Down
11 changes: 9 additions & 2 deletions kafka/consumer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from __future__ import absolute_import

from itertools import izip_longest, repeat
try:
from itertools import zip_longest as izip_longest, repeat # pylint: disable-msg=E0611
except ImportError: # python 2
from itertools import izip_longest as izip_longest, repeat
import logging
import time
import numbers
from threading import Lock
from multiprocessing import Process, Queue as MPQueue, Event, Value
from Queue import Empty, Queue

try:
from Queue import Empty, Queue
except ImportError: # python 2
from queue import Empty, Queue

import kafka
from kafka.common import (
Expand Down
14 changes: 10 additions & 4 deletions kafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@
import time
import random

from Queue import Empty
try:
from queue import Empty
except ImportError:
from Queue import Empty
from collections import defaultdict
from itertools import cycle
from multiprocessing import Queue, Process

import six
from six.moves import xrange

from kafka.common import (
ProduceRequest, TopicAndPartition, UnsupportedCodecError, UnknownTopicOrPartitionError
)
Expand Down Expand Up @@ -172,8 +178,8 @@ def send_messages(self, topic, partition, *msg):
if not isinstance(msg, (list, tuple)):
raise TypeError("msg is not a list or tuple!")

# Raise TypeError if any message is not encoded as a str
if any(not isinstance(m, str) for m in msg):
# Raise TypeError if any message is not encoded as bytes
if any(not isinstance(m, six.binary_type) for m in msg):
raise TypeError("all produce message payloads must be type str")

if self.async:
Expand Down Expand Up @@ -221,7 +227,7 @@ class SimpleProducer(Producer):
batch_send_every_t - If set, messages are send after this timeout
random_start - If true, randomize the initial partition which the
the first message block will be published to, otherwise
if false, the first message block will always publish
if false, the first message block will always publish
to partition 0 before cycling through each partition
"""
def __init__(self, client, async=False,
Expand Down
13 changes: 8 additions & 5 deletions kafka/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import logging
import struct
import zlib

import six

from six.moves import xrange

from kafka.codec import (
gzip_encode, gzip_decode, snappy_encode, snappy_decode
Expand All @@ -13,7 +16,7 @@
UnsupportedCodecError
)
from kafka.util import (
read_short_string, read_int_string, relative_unpack,
crc32, read_short_string, read_int_string, relative_unpack,
write_short_string, write_int_string, group_by_topic_and_partition
)

Expand Down Expand Up @@ -67,7 +70,7 @@ def _encode_message_set(cls, messages):
Offset => int64
MessageSize => int32
"""
message_set = ""
message_set = b""
for message in messages:
encoded_message = KafkaProtocol._encode_message(message)
message_set += struct.pack('>qi%ds' % len(encoded_message), 0, len(encoded_message), encoded_message)
Expand All @@ -94,7 +97,7 @@ def _encode_message(cls, message):
msg = struct.pack('>BB', message.magic, message.attributes)
msg += write_int_string(message.key)
msg += write_int_string(message.value)
crc = zlib.crc32(msg)
crc = crc32(msg)
msg = struct.pack('>i%ds' % len(msg), crc, msg)
else:
raise ProtocolError("Unexpected magic number: %d" % message.magic)
Expand Down Expand Up @@ -146,7 +149,7 @@ def _decode_message(cls, data, offset):
of the MessageSet payload).
"""
((crc, magic, att), cur) = relative_unpack('>iBB', data, 0)
if crc != zlib.crc32(data[4:]):
if crc != crc32(data[4:]):
raise ChecksumError("Message checksum failed")

(key, cur) = read_int_string(data, cur)
Expand Down
26 changes: 21 additions & 5 deletions kafka/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
import collections
import struct
import sys
import zlib
from threading import Thread, Event

import six

from kafka.common import BufferUnderflowError


def crc32(data):
"""
Python 2 returns a value in the range [-2**31, 2**31-1].
Python 3 returns a value in the range [0, 2**32-1].
We want a consistent behavior so let's use python2's.
"""
crc = zlib.crc32(data)
if six.PY3 and crc > 2**31:
crc -= 2 ** 32
return crc


def write_int_string(s):
if s is not None and not isinstance(s, str):
raise TypeError('Expected "%s" to be str\n'
if s is not None and not isinstance(s, six.binary_type):
raise TypeError('Expected "%s" to be bytes\n'
'data=%s' % (type(s), repr(s)))
if s is None:
return struct.pack('>i', -1)
Expand All @@ -17,12 +33,12 @@ def write_int_string(s):


def write_short_string(s):
if s is not None and not isinstance(s, str):
raise TypeError('Expected "%s" to be str\n'
if s is not None and not isinstance(s, six.binary_type):
raise TypeError('Expected "%s" to be bytes\n'
'data=%s' % (type(s), repr(s)))
if s is None:
return struct.pack('>h', -1)
elif len(s) > 32767 and sys.version < (2, 7):
elif len(s) > 32767 and sys.version_info < (2, 7):
# Python 2.6 issues a deprecation warning instead of a struct error
raise struct.error(len(s))
else:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def run(cls):
is also supported for message sets.
""",
keywords="apache kafka",
install_requires=['six'],
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
Expand Down
10 changes: 5 additions & 5 deletions test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import shutil
import subprocess
import tempfile
import urllib2
from six.moves import urllib
import uuid

from urlparse import urlparse
from six.moves.urllib.parse import urlparse # pylint: disable-msg=E0611
from test.service import ExternalService, SpawnedService
from test.testutil import get_open_port

Expand Down Expand Up @@ -42,12 +42,12 @@ def download_official_distribution(cls,
try:
url = url_base + distfile + '.tgz'
logging.info("Attempting to download %s", url)
response = urllib2.urlopen(url)
except urllib2.HTTPError:
response = urllib.request.urlopen(url)
except urllib.error.HTTPError:
logging.exception("HTTP Error")
url = url_base + distfile + '.tar.gz'
logging.info("Attempting to download %s", url)
response = urllib2.urlopen(url)
response = urllib.request.urlopen(url)

logging.info("Saving distribution file to %s", output_file)
with open(output_file, 'w') as output_file_fd:
Expand Down
21 changes: 11 additions & 10 deletions test/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import unittest

import six
from mock import MagicMock, patch

from kafka import KafkaClient
Expand All @@ -15,25 +16,25 @@ def test_init_with_list(self):
with patch.object(KafkaClient, 'load_metadata_for_topics'):
client = KafkaClient(hosts=['kafka01:9092', 'kafka02:9092', 'kafka03:9092'])

self.assertItemsEqual(
[('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
client.hosts)
self.assertEqual(
sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]),
sorted(client.hosts))

def test_init_with_csv(self):
with patch.object(KafkaClient, 'load_metadata_for_topics'):
client = KafkaClient(hosts='kafka01:9092,kafka02:9092,kafka03:9092')

self.assertItemsEqual(
[('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
client.hosts)
self.assertEqual(
sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]),
sorted(client.hosts))

def test_init_with_unicode_csv(self):
with patch.object(KafkaClient, 'load_metadata_for_topics'):
client = KafkaClient(hosts=u'kafka01:9092,kafka02:9092,kafka03:9092')

self.assertItemsEqual(
[('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)],
client.hosts)
self.assertEqual(
sorted([('kafka01', 9092), ('kafka02', 9092), ('kafka03', 9092)]),
sorted(client.hosts))

def test_send_broker_unaware_request_fail(self):
'Tests that call fails when all hosts are unavailable'
Expand All @@ -58,7 +59,7 @@ def mock_get_conn(host, port):
with self.assertRaises(KafkaUnavailableError):
client._send_broker_unaware_request(1, 'fake request')

for key, conn in mocked_conns.iteritems():
for key, conn in six.iteritems(mocked_conns):
conn.send.assert_called_with(1, 'fake request')

def test_send_broker_unaware_request(self):
Expand Down
Loading

0 comments on commit cf0b7f0

Please sign in to comment.