Skip to content

Commit

Permalink
[SPARK-17110] Fix StreamCorruptionException in BlockManager.getRemote…
Browse files Browse the repository at this point in the history
…Values()

## What changes were proposed in this pull request?

This patch fixes a `java.io.StreamCorruptedException` error affecting remote reads of cached values when certain data types are used. The problem stems from apache#11801 / SPARK-13990, a patch to have Spark automatically pick the "best" serializer when caching RDDs. If PySpark cached a PythonRDD, then this would be cached as an `RDD[Array[Byte]]` and the automatic serializer selection would pick KryoSerializer for replication and block transfer. However, the `getRemoteValues()` / `getRemoteBytes()` code path did not pass proper class tags in order to enable the same serializer to be used during deserialization, causing Java to be inappropriately used instead of Kryo, leading to the StreamCorruptedException.

We already fixed a similar bug in apache#14311, which dealt with similar issues in block replication. Prior to that patch, it seems that we had no tests to ensure that block replication actually succeeded. Similarly, prior to this bug fix patch it looks like we had no tests to perform remote reads of cached data, which is why this bug was able to remain latent for so long.

This patch addresses the bug by modifying `BlockManager`'s `get()` and  `getRemoteValues()` methods to accept ClassTags, allowing the proper class tag to be threaded in the `getOrElseUpdate` code path (which is used by `rdd.iterator`)

## How was this patch tested?

Extended the caching tests in `DistributedSuite` to exercise the `getRemoteValues` path, plus manual testing to verify that the PySpark bug reproduction in SPARK-17110 is fixed.

Author: Josh Rosen <[email protected]>

Closes apache#14952 from JoshRosen/SPARK-17110.
  • Loading branch information
JoshRosen committed Sep 6, 2016
1 parent 8bbb08a commit 29cfab3
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo
assertValid()
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
blockManager.get(blockId) match {
blockManager.get[T](blockId) match {
case Some(block) => block.data.asInstanceOf[Iterator[T]]
case None =>
throw new Exception("Could not compute split, block " + blockId + " not found")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,12 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* Deserializes an InputStream into an iterator of values and disposes of it when the end of
* the iterator is reached.
*/
def dataDeserializeStream[T: ClassTag](
def dataDeserializeStream[T](
blockId: BlockId,
inputStream: InputStream): Iterator[T] = {
inputStream: InputStream)
(classTag: ClassTag[T]): Iterator[T] = {
val stream = new BufferedInputStream(inputStream)
getSerializer(implicitly[ClassTag[T]])
getSerializer(classTag)
.newInstance()
.deserializeStream(wrapStream(blockId, stream))
.asIterator.asInstanceOf[Iterator[T]]
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,11 @@ private[spark] class BlockManager(
*
* This does not acquire a lock on this block in this JVM.
*/
private def getRemoteValues(blockId: BlockId): Option[BlockResult] = {
private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
val ct = implicitly[ClassTag[T]]
getRemoteBytes(blockId).map { data =>
val values =
serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))
serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true))(ct)
new BlockResult(values, DataReadMethod.Network, data.size)
}
}
Expand Down Expand Up @@ -602,13 +603,13 @@ private[spark] class BlockManager(
* any locks if the block was fetched from a remote block manager. The read lock will
* automatically be freed once the result's `data` iterator is fully consumed.
*/
def get(blockId: BlockId): Option[BlockResult] = {
def get[T: ClassTag](blockId: BlockId): Option[BlockResult] = {
val local = getLocalValues(blockId)
if (local.isDefined) {
logInfo(s"Found block $blockId locally")
return local
}
val remote = getRemoteValues(blockId)
val remote = getRemoteValues[T](blockId)
if (remote.isDefined) {
logInfo(s"Found block $blockId remotely")
return remote
Expand Down Expand Up @@ -660,7 +661,7 @@ private[spark] class BlockManager(
makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = {
// Attempt to read the block from local or remote storage. If it's present, then we don't need
// to go through the local-get-or-put path.
get(blockId) match {
get[T](blockId)(classTag) match {
case Some(block) =>
return Left(block)
case _ =>
Expand Down Expand Up @@ -1204,8 +1205,8 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
def getSingle(blockId: BlockId): Option[Any] = {
get(blockId).map(_.data.next())
def getSingle[T: ClassTag](blockId: BlockId): Option[T] = {
get[T](blockId).map(_.data.next().asInstanceOf[T])
}

/**
Expand Down
6 changes: 4 additions & 2 deletions core/src/test/scala/org/apache/spark/DistributedSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
blockManager.master.getLocations(blockId).foreach { cmId =>
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
blockId.toString)
val deserialized = serializerManager.dataDeserializeStream[Int](blockId,
new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList
val deserialized = serializerManager.dataDeserializeStream(blockId,
new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList
assert(deserialized === (1 to 100).toList)
}
// This will exercise the getRemoteBytes / getRemoteValues code paths:
assert(blockIds.flatMap(id => blockManager.get[Int](id).get.data).toSet === (1 to 1000).toSet)
}

Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
val blockId = partition.blockId

def getBlockFromBlockManager(): Option[Iterator[T]] = {
blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
blockManager.get[T](blockId).map(_.data.asInstanceOf[Iterator[T]])
}

def getBlockFromWriteAheadLog(): Iterator[T] = {
Expand Down Expand Up @@ -163,7 +163,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
dataRead.rewind()
}
serializerManager
.dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream())
.dataDeserializeStream(
blockId, new ChunkedByteBuffer(dataRead).toInputStream())(elementClassTag)
.asInstanceOf[Iterator[T]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag

import org.apache.hadoop.conf.Configuration
import org.scalatest.{BeforeAndAfter, Matchers}
Expand Down Expand Up @@ -163,7 +164,7 @@ class ReceivedBlockHandlerSuite
val bytes = reader.read(fileSegment)
reader.close()
serializerManager.dataDeserializeStream(
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList
generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream())(ClassTag.Any).toList
}
loggedData shouldEqual data
}
Expand Down

0 comments on commit 29cfab3

Please sign in to comment.