Skip to content

Commit

Permalink
SPARK-1480: Clean up use of classloaders
Browse files Browse the repository at this point in the history
The Spark codebase is a bit fast-and-loose when accessing classloaders and this has caused a few bugs to surface in master.

This patch defines some utility methods for accessing classloaders. This makes the intention when accessing a classloader much more explicit in the code and fixes a few cases where the wrong one was chosen.

case (a) -> We want the classloader that loaded Spark
case (b) -> We want the context class loader, or if not present, we want (a)

This patch provides a better fix for SPARK-1403 (https://issues.apache.org/jira/browse/SPARK-1403) than the current work around, which it reverts. It also fixes a previously unreported bug that the `./spark-submit` script did not work for running with `local` master. It didn't work because the executor classloader did not properly delegate to the context class loader (if it is defined) and in local mode the context class loader is set by the `./spark-submit` script. A unit test is added for that case.

Author: Patrick Wendell <[email protected]>

Closes #398 from pwendell/class-loaders and squashes the following commits:

b4a1a58 [Patrick Wendell] Minor clean up
14f1272 [Patrick Wendell] SPARK-1480: Clean up use of classloaders
  • Loading branch information
pwendell committed Apr 13, 2014
1 parent ca11919 commit 4bc07ee
Show file tree
Hide file tree
Showing 15 changed files with 78 additions and 35 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.impl.StaticLoggerBinder

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -115,8 +116,7 @@ trait Logging {
val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
if (!log4jInitialized && usingLog4j) {
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
val classLoader = this.getClass.getClassLoader
Option(classLoader.getResource(defaultLogProps)) match {
Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
case Some(url) =>
PropertyConfigurator.configure(url)
log.info(s"Using Spark's default log4j profile: $defaultLogProps")
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): MutableURLClassLoader = {
val loader = this.getClass.getClassLoader
val currentLoader = Utils.getContextOrSparkClassLoader

// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
Expand All @@ -301,8 +301,8 @@ private[spark] class Executor(
}.toArray
val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
userClassPathFirst match {
case true => new ChildExecutorURLClassLoader(urls, loader)
case false => new ExecutorURLClassLoader(urls, loader)
case true => new ChildExecutorURLClassLoader(urls, currentLoader)
case false => new ExecutorURLClassLoader(urls, currentLoader)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,13 @@ private[spark] class MesosExecutorBackend
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
val cl = Thread.currentThread.getContextClassLoader
try {
// Work around for SPARK-1480
Thread.currentThread.setContextClassLoader(getClass.getClassLoader)
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties)
} finally {
// Work around for SPARK-1480
Thread.currentThread.setContextClassLoader(cl)
}
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties)
}

override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.collection.mutable
import scala.util.matching.Regex

import org.apache.spark.Logging
import org.apache.spark.util.Utils

private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {

Expand All @@ -50,7 +51,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
try {
is = configFile match {
case Some(f) => new FileInputStream(f)
case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF)
case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF)
}

if (is != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ private[spark] object ResultTask {

def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
{
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.{NoSuchElementException, Properties}
import scala.xml.XML

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.Utils

/**
* An interface to build Schedulable tree
Expand Down Expand Up @@ -72,7 +73,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf)
schedulerAllocFile.map { f =>
new FileInputStream(f)
}.getOrElse {
getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
try {
if (serializedData != null && serializedData.limit() > 0) {
reason = serializer.get().deserialize[TaskEndReason](
serializedData, getClass.getClassLoader)
serializedData, Utils.getSparkClassLoader)
}
} catch {
case cnd: ClassNotFoundException =>
// Log an error but keep going here -- the task failed, so not catastropic if we can't
// deserialize the reason.
val loader = Thread.currentThread.getContextClassLoader
val loader = Utils.getContextOrSparkClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
case ex: Throwable => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils

private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
extends SerializationStream {
Expand Down Expand Up @@ -86,7 +87,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
}

def deserializeStream(s: InputStream): DeserializationStream = {
new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader)
new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
}

def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}

import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.util.Utils

/**
* Utilities for launching a web server using Jetty's HTTP Server class
Expand Down Expand Up @@ -124,7 +125,7 @@ private[spark] object JettyUtils extends Logging {
contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false")
val staticHandler = new DefaultServlet
val holder = new ServletHolder(staticHandler)
Option(getClass.getClassLoader.getResource(resourceBase)) match {
Option(Utils.getSparkClassLoader.getResource(resourceBase)) match {
case Some(res) =>
holder.setInitParameter("resourceBase", res.toString)
case None =>
Expand Down
15 changes: 15 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ private[spark] object Utils extends Logging {
}
}

/**
* Get the ClassLoader which loaded Spark.
*/
def getSparkClassLoader = getClass.getClassLoader

/**
* Get the Context ClassLoader on this thread or, if not present, the ClassLoader that
* loaded Spark.
*
* This should be used whenever passing a ClassLoader to Class.ForName or finding the currently
* active loader when setting up ClassLoader delegation chains.
*/
def getContextOrSparkClassLoader =
Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader)

/**
* Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.executor

import java.io.File
import java.net.URLClassLoader

import org.scalatest.FunSuite

import org.apache.spark.TestUtils
import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils}
import org.apache.spark.util.Utils

class ExecutorURLClassLoaderSuite extends FunSuite {

Expand Down Expand Up @@ -63,5 +63,33 @@ class ExecutorURLClassLoaderSuite extends FunSuite {
}
}

test("driver sets context class loader in local mode") {
// Test the case where the driver program sets a context classloader and then runs a job
// in local mode. This is what happens when ./spark-submit is called with "local" as the
// master.
val original = Thread.currentThread().getContextClassLoader

val className = "ClassForDriverTest"
val jar = TestUtils.createJarWithClasses(Seq(className))
val contextLoader = new URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
Thread.currentThread().setContextClassLoader(contextLoader)

val sc = new SparkContext("local", "driverLoaderTest")

try {
sc.makeRDD(1 to 5, 2).mapPartitions { x =>
val loader = Thread.currentThread().getContextClassLoader
Class.forName(className, true, loader).newInstance()
Seq().iterator
}.count()
}
catch {
case e: SparkException if e.getMessage.contains("ClassNotFoundException") =>
fail("Local executor could not find class", e)
case t: Throwable => fail("Unexpected exception ", t)
}

sc.stop()
Thread.currentThread().setContextClassLoader(original)
}
}
7 changes: 4 additions & 3 deletions repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse}
import org.apache.spark.Logging
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.util.Utils

/** The Scala interactive shell. It provides a read-eval-print loop
* around the Interpreter class.
Expand Down Expand Up @@ -130,7 +131,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
def history = in.history

/** The context class loader at the time this object was created */
protected val originalClassLoader = Thread.currentThread.getContextClassLoader
protected val originalClassLoader = Utils.getContextOrSparkClassLoader

// classpath entries added via :cp
var addedClasspath: String = ""
Expand Down Expand Up @@ -177,7 +178,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
override lazy val formatting = new Formatting {
def prompt = SparkILoop.this.prompt
}
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
}

/** Create a new interpreter. */
Expand Down Expand Up @@ -871,7 +872,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter,
}

val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
val m = u.runtimeMirror(getClass.getClassLoader)
val m = u.runtimeMirror(Utils.getSparkClassLoader)
private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
u.TypeTag[T](
m,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst

import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File}

import org.apache.spark.util.{Utils => SparkUtils}

package object util {
/**
* Returns a path to a temporary file that probably does not exist.
Expand Down Expand Up @@ -54,7 +56,7 @@ package object util {
def resourceToString(
resource:String,
encoding: String = "UTF-8",
classLoader: ClassLoader = this.getClass.getClassLoader) = {
classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = {
val inStream = classLoader.getResourceAsStream(resource)
val outStream = new ByteArrayOutputStream
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.reflect.runtime.universe.runtimeMirror
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar._
import org.apache.spark.util.Utils

private[sql] case object PassThrough extends CompressionScheme {
override val typeId = 0
Expand Down Expand Up @@ -254,7 +255,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
private val dictionary = {
// TODO Can we clean up this mess? Maybe move this to `DataType`?
implicit val classTag = {
val mirror = runtimeMirror(getClass.getClassLoader)
val mirror = runtimeMirror(Utils.getSparkClassLoader)
ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.{Serializer, Kryo}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils

class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
Expand All @@ -44,7 +45,7 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
kryo.setClassLoader(this.getClass.getClassLoader)
kryo.setClassLoader(Utils.getSparkClassLoader)
kryo
}
}
Expand Down

0 comments on commit 4bc07ee

Please sign in to comment.