Skip to content

Commit

Permalink
[SPARK-17491] Close serialization stream to fix wrong answer bug in p…
Browse files Browse the repository at this point in the history
…utIteratorAsBytes()

## What changes were proposed in this pull request?

`MemoryStore.putIteratorAsBytes()` may silently lose values when used with `KryoSerializer` because it does not properly close the serialization stream before attempting to deserialize the already-serialized values, which may cause values buffered in Kryo's internal buffers to not be read.

This is the root cause behind a user-reported "wrong answer" bug in PySpark caching reported by bennoleslie on the Spark user mailing list in a thread titled "pyspark persist MEMORY_ONLY vs MEMORY_AND_DISK". Due to Spark 2.0's automatic use of KryoSerializer for "safe" types (such as byte arrays, primitives, etc.) this misuse of serializers manifested itself as silent data corruption rather than a StreamCorrupted error (which you might get from JavaSerializer).

The minimal fix, implemented here, is to close the serialization stream before attempting to deserialize written values. In addition, this patch adds several additional assertions / precondition checks to prevent misuse of `PartiallySerializedBlock` and `ChunkedByteBufferOutputStream`.

## How was this patch tested?

The original bug was masked by an invalid assert in the memory store test cases: the old assert compared two results record-by-record with `zip` but didn't first check that the lengths of the two collections were equal, causing missing records to go unnoticed. The updated test case reproduced this bug.

In addition, I added a new `PartiallySerializedBlockSuite` to unit test that component.

Author: Josh Rosen <[email protected]>

Closes apache#15043 from JoshRosen/partially-serialized-block-values-iterator-bugfix.
  • Loading branch information
JoshRosen committed Sep 17, 2016
1 parent 86c2d39 commit 8faa521
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 44 deletions.
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ private[spark] object Task {
dataOut.flush()
val taskBytes = serializer.serialize(task)
Utils.writeByteBuffer(taskBytes, out)
out.close()
out.toByteBuffer
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.memory.{MemoryManager, MemoryMode}
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
import org.apache.spark.util.{SizeEstimator, Utils}
import org.apache.spark.util.collection.SizeTrackingVector
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

Expand Down Expand Up @@ -277,6 +277,7 @@ private[spark] class MemoryStore(
"released too much unroll memory")
Left(new PartiallyUnrolledIterator(
this,
MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = arrayValues.toIterator,
rest = Iterator.empty))
Expand All @@ -285,7 +286,11 @@ private[spark] class MemoryStore(
// We ran out of space while unrolling the values for this block
logUnrollFailureMessage(blockId, vector.estimateSize())
Left(new PartiallyUnrolledIterator(
this, unrollMemoryUsedByThisBlock, unrolled = vector.iterator, rest = values))
this,
MemoryMode.ON_HEAP,
unrollMemoryUsedByThisBlock,
unrolled = vector.iterator,
rest = values))
}
}

