Skip to content

Commit

Permalink
SPARK-1708. Add a ClassTag on Serializer and things that depend on it
Browse files Browse the repository at this point in the history
This pull request contains a rebased patch from @heathermiller (heathermiller#1) to add ClassTags on Serializer and types that depend on it (Broadcast and AccumulableCollection). Putting these in the public API signatures now will allow us to use Scala Pickling for serialization down the line without breaking binary compatibility.

One question remaining is whether we also want them on Accumulator -- Accumulator is passed as part of a bigger Task or TaskResult object via the closure serializer so it doesn't seem super useful to add the ClassTag there. Broadcast and AccumulableCollection in contrast were being serialized directly.

CC @rxin, @pwendell, @heathermiller

Author: Matei Zaharia <[email protected]>

Closes apache#700 from mateiz/spark-1708 and squashes the following commits:

1a3d8b0 [Matei Zaharia] Use fake ClassTag in Java
3b449ed [Matei Zaharia] test fix
2209a27 [Matei Zaharia] Code style fixes
9d48830 [Matei Zaharia] Add a ClassTag on Serializer and things that depend on it
  • Loading branch information
mateiz authored and pwendell committed May 10, 2014
1 parent 8e94d27 commit 7eefc9d
Show file tree
Hide file tree
Showing 22 changed files with 103 additions and 72 deletions.
7 changes: 4 additions & 3 deletions core/src/main/scala/org/apache/spark/Accumulators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{ObjectInputStream, Serializable}

import scala.collection.generic.Growable
import scala.collection.mutable.Map
import scala.reflect.ClassTag

import org.apache.spark.serializer.JavaSerializer

Expand Down Expand Up @@ -164,9 +165,9 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R
}

