Skip to content

Commit

Permalink
[SPARK-19540][SQL] Add ability to clone SparkSession wherein cloned s…
Browse files Browse the repository at this point in the history
…ession has an identical copy of the SessionState

Forking a newSession() from SparkSession currently makes a new SparkSession that does not retain SessionState (i.e. temporary tables, SQL config, registered functions etc.) This change adds a method cloneSession() which creates a new SparkSession with a copy of the parent's SessionState.

Subsequent changes to base session are not propagated to cloned session, clone is independent after creation.
If the base is changed after clone has been created, say user registers new UDF, then the new UDF will not be available inside the clone. Same goes for configs and temp tables.

Unit tests

Author: Kunal Khamar <[email protected]>
Author: Shixiong Zhu <[email protected]>

Closes apache#16826 from kunalkhamar/fork-sparksession.
  • Loading branch information
kunalkhamar authored and zsxwing committed Mar 8, 2017
1 parent 1bf9012 commit 6570cfd
Show file tree
Hide file tree
Showing 20 changed files with 981 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ trait CatalystConf {

/** The maximum number of joined nodes allowed in the dynamic programming algorithm. */
def joinReorderDPThreshold: Int

override def clone(): CatalystConf = throw new CloneNotSupportedException()
}


Expand All @@ -85,4 +87,7 @@ case class SimpleCatalystConf(
joinReorderDPThreshold: Int = 12,
warehousePath: String = "/user/hive/warehouse",
sessionLocalTimeZone: String = TimeZone.getDefault().getID)
extends CatalystConf
extends CatalystConf {

override def clone(): SimpleCatalystConf = this.copy()
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ trait FunctionRegistry {
/** Clear all registered functions. */
def clear(): Unit

/** Create a copy of this registry with identical functions as this registry. */
override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
}

class SimpleFunctionRegistry extends FunctionRegistry {
Expand Down Expand Up @@ -107,7 +109,7 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.clear()
}

def copy(): SimpleFunctionRegistry = synchronized {
override def clone(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
registry.registerFunction(name, info, builder)
Expand Down Expand Up @@ -150,6 +152,7 @@ object EmptyFunctionRegistry extends FunctionRegistry {
throw new UnsupportedOperationException
}

override def clone(): FunctionRegistry = this
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ object SessionCatalog {
class SessionCatalog(
externalCatalog: ExternalCatalog,
globalTempViewManager: GlobalTempViewManager,
functionResourceLoader: FunctionResourceLoader,
functionRegistry: FunctionRegistry,
conf: CatalystConf,
hadoopConf: Configuration,
Expand All @@ -66,16 +65,19 @@ class SessionCatalog(
this(
externalCatalog,
new GlobalTempViewManager("global_temp"),
DummyFunctionResourceLoader,
functionRegistry,
conf,
new Configuration(),
CatalystSqlParser)
functionResourceLoader = DummyFunctionResourceLoader
}

// For testing only.
def this(externalCatalog: ExternalCatalog) {
this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true))
this(
externalCatalog,
new SimpleFunctionRegistry,
SimpleCatalystConf(caseSensitiveAnalysis = true))
}

/** List of temporary tables, mapping from table name to their logical plan. */
Expand All @@ -89,6 +91,8 @@ class SessionCatalog(
@GuardedBy("this")
protected var currentDb = formatDatabaseName(DEFAULT_DATABASE)

@volatile var functionResourceLoader: FunctionResourceLoader = _

/**
* Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"),
* i.e. if this name only contains characters, numbers, and _.
Expand Down Expand Up @@ -987,6 +991,9 @@ class SessionCatalog(
* by a tuple (resource type, resource uri).
*/
def loadFunctionResources(resources: Seq[FunctionResource]): Unit = {
if (functionResourceLoader == null) {
throw new IllegalStateException("functionResourceLoader has not yet been initialized")
}
resources.foreach(functionResourceLoader.loadResource)
}

Expand Down Expand Up @@ -1182,4 +1189,29 @@ class SessionCatalog(
}
}

/**
* Create a new [[SessionCatalog]] with the provided parameters. `externalCatalog` and
* `globalTempViewManager` are `inherited`, while `currentDb` and `tempTables` are copied.
*/
def newSessionCatalogWith(
conf: CatalystConf,
hadoopConf: Configuration,
functionRegistry: FunctionRegistry,
parser: ParserInterface): SessionCatalog = {
val catalog = new SessionCatalog(
externalCatalog,
globalTempViewManager,
functionRegistry,
conf,
hadoopConf,
parser)

synchronized {
catalog.currentDb = currentDb
// copy over temporary tables
tempTables.foreach(kv => catalog.tempTables.put(kv._1, kv._2))
}

catalog
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.catalog

import org.apache.hadoop.conf.Configuration

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
Expand Down Expand Up @@ -1197,6 +1199,59 @@ class SessionCatalogSuite extends PlanTest {
}
}

test("clone SessionCatalog - temp views") {
val externalCatalog = newEmptyCatalog()
val original = new SessionCatalog(externalCatalog)
val tempTable1 = Range(1, 10, 1, 10)
original.createTempView("copytest1", tempTable1, overrideIfExists = false)

// check if tables copied over
val clone = original.newSessionCatalogWith(
SimpleCatalystConf(caseSensitiveAnalysis = true),
new Configuration(),
new SimpleFunctionRegistry,
CatalystSqlParser)
assert(original ne clone)
assert(clone.getTempView("copytest1") == Some(tempTable1))

// check if clone and original independent
clone.dropTable(TableIdentifier("copytest1"), ignoreIfNotExists = false, purge = false)
assert(original.getTempView("copytest1") == Some(tempTable1))

val tempTable2 = Range(1, 20, 2, 10)
original.createTempView("copytest2", tempTable2, overrideIfExists = false)
assert(clone.getTempView("copytest2").isEmpty)
}

test("clone SessionCatalog - current db") {
val externalCatalog = newEmptyCatalog()
val db1 = "db1"
val db2 = "db2"
val db3 = "db3"

externalCatalog.createDatabase(newDb(db1), ignoreIfExists = true)
externalCatalog.createDatabase(newDb(db2), ignoreIfExists = true)
externalCatalog.createDatabase(newDb(db3), ignoreIfExists = true)

val original = new SessionCatalog(externalCatalog)
original.setCurrentDatabase(db1)

// check if current db copied over
val clone = original.newSessionCatalogWith(
SimpleCatalystConf(caseSensitiveAnalysis = true),
new Configuration(),
new SimpleFunctionRegistry,
CatalystSqlParser)
assert(original ne clone)
assert(clone.getCurrentDatabase == db1)

// check if clone and original independent
clone.setCurrentDatabase(db2)
assert(original.getCurrentDatabase == db1)
original.setCurrentDatabase(db3)
assert(clone.getCurrentDatabase == db2)
}

test("SPARK-19737: detect undefined functions without triggering relation resolution") {
import org.apache.spark.sql.catalyst.dsl.plans._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,10 @@ class ExperimentalMethods private[sql]() {

@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil

override def clone(): ExperimentalMethods = {
val result = new ExperimentalMethods
result.extraStrategies = extraStrategies
result.extraOptimizations = extraOptimizations
result
}
}
59 changes: 45 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

Expand All @@ -43,7 +42,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.types.{DataType, LongType, StructType}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.Utils

Expand All @@ -67,15 +66,22 @@ import org.apache.spark.util.Utils
* .config("spark.some.config.option", "some-value")
* .getOrCreate()
* }}}
*
* @param sparkContext The Spark context associated with this Spark session.
* @param existingSharedState If supplied, use the existing shared state
* instead of creating a new one.
* @param parentSessionState If supplied, inherit all session state (i.e. temporary
* views, SQL config, UDFs etc) from parent.
*/
@InterfaceStability.Stable
class SparkSession private(
@transient val sparkContext: SparkContext,
@transient private val existingSharedState: Option[SharedState])
@transient private val existingSharedState: Option[SharedState],
@transient private val parentSessionState: Option[SessionState])
extends Serializable with Closeable with Logging { self =>

private[sql] def this(sc: SparkContext) {
this(sc, None)
this(sc, None, None)
}

sparkContext.assertNotStopped()
Expand Down Expand Up @@ -108,6 +114,7 @@ class SparkSession private(
/**
* State isolated across sessions, including SQL configurations, temporary tables, registered
* functions, and everything else that accepts a [[org.apache.spark.sql.internal.SQLConf]].
* If `parentSessionState` is not null, the `SessionState` will be a copy of the parent.
*
* This is internal to Spark and there is no guarantee on interface stability.
*
Expand All @@ -116,9 +123,13 @@ class SparkSession private(
@InterfaceStability.Unstable
@transient
lazy val sessionState: SessionState = {
SparkSession.reflect[SessionState, SparkSession](
SparkSession.sessionStateClassName(sparkContext.conf),
self)
parentSessionState
.map(_.clone(this))
.getOrElse {
SparkSession.instantiateSessionState(
SparkSession.sessionStateClassName(sparkContext.conf),
self)
}
}

/**
Expand Down Expand Up @@ -208,7 +219,25 @@ class SparkSession private(
* @since 2.0.0
*/
def newSession(): SparkSession = {
new SparkSession(sparkContext, Some(sharedState))
new SparkSession(sparkContext, Some(sharedState), parentSessionState = None)
}

/**
* Create an identical copy of this `SparkSession`, sharing the underlying `SparkContext`
* and shared state. All the state of this session (i.e. SQL configurations, temporary tables,
* registered functions) is copied over, and the cloned session is set up with the same shared
* state as this session. The cloned session is independent of this session, that is, any
* non-global change in either session is not reflected in the other.
*
* @note Other than the `SparkContext`, all shared state is initialized lazily.
* This method will force the initialization of the shared state to ensure that parent
* and child sessions are set up with the same shared state. If the underlying catalog
* implementation is Hive, this will initialize the metastore, which may take some time.
*/
private[sql] def cloneSession(): SparkSession = {
val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState))
result.sessionState // force copy of SessionState
result
}