Expand Down Expand Up @@ -394,7 +399,7 @@ private[spark] class MemoryStore(
redirectableStream,
unrollMemoryUsedByThisBlock,
memoryMode,
bbos.toChunkedByteBuffer,
bbos,
values,
classTag))
}
Expand Down Expand Up @@ -655,20 +660,22 @@ private[spark] class MemoryStore(
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
*
* @param memoryStore the memoryStore, used for freeing memory.
* @param memoryMode the memory mode (on- or off-heap).
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param unrolled an iterator for the partially-unrolled values.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
*/
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
memoryMode: MemoryMode,
unrollMemory: Long,
private[this] var unrolled: Iterator[T],
rest: Iterator[T])
extends Iterator[T] {

private def releaseUnrollMemory(): Unit = {
memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
// SPARK-17503: Garbage collects the unrolling memory before the life end of
// PartiallyUnrolledIterator.
unrolled = null
Expand Down Expand Up @@ -706,7 +713,7 @@ private[storage] class PartiallyUnrolledIterator[T](
/**
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
*/
private class RedirectableOutputStream extends OutputStream {
private[storage] class RedirectableOutputStream extends OutputStream {
private[this] var os: OutputStream = _
def setOutputStream(s: OutputStream): Unit = { os = s }
override def write(b: Int): Unit = os.write(b)
Expand All @@ -726,7 +733,8 @@ private class RedirectableOutputStream extends OutputStream {
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
* @param memoryMode whether the unroll memory is on- or off-heap
* @param unrolled a byte buffer containing the partially-serialized values.
* @param bbos byte buffer output stream containing the partially-serialized values.
* [[redirectableOutputStream]] initially points to this output stream.
* @param rest the rest of the original iterator passed to
* [[MemoryStore.putIteratorAsValues()]].
* @param classTag the [[ClassTag]] for the block.
Expand All @@ -735,14 +743,19 @@ private[storage] class PartiallySerializedBlock[T](
memoryStore: MemoryStore,
serializerManager: SerializerManager,
blockId: BlockId,
serializationStream: SerializationStream,
redirectableOutputStream: RedirectableOutputStream,
unrollMemory: Long,
private val serializationStream: SerializationStream,
private val redirectableOutputStream: RedirectableOutputStream,
val unrollMemory: Long,
memoryMode: MemoryMode,
unrolled: ChunkedByteBuffer,
bbos: ChunkedByteBufferOutputStream,
rest: Iterator[T],
classTag: ClassTag[T]) {

private lazy val unrolledBuffer: ChunkedByteBuffer = {
bbos.close()
bbos.toChunkedByteBuffer
}

// If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of
// this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task
// completion listener here in order to ensure that `unrolled.dispose()` is called at least once.
Expand All @@ -751,23 +764,42 @@ private[storage] class PartiallySerializedBlock[T](
taskContext.addTaskCompletionListener { _ =>
// When a task completes, its unroll memory will automatically be freed. Thus we do not call
// releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing.
unrolled.dispose()
unrolledBuffer.dispose()
}
}

// Exposed for testing
private[storage] def getUnrolledChunkedByteBuffer: ChunkedByteBuffer = unrolledBuffer

private[this] var discarded = false
private[this] var consumed = false

private def verifyNotConsumedAndNotDiscarded(): Unit = {
if (consumed) {
throw new IllegalStateException(
"Can only call one of finishWritingToStream() or valuesIterator() and can only call once.")
}
if (discarded) {
throw new IllegalStateException("Cannot call methods on a discarded PartiallySerializedBlock")
}
}

/**
* Called to dispose of this block and free its memory.
*/
def discard(): Unit = {
try {
// We want to close the output stream in order to free any resources associated with the
// serializer itself (such as Kryo's internal buffers). close() might cause data to be
// written, so redirect the output stream to discard that data.
redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
serializationStream.close()
} finally {
unrolled.dispose()
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
if (!discarded) {
try {
// We want to close the output stream in order to free any resources associated with the
// serializer itself (such as Kryo's internal buffers). close() might cause data to be
// written, so redirect the output stream to discard that data.
redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream())
serializationStream.close()
} finally {
discarded = true
unrolledBuffer.dispose()
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
}
}
}

Expand All @@ -776,8 +808,10 @@ private[storage] class PartiallySerializedBlock[T](
* and then serializing the values from the original input iterator.
*/
def finishWritingToStream(os: OutputStream): Unit = {
verifyNotConsumedAndNotDiscarded()
consumed = true
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
ByteStreams.copy(unrolled.toInputStream(dispose = true), os)
ByteStreams.copy(unrolledBuffer.toInputStream(dispose = true), os)
memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory)
redirectableOutputStream.setOutputStream(os)
while (rest.hasNext) {
Expand All @@ -794,13 +828,22 @@ private[storage] class PartiallySerializedBlock[T](
* `close()` on it to free its resources.
*/
def valuesIterator: PartiallyUnrolledIterator[T] = {
verifyNotConsumedAndNotDiscarded()
consumed = true
// Close the serialization stream so that the serializer's internal buffers are freed and any
// "end-of-stream" markers can be written out so that `unrolled` is a valid serialized stream.
serializationStream.close()
// `unrolled`'s underlying buffers will be freed once this input stream is fully read:
val unrolledIter = serializerManager.dataDeserializeStream(
blockId, unrolled.toInputStream(dispose = true))(classTag)
blockId, unrolledBuffer.toInputStream(dispose = true))(classTag)
// The unroll memory will be freed once `unrolledIter` is fully consumed in
// PartiallyUnrolledIterator. If the iterator is not consumed by the end of the task then any
// extra unroll memory will automatically be freed by a `finally` block in `Task`.
new PartiallyUnrolledIterator(
memoryStore,
memoryMode,
unrollMemory,
unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()),
unrolled = unrolledIter,
rest = rest)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,32 @@ private[spark] class ByteBufferOutputStream(capacity: Int) extends ByteArrayOutp

def getCount(): Int = count

private[this] var closed: Boolean = false

override def write(b: Int): Unit = {
require(!closed, "cannot write to a closed ByteBufferOutputStream")
super.write(b)
}

override def write(b: Array[Byte], off: Int, len: Int): Unit = {
require(!closed, "cannot write to a closed ByteBufferOutputStream")
super.write(b, off, len)
}

override def reset(): Unit = {
require(!closed, "cannot reset a closed ByteBufferOutputStream")
super.reset()
}

override def close(): Unit = {
if (!closed) {
super.close()
closed = true
}
}

def toByteBuffer: ByteBuffer = {
return ByteBuffer.wrap(buf, 0, count)
require(closed, "can only call toByteBuffer() after ByteBufferOutputStream has been closed")
ByteBuffer.wrap(buf, 0, count)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,27 @@ private[spark] class ChunkedByteBufferOutputStream(
*/
private[this] var position = chunkSize
private[this] var _size = 0
private[this] var closed: Boolean = false

def size: Long = _size

override def close(): Unit = {
if (!closed) {
super.close()
closed = true
}
}

override def write(b: Int): Unit = {
require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
allocateNewChunkIfNeeded()
chunks(lastChunkIndex).put(b.toByte)
position += 1
_size += 1
}

override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
require(!closed, "cannot write to a closed ChunkedByteBufferOutputStream")
var written = 0
while (written < len) {
allocateNewChunkIfNeeded()
Expand All @@ -73,7 +83,6 @@ private[spark] class ChunkedByteBufferOutputStream(

@inline
private def allocateNewChunkIfNeeded(): Unit = {
require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called")
if (position == chunkSize) {
chunks += allocator(chunkSize)
lastChunkIndex += 1
Expand All @@ -82,6 +91,7 @@ private[spark] class ChunkedByteBufferOutputStream(
}

def toChunkedByteBuffer: ChunkedByteBuffer = {
require(closed, "cannot call toChunkedByteBuffer() unless close() has been called")
require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once")
toChunkedByteBufferWasCalled = true
if (lastChunkIndex == -1) {
Expand Down
34 changes: 16 additions & 18 deletions core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ class MemoryStoreSuite
(memoryStore, blockInfoManager)
}

private def assertSameContents[T](expected: Seq[T], actual: Seq[T], hint: String): Unit = {
assert(actual.length === expected.length, s"wrong number of values returned in $hint")
expected.iterator.zip(actual.iterator).foreach { case (e, a) =>
assert(e === a, s"$hint did not return original values!")
}
}

test("reserve/release unroll memory") {
val (memoryStore, _) = makeMemoryStore(12000)
assert(memoryStore.currentUnrollMemory === 0)
Expand Down Expand Up @@ -137,9 +144,7 @@ class MemoryStoreSuite
var putResult = putIteratorAsValues("unroll", smallList.iterator, ClassTag.Any)
assert(putResult.isRight)
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
assert(e === a, "getValues() did not return original values!")
}
assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
Expand All @@ -152,9 +157,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
assert(memoryStore.contains("someBlock2"))
assert(!memoryStore.contains("someBlock1"))
smallList.iterator.zip(memoryStore.getValues("unroll").get).foreach { case (e, a) =>
assert(e === a, "getValues() did not return original values!")
}
assertSameContents(smallList, memoryStore.getValues("unroll").get.toSeq, "getValues")
blockInfoManager.lockForWriting("unroll")
assert(memoryStore.remove("unroll"))
blockInfoManager.removeBlock("unroll")
Expand All @@ -167,9 +170,7 @@ class MemoryStoreSuite
assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
assert(!memoryStore.contains("someBlock2"))
assert(putResult.isLeft)
bigList.iterator.zip(putResult.left.get).foreach { case (e, a) =>
assert(e === a, "putIterator() did not return original values!")
}
assertSameContents(bigList, putResult.left.get.toSeq, "putIterator")
// The unroll memory was freed once the iterator returned by putIterator() was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
Expand Down Expand Up @@ -316,9 +317,8 @@ class MemoryStoreSuite
assert(res.isLeft)
assert(memoryStore.currentUnrollMemoryForThisTask > 0)
val valuesReturnedFromFailedPut = res.left.get.valuesIterator.toSeq // force materialization
valuesReturnedFromFailedPut.zip(bigList).foreach { case (e, a) =>
assert(e === a, "PartiallySerializedBlock.valuesIterator() did not return original values!")
}
assertSameContents(
bigList, valuesReturnedFromFailedPut, "PartiallySerializedBlock.valuesIterator()")
// The unroll memory was freed once the iterator was fully traversed.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
}
Expand All @@ -340,12 +340,10 @@ class MemoryStoreSuite
res.left.get.finishWritingToStream(bos)
// The unroll memory was freed once the block was fully written.
assert(memoryStore.currentUnrollMemoryForThisTask === 0)
val deserializationStream = serializerManager.dataDeserializeStream[Any](
"b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any)
deserializationStream.zip(bigList.iterator).foreach { case (e, a) =>
assert(e === a,
"PartiallySerializedBlock.finishWritingtoStream() did not write original values!")
}
val deserializedValues = serializerManager.dataDeserializeStream[Any](
"b1", new ByteBufferInputStream(bos.toByteBuffer))(ClassTag.Any).toSeq
assertSameContents(
bigList, deserializedValues, "PartiallySerializedBlock.finishWritingToStream()")
}

test("multiple unrolls by the same thread") {
Expand Down
Loading

0 comments on commit 8faa521

Please sign in to comment.