Skip to content

Commit

Permalink
KAFKA-3996; ByteBufferMessageSet.writeTo() should be non-blocking
Browse files Browse the repository at this point in the history
Also:
* Introduce a blocking variant to be used by `FileMessageSet.append`
* Add tests
* Minor clean-ups

Author: Ismael Juma <[email protected]>

Reviewers: Jun Rao <[email protected]>

Closes apache#1669 from ijuma/kafka-3996-byte-buffer-message-set-write-to-non-blocking
  • Loading branch information
ijuma authored and junrao committed Jul 27, 2016
1 parent 4059f07 commit 64842f4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
public class ByteBufferSend implements Send {

private final String destination;
private final int size;
protected final ByteBuffer[] buffers;
private int remaining;
private int size;
private boolean pending = false;

public ByteBufferSend(String destination, ByteBuffer... buffers) {
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/kafka/log/FileMessageSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class FileMessageSet private[kafka](@volatile var file: File,
* Append these messages to the message set
*/
def append(messages: ByteBufferMessageSet) {
val written = messages.writeTo(channel, 0, messages.sizeInBytes)
val written = messages.writeFullyTo(channel)
_size.getAndAdd(written)
}

Expand Down
18 changes: 15 additions & 3 deletions core/src/main/scala/kafka/message/ByteBufferMessageSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -294,16 +294,28 @@ class ByteBufferMessageSet(val buffer: ByteBuffer) extends MessageSet with Loggi
}

/** Write the messages in this set to the given channel */
def writeTo(channel: GatheringByteChannel, offset: Long, size: Int): Int = {
// Ignore offset and size from input. We just want to write the whole buffer to the channel.
def writeFullyTo(channel: GatheringByteChannel): Int = {
buffer.mark()
var written = 0
while(written < sizeInBytes)
while (written < sizeInBytes)
written += channel.write(buffer)
buffer.reset()
written
}

/** Write the messages in this set to the given channel starting at the given offset byte.
* Less than the complete amount may be written, but no more than maxSize can be. The number
* of bytes written is returned */
def writeTo(channel: GatheringByteChannel, offset: Long, maxSize: Int): Int = {
if (offset > Int.MaxValue)
throw new IllegalArgumentException(s"offset should not be larger than Int.MaxValue: $offset")
val dup = buffer.duplicate()
val position = offset.toInt
dup.position(position)
dup.limit(math.min(buffer.limit, position + maxSize))
channel.write(dup)
}

override def isMagicValueInAllWrapperMessages(expectedMagicValue: Byte): Boolean = {
for (messageAndOffset <- shallowIterator) {
if (messageAndOffset.message.magic != expectedMagicValue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,45 @@

package kafka.message

import java.io.RandomAccessFile
import java.nio.ByteBuffer
import java.nio.channels.{FileChannel, GatheringByteChannel}
import java.nio.file.StandardOpenOption

import org.junit.Assert._
import kafka.utils.TestUtils._
import kafka.log.FileMessageSet
import kafka.utils.TestUtils
import org.scalatest.junit.JUnitSuite
import org.junit.Test

import scala.collection.mutable.ArrayBuffer

trait BaseMessageSetTestCases extends JUnitSuite {


private class StubByteChannel(bytesToConsumePerBuffer: Int) extends GatheringByteChannel {

val data = new ArrayBuffer[Byte]

def write(srcs: Array[ByteBuffer], offset: Int, length: Int): Long = {
srcs.map { src =>
val array = new Array[Byte](math.min(bytesToConsumePerBuffer, src.remaining))
src.get(array)
data ++= array
array.length
}.sum
}

def write(srcs: Array[ByteBuffer]): Long = write(srcs, 0, srcs.map(_.remaining).sum)

def write(src: ByteBuffer): Int = write(Array(src)).toInt

def isOpen: Boolean = true

def close() {}

}


val messages = Array(new Message("abcd".getBytes), new Message("efgh".getBytes), new Message("ijkl".getBytes))

def createMessageSet(messages: Seq[Message]): MessageSet
Expand Down Expand Up @@ -56,20 +86,48 @@ trait BaseMessageSetTestCases extends JUnitSuite {
@Test
def testWriteTo() {
// test empty message set
testWriteToWithMessageSet(createMessageSet(Array[Message]()))
testWriteToWithMessageSet(createMessageSet(messages))
checkWriteToWithMessageSet(createMessageSet(Array[Message]()))
checkWriteToWithMessageSet(createMessageSet(messages))
}

def testWriteToWithMessageSet(set: MessageSet) {
/* Tests that writing to a channel that doesn't consume all the bytes in the buffer works correctly */
@Test
def testWriteToChannelThatConsumesPartially() {
val bytesToConsumePerBuffer = 50
val messages = (0 until 10).map(_ => new Message(TestUtils.randomString(100).getBytes))
val messageSet = createMessageSet(messages)
val messageSetSize = messageSet.sizeInBytes

val channel = new StubByteChannel(bytesToConsumePerBuffer)

var remaining = messageSetSize
var iterations = 0
while (remaining > 0) {
remaining -= messageSet.writeTo(channel, messageSetSize - remaining, remaining)
iterations += 1
}

assertEquals((messageSetSize / bytesToConsumePerBuffer) + 1, iterations)
checkEquals(new ByteBufferMessageSet(ByteBuffer.wrap(channel.data.toArray)).iterator, messageSet.iterator)
}

def checkWriteToWithMessageSet(messageSet: MessageSet) {
checkWriteWithMessageSet(messageSet, messageSet.writeTo(_, 0, messageSet.sizeInBytes))
}

def checkWriteWithMessageSet(set: MessageSet, write: GatheringByteChannel => Long) {
// do the write twice to ensure the message set is restored to its original state
for(i <- List(0,1)) {
for (_ <- 0 to 1) {
val file = tempFile()
val channel = new RandomAccessFile(file, "rw").getChannel()
val written = set.writeTo(channel, 0, 1024)
assertEquals("Expect to write the number of bytes in the set.", set.sizeInBytes, written)
val newSet = new FileMessageSet(file, channel)
checkEquals(set.iterator, newSet.iterator)
val channel = FileChannel.open(file.toPath, StandardOpenOption.READ, StandardOpenOption.WRITE)
try {
val written = write(channel)
assertEquals("Expect to write the number of bytes in the set.", set.sizeInBytes, written)
val newSet = new FileMessageSet(file, channel)
checkEquals(set.iterator, newSet.iterator)
} finally channel.close()
}
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,16 @@ class ByteBufferMessageSetTest extends BaseMessageSetTestCases {
messageTimestampType = TimestampType.CREATE_TIME,
messageTimestampDiffMaxMs = 5000L)._1, offset)
}

@Test
def testWriteFullyTo() {
checkWriteFullyToWithMessageSet(createMessageSet(Array[Message]()))
checkWriteFullyToWithMessageSet(createMessageSet(messages))
}

def checkWriteFullyToWithMessageSet(messageSet: ByteBufferMessageSet) {
checkWriteWithMessageSet(messageSet, messageSet.writeFullyTo)
}

/* check that offsets are assigned based on byte offset from the given base offset */
def checkOffsets(messages: ByteBufferMessageSet, baseOffset: Long) {
Expand Down

0 comments on commit 64842f4

Please sign in to comment.