Expand Down Expand Up @@ -971,16 +1000,18 @@ object SparkSession {
}

/**
* Helper method to create an instance of [[T]] using a single-arg constructor that
* accepts an [[Arg]].
* Helper method to create an instance of `SessionState` based on `className` from conf.
* The result is either `SessionState` or `HiveSessionState`.
*/
private def reflect[T, Arg <: AnyRef](
private def instantiateSessionState(
className: String,
ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = {
sparkSession: SparkSession): SessionState = {

try {
// get `SessionState.apply(SparkSession)`
val clazz = Utils.classForName(className)
val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass)
ctor.newInstance(ctorArg).asInstanceOf[T]
val method = clazz.getMethod("apply", sparkSession.getClass)
method.invoke(null, sparkSession).asInstanceOf[SessionState]
} catch {
case NonFatal(e) =>
throw new IllegalArgumentException(s"Error while instantiating '$className':", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
* Preprocess [[CreateTable]], to do some normalization and checking.
*/
case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[LogicalPlan] {
private val catalog = sparkSession.sessionState.catalog
// catalog is a def and not a val/lazy val as the latter would introduce a circular reference
private def catalog = sparkSession.sessionState.catalog

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,14 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def clear(): Unit = {
settings.clear()
}

override def clone(): SQLConf = {
val result = new SQLConf
getAllConfs.foreach {
case(k, v) => if (v ne null) result.setConfString(k, v)
}
result
}
}

/**
Expand Down
Loading

0 comments on commit 6570cfd

Please sign in to comment.