private[spark]
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
extends AccumulableParam[R,T] {
private[spark] class
GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
extends AccumulableParam[R, T] {

def addAccumulator(growable: R, elem: T): R = {
growable += elem
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
(initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
Expand All @@ -767,7 +767,7 @@ class SparkContext(config: SparkConf) extends Logging {
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T): Broadcast[T] = {
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)(fakeClassTag)

/** Shut down the SparkContext. */
def stop() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.io.Serializable

import org.apache.spark.SparkException

import scala.reflect.ClassTag

/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
* cached on each machine rather than shipping a copy of it with tasks. They can be used, for
Expand Down Expand Up @@ -50,7 +52,7 @@ import org.apache.spark.SparkException
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
abstract class Broadcast[T](val id: Long) extends Serializable {
abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {

/**
* Flag signifying whether the broadcast variable is valid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.broadcast

import scala.reflect.ClassTag

import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
Expand All @@ -31,7 +33,7 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
def stop(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.broadcast

import java.util.concurrent.atomic.AtomicLong

import scala.reflect.ClassTag

import org.apache.spark._

private[spark] class BroadcastManager(
Expand Down Expand Up @@ -56,7 +58,7 @@ private[spark] class BroadcastManager(

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) = {
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream}
import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit

import scala.reflect.ClassTag

import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
Expand All @@ -34,7 +36,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
* (through a HTTP server running at the driver) and stored in the BlockManager of the
* executor to speed up future accesses.
*/
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class HttpBroadcast[T: ClassTag](
@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

def getValue = value_
Expand Down Expand Up @@ -173,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging {
files += file.getAbsolutePath
}

def read[T](id: Long): T = {
def read[T: ClassTag](id: Long): T = {
logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.broadcast

import scala.reflect.ClassTag

import org.apache.spark.{SecurityManager, SparkConf}

/**
Expand All @@ -29,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)

def stop() { HttpBroadcast.stop() }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.broadcast

import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}

import scala.reflect.ClassTag
import scala.math
import scala.util.Random

Expand All @@ -44,7 +45,8 @@ import org.apache.spark.util.Utils
* copies of the broadcast data (one per executor) as done by the
* [[org.apache.spark.broadcast.HttpBroadcast]].
*/
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class TorrentBroadcast[T: ClassTag](
@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

def getValue = value_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.broadcast

import scala.reflect.ClassTag

import org.apache.spark.{SecurityManager, SparkConf}

/**
Expand All @@ -30,7 +32,7 @@ class TorrentBroadcastFactory extends BroadcastFactory {
TorrentBroadcast.initialize(isDriver, conf)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)

def stop() { TorrentBroadcast.stop() }
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private[spark] object CheckpointRDD extends Logging {
"part-%05d".format(splitId)
}

def writeToFile[T](
def writeToFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
blockSize: Int = -1
Expand Down Expand Up @@ -160,7 +160,7 @@ private[spark] object CheckpointRDD extends Logging {
val conf = SparkHadoopUtil.get.newConfiguration()
val fs = path.getFileSystem(conf)
val broadcastedConf = sc.broadcast(new SerializableWritable(conf))
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _)
sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
slice = in.readInt()

val ser = sfactory.newInstance()
Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject())
Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject[Seq[T]]())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration))
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _)
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (newRDD.partitions.size != rdd.partitions.size) {
throw new SparkException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.serializer
import java.io._
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
Expand All @@ -36,7 +38,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
* But only call it every 10,000th time to avoid bloated serialization streams (when
* the stream 'resets' object class descriptions have to be re-written)
*/
def writeObject[T](t: T): SerializationStream = {
def writeObject[T: ClassTag](t: T): SerializationStream = {
objOut.writeObject(t)
if (counterReset > 0 && counter >= counterReset) {
objOut.reset()
Expand All @@ -46,6 +48,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
}
this
}

def flush() { objOut.flush() }
def close() { objOut.close() }
}
Expand All @@ -57,26 +60,26 @@ extends DeserializationStream {
Class.forName(desc.getName, false, loader)
}

def readObject[T](): T = objIn.readObject().asInstanceOf[T]
def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
def close() { objIn.close() }
}

private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
out.writeObject(t)
out.close()
ByteBuffer.wrap(bos.toByteArray)
}

def deserialize[T](bytes: ByteBuffer): T = {
def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis)
in.readObject().asInstanceOf[T]
}

def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis, loader)
in.readObject().asInstanceOf[T]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage._
import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}

import scala.reflect.ClassTag

/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*
Expand Down Expand Up @@ -95,7 +97,7 @@ private[spark]
class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
val output = new KryoOutput(outStream)

def writeObject[T](t: T): SerializationStream = {
def writeObject[T: ClassTag](t: T): SerializationStream = {
kryo.writeClassAndObject(output, t)
this
}
Expand All @@ -108,7 +110,7 @@ private[spark]
class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
val input = new KryoInput(inStream)

def readObject[T](): T = {
def readObject[T: ClassTag](): T = {
try {
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
Expand All @@ -131,18 +133,18 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
lazy val output = ks.newKryoOutput()
lazy val input = new KryoInput()

def serialize[T](t: T): ByteBuffer = {
def serialize[T: ClassTag](t: T): ByteBuffer = {
output.clear()
kryo.writeClassAndObject(output, t)
ByteBuffer.wrap(output.toBytes)
}

def deserialize[T](bytes: ByteBuffer): T = {
def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
input.setBuffer(bytes.array)
kryo.readClassAndObject(input).asInstanceOf[T]
}

def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
Expand Down
17 changes: 9 additions & 8 deletions core/src/main/scala/org/apache/spark/serializer/Serializer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.serializer
import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.SparkEnv
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
Expand Down Expand Up @@ -59,17 +61,17 @@ object Serializer {
*/
@DeveloperApi
trait SerializerInstance {
def serialize[T](t: T): ByteBuffer
def serialize[T: ClassTag](t: T): ByteBuffer

def deserialize[T](bytes: ByteBuffer): T
def deserialize[T: ClassTag](bytes: ByteBuffer): T

def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T

def serializeStream(s: OutputStream): SerializationStream

def deserializeStream(s: InputStream): DeserializationStream

def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = {
// Default implementation uses serializeStream
val stream = new ByteArrayOutputStream()
serializeStream(stream).writeAll(iterator)
Expand All @@ -85,18 +87,17 @@ trait SerializerInstance {
}
}


/**
* :: DeveloperApi ::
* A stream for writing serialized objects.
*/
@DeveloperApi
trait SerializationStream {
def writeObject[T](t: T): SerializationStream
def writeObject[T: ClassTag](t: T): SerializationStream
def flush(): Unit
def close(): Unit

def writeAll[T](iter: Iterator[T]): SerializationStream = {
def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
while (iter.hasNext) {
writeObject(iter.next())
}
Expand All @@ -111,7 +112,7 @@ trait SerializationStream {
*/
@DeveloperApi
trait DeserializationStream {
def readObject[T](): T
def readObject[T: ClassTag](): T
def close(): Unit

/**
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ private[spark] object Utils extends Logging {
/**
* Clone an object using a Spark serializer.
*/
def clone[T](value: T, serializer: SerializerInstance): T = {
def clone[T: ClassTag](value: T, serializer: SerializerInstance): T = {
serializer.deserialize[T](serializer.serialize(value))
}

Expand Down
Loading

0 comments on commit 7eefc9d

Please sign in to comment.