Skip to content

Commit

Permalink
[SPARK-8873] [MESOS] Clean up shuffle files if external shuffle servi…
Browse files Browse the repository at this point in the history
…ce is used

This patch builds directly on apache#7820, which is largely written by tnachen. The only addition is one commit for cleaning up the code. There should be no functional differences between this and apache#7820.

Author: Timothy Chen <[email protected]>
Author: Andrew Or <[email protected]>

Closes apache#7881 from andrewor14/tim-cleanup-mesos-shuffle and squashes the following commits:

8894f7d [Andrew Or] Clean up code
2a5fa10 [Andrew Or] Merge branch 'mesos_shuffle_clean' of github.com:tnachen/spark into tim-cleanup-mesos-shuffle
fadff89 [Timothy Chen] Address comments.
e4d0f1d [Timothy Chen] Clean up external shuffle data on driver exit with Mesos.
  • Loading branch information
tnachen authored and Andrew Or committed Aug 3, 2015
1 parent 1ebd41b commit 95dccc6
Show file tree
Hide file tree
Showing 15 changed files with 394 additions and 17 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2658,7 +2658,7 @@ object SparkContext extends Logging {
val coarseGrained = sc.conf.getBoolean("spark.mesos.coarse", false)
val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, sc, url)
new CoarseMesosSchedulerBackend(scheduler, sc, url, sc.env.securityManager)
} else {
new MesosSchedulerBackend(scheduler, sc, url)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslServerBootstrap
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.util.TransportConf
import org.apache.spark.util.Utils

/**
Expand All @@ -45,11 +46,16 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana
private val useSasl: Boolean = securityManager.isAuthenticationEnabled()

private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
private val blockHandler = newShuffleBlockHandler(transportConf)
private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler)

private var server: TransportServer = _

/** Create a new shuffle block handler. Factored out for subclasses to override. */
protected def newShuffleBlockHandler(conf: TransportConf): ExternalShuffleBlockHandler = {
new ExternalShuffleBlockHandler(conf)
}

/** Starts the external shuffle service if the user has configured us to. */
def startIfEnabled() {
if (enabled) {
Expand Down Expand Up @@ -93,14 +99,21 @@ object ExternalShuffleService extends Logging {
private val barrier = new CountDownLatch(1)

def main(args: Array[String]): Unit = {
main(args, (conf: SparkConf, sm: SecurityManager) => new ExternalShuffleService(conf, sm))
}

/** A helper main method that allows the caller to call this with a custom shuffle service. */
private[spark] def main(
args: Array[String],
newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = {
val sparkConf = new SparkConf
Utils.loadDefaultSparkProperties(sparkConf)
val securityManager = new SecurityManager(sparkConf)

// we override this value since this service is started from the command line
// and we assume the user really wants it to be running
sparkConf.set("spark.shuffle.service.enabled", "true")
server = new ExternalShuffleService(sparkConf, securityManager)
server = newShuffleService(sparkConf, securityManager)
server.start()

installShutdownHook()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.deploy.mesos

import java.net.SocketAddress

import scala.collection.mutable

import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.ExternalShuffleService
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage
import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver
import org.apache.spark.network.util.TransportConf

/**
* An RPC endpoint that receives registration requests from Spark drivers running on Mesos.
* It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]].
*/
private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportConf)
extends ExternalShuffleBlockHandler(transportConf) with Logging {

// Stores a map of driver socket addresses to app ids
private val connectedApps = new mutable.HashMap[SocketAddress, String]

protected override def handleMessage(
message: BlockTransferMessage,
client: TransportClient,
callback: RpcResponseCallback): Unit = {
message match {
case RegisterDriverParam(appId) =>
val address = client.getSocketAddress
logDebug(s"Received registration request from app $appId (remote address $address).")
if (connectedApps.contains(address)) {
val existingAppId = connectedApps(address)
if (!existingAppId.equals(appId)) {
logError(s"A new app '$appId' has connected to existing address $address, " +
s"removing previously registered app '$existingAppId'.")
applicationRemoved(existingAppId, true)
}
}
connectedApps(address) = appId
callback.onSuccess(new Array[Byte](0))
case _ => super.handleMessage(message, client, callback)
}
}

/**
* On connection termination, clean up shuffle files written by the associated application.
*/
override def connectionTerminated(client: TransportClient): Unit = {
val address = client.getSocketAddress
if (connectedApps.contains(address)) {
val appId = connectedApps(address)
logInfo(s"Application $appId disconnected (address was $address).")
applicationRemoved(appId, true /* cleanupLocalDirs */)
connectedApps.remove(address)
} else {
logWarning(s"Unknown $address disconnected.")
}
}

/** An extractor object for matching [[RegisterDriver]] message. */
private object RegisterDriverParam {
def unapply(r: RegisterDriver): Option[String] = Some(r.getAppId)
}
}

/**
* A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers
* to associate with. This allows the shuffle service to detect when a driver is terminated
* and can clean up the associated shuffle files.
*/
private[mesos] class MesosExternalShuffleService(conf: SparkConf, securityManager: SecurityManager)
extends ExternalShuffleService(conf, securityManager) {

protected override def newShuffleBlockHandler(
conf: TransportConf): ExternalShuffleBlockHandler = {
new MesosExternalShuffleBlockHandler(conf)
}
}

private[spark] object MesosExternalShuffleService extends Logging {

def main(args: Array[String]): Unit = {
ExternalShuffleService.main(args,
(conf: SparkConf, sm: SecurityManager) => new MesosExternalShuffleService(conf, sm))
}
}


6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint
*
* It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence.
*
* The lift-cycle will be:
* The life-cycle of an endpoint is:
*
* constructor onStart receive* onStop
* constructor -> onStart -> receive* -> onStop
*
* Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use
* Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use
* [[ThreadSafeRpcEndpoint]]
*
* If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ import scala.collection.mutable.{HashMap, HashSet}

import com.google.common.collect.HashBiMap
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
import org.apache.mesos.{Scheduler => MScheduler, _}
import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver}

import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
import org.apache.spark.rpc.RpcAddress
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.util.Utils
import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}

/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
Expand All @@ -46,7 +49,8 @@ import org.apache.spark.{SparkContext, SparkEnv, SparkException, TaskState}
private[spark] class CoarseMesosSchedulerBackend(
scheduler: TaskSchedulerImpl,
sc: SparkContext,
master: String)
master: String,
securityManager: SecurityManager)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
with MScheduler
with MesosSchedulerUtils {
Expand All @@ -56,12 +60,19 @@ private[spark] class CoarseMesosSchedulerBackend(
// Maximum number of cores to acquire (TODO: we'll need more flexible controls here)
val maxCores = conf.get("spark.cores.max", Int.MaxValue.toString).toInt

// If shuffle service is enabled, the Spark driver will register with the shuffle service.
// This is for cleaning up shuffle files reliably.
private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)

// Cores we have acquired with each Mesos task ID
val coresByTaskId = new HashMap[Int, Int]
var totalCoresAcquired = 0

val slaveIdsWithExecutors = new HashSet[String]

// Maping from slave Id to hostname
private val slaveIdToHost = new HashMap[String, String]

val taskIdToSlaveId: HashBiMap[Int, String] = HashBiMap.create[Int, String]
// How many times tasks on each slave failed
val failuresBySlaveId: HashMap[String, Int] = new HashMap[String, Int]
Expand Down Expand Up @@ -90,6 +101,19 @@ private[spark] class CoarseMesosSchedulerBackend(
private val slaveOfferConstraints =
parseConstraintString(sc.conf.get("spark.mesos.constraints", ""))

// A client for talking to the external shuffle service, if it is a
private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = {
if (shuffleServiceEnabled) {
Some(new MesosExternalShuffleClient(
SparkTransportConf.fromSparkConf(conf),
securityManager,
securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled()))
} else {
None
}
}

var nextMesosTaskId = 0

@volatile var appId: String = _
Expand Down Expand Up @@ -188,6 +212,7 @@ private[spark] class CoarseMesosSchedulerBackend(

override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
appId = frameworkId.getValue
mesosExternalShuffleClient.foreach(_.init(appId))
logInfo("Registered as framework ID " + appId)
markRegistered()
}
Expand Down Expand Up @@ -244,6 +269,7 @@ private[spark] class CoarseMesosSchedulerBackend(

// accept the offer and launch the task
logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus")
slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname
d.launchTasks(
Collections.singleton(offer.getId),
Collections.singleton(taskBuilder.build()), filters)
Expand All @@ -261,7 +287,27 @@ private[spark] class CoarseMesosSchedulerBackend(
val taskId = status.getTaskId.getValue.toInt
val state = status.getState
logInfo(s"Mesos task $taskId is now $state")
val slaveId: String = status.getSlaveId.getValue
stateLock.synchronized {
// If the shuffle service is enabled, have the driver register with each one of the
// shuffle services. This allows the shuffle services to clean up state associated with
// this application when the driver exits. There is currently not a great way to detect
// this through Mesos, since the shuffle services are set up independently.
if (TaskState.fromMesos(state).equals(TaskState.RUNNING) &&
slaveIdToHost.contains(slaveId) &&
shuffleServiceEnabled) {
assume(mesosExternalShuffleClient.isDefined,
"External shuffle client was not instantiated even though shuffle service is enabled.")
// TODO: Remove this and allow the MesosExternalShuffleService to detect
// framework termination when new Mesos Framework HTTP API is available.
val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337)
val hostname = slaveIdToHost.remove(slaveId).get
logDebug(s"Connecting to shuffle service on slave $slaveId, " +
s"host $hostname, port $externalShufflePort for app ${conf.getAppId}")
mesosExternalShuffleClient.get
.registerDriverWithShuffleService(hostname, externalShufflePort)
}

if (TaskState.isFinished(TaskState.fromMesos(state))) {
val slaveId = taskIdToSlaveId(taskId)
slaveIdsWithExecutors -= slaveId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.scalatest.mock.MockitoSugar
import org.scalatest.BeforeAndAfter

import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite}

class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
with LocalSparkContext
Expand Down Expand Up @@ -59,7 +59,8 @@ class CoarseMesosSchedulerBackendSuite extends SparkFunSuite
private def createSchedulerBackend(
taskScheduler: TaskSchedulerImpl,
driver: SchedulerDriver): CoarseMesosSchedulerBackend = {
val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master") {
val securityManager = mock[SecurityManager]
val backend = new CoarseMesosSchedulerBackend(taskScheduler, sc, "master", securityManager) {
override protected def createSchedulerDriver(
masterUrl: String,
scheduler: Scheduler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public List<String> buildCommand(Map<String, String> env) throws IOException {
} else if (className.equals("org.apache.spark.executor.MesosExecutorBackend")) {
javaOptsKeys.add("SPARK_EXECUTOR_OPTS");
memKey = "SPARK_EXECUTOR_MEMORY";
} else if (className.equals("org.apache.spark.deploy.ExternalShuffleService")) {
} else if (className.equals("org.apache.spark.deploy.ExternalShuffleService") ||
className.equals("org.apache.spark.deploy.mesos.MesosExternalShuffleService")) {
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
javaOptsKeys.add("SPARK_SHUFFLE_OPTS");
memKey = "SPARK_DAEMON_MEMORY";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -79,6 +80,10 @@ public boolean isActive() {
return channel.isOpen() || channel.isActive();
}

public SocketAddress getSocketAddress() {
return channel.remoteAddress();
}

/**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ public ExternalShuffleBlockHandler(TransportConf conf) {
@Override
public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) {
BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message);
handleMessage(msgObj, client, callback);
}

protected void handleMessage(
BlockTransferMessage msgObj,
TransportClient client,
RpcResponseCallback callback) {
if (msgObj instanceof OpenBlocks) {
OpenBlocks msg = (OpenBlocks) msgObj;
List<ManagedBuffer> blocks = Lists.newArrayList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ public class ExternalShuffleClient extends ShuffleClient {
private final boolean saslEncryptionEnabled;
private final SecretKeyHolder secretKeyHolder;

private TransportClientFactory clientFactory;
private String appId;
protected TransportClientFactory clientFactory;
protected String appId;

/**
* Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled,
Expand All @@ -71,6 +71,10 @@ public ExternalShuffleClient(
this.saslEncryptionEnabled = saslEncryptionEnabled;
}

protected void checkInit() {
assert appId != null : "Called before init()";
}

@Override
public void init(String appId) {
this.appId = appId;
Expand All @@ -89,7 +93,7 @@ public void fetchBlocks(
final String execId,
String[] blockIds,
BlockFetchingListener listener) {
assert appId != null : "Called before init()";
checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
Expand Down Expand Up @@ -132,7 +136,7 @@ public void registerWithShuffleServer(
int port,
String execId,
ExecutorShuffleInfo executorInfo) throws IOException {
assert appId != null : "Called before init()";
checkInit();
TransportClient client = clientFactory.createClient(host, port);
byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray();
client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
Expand Down
Loading

0 comments on commit 95dccc6

Please sign in to comment.