Skip to content

Commit

Permalink
[SPARK-48480][SS][CONNECT] StreamingQueryListener should not be affec…
Browse files Browse the repository at this point in the history
…ted by spark.interrupt()

### What changes were proposed in this pull request?

This PR implements a small architecture change for the server side listenerBusListener. Before, when the first `addListener` call reaches to the server, there is a thread created, and there is a latch to hold this thread long running. This is to prevent this thread from returning, which would send a `ResultComplete` to the client, and closes the client receiving iterator. In client side listener we need to keep the iterator open all the time (until the last `removeListener` call) to keep receiving events.

In this PR, we delegate the sending of the final `ResultComplete` to the listener thread itself. Now the thread doesn't need to be held stuck. This would 1. remove a hanging thread running on the server and 2. Shield the listener from being effected by `spark.interruptAll`.

`spark.interruptAll` interrupts all spark connect threads. So before this change, the listener long-running thread is also interrupted, therefore would be affected by it and stop sending back events. Now the long-running thread is closed, so it won't be affected.

### Why are the changes needed?

Spark Connect improvement.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added unit test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#46929 from WweiL/listener-uninterruptible.

Authored-by: Wei Liu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
WweiL authored and HyukjinKwon committed Aug 3, 2024
1 parent 9e35d04 commit aaf602a
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 33 deletions.
40 changes: 40 additions & 0 deletions python/pyspark/sql/tests/connect/streaming/test_parity_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,46 @@ def test_listener_events_spark_command(self):
# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)

def test_server_listener_uninterruptible(self):
listener = TestListenerLocalV1()

try:
self.spark.streams.addListener(listener)
q = (
self.spark.readStream.format("rate")
.load()
.writeStream.format("noop")
.queryName("test_listener_uninterruptible")
.start()
)

self.assertEqual(len(listener.start), 1)
self.assertEqual(str(listener.start[0].id), q.id)

while q.lastProgress is None:
q.awaitTermination(0.5)

# Interrupt should stop the query but should not impact the listener,
# therefore there should be a QueryTerminatedEvent sent from the server.
self.spark.interruptAll()

# Need to wait a while before the query really stops
while q.isActive:
q.awaitTermination(0.5)

# Need to wait a while before QueryTerminatedEvent reaches client
while len(listener.terminated) == 0:
time.sleep(1)

self.assertEqual(len(listener.terminated), 1)
self.assertEqual(str(listener.terminated[0].id), q.id)

finally:
for listener in self.spark.streams._sqlb._listener_bus:
self.spark.streams.removeListener(listener)
for q in self.spark.streams.active:
q.stop()


if __name__ == "__main__":
import unittest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,17 +238,44 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
}
completed = true // no longer interruptible

if (executeHolder.reattachable) {
// Reattachable execution sends a ResultComplete at the end of the stream
// to signal that there isn't more coming.
executeHolder.responseObserver.onNextComplete(createResultComplete())
} else {
executeHolder.responseObserver.onCompleted()
// If the request starts a long running iterator (e.g. StreamingQueryListener needs
// a long-running iterator to continuously stream back events, it runs in a separate
// thread, and holds the responseObserver to send back the listener events.)
// In such cases, even after the ExecuteThread returns, we still want to keep the
// client side iterator open, i.e. don't send the ResultComplete to the client.
// So delegate the sending of the final ResultComplete to the listener thread itself.
if (!shouldDelegateCompleteResponse(executeHolder.request)) {
if (executeHolder.reattachable) {
// Reattachable execution sends a ResultComplete at the end of the stream
// to signal that there isn't more coming.
executeHolder.responseObserver.onNextComplete(createResultComplete())
} else {
executeHolder.responseObserver.onCompleted()
}
}
}
}
}

/**
* Perform a check to see if we should delegate sending ResultCompelete. Currently, the
* ADD_LISTENER_BUS_LISTENER command creates a new thread and continuously streams back listener
* events to the client side StreamingQueryListenerBus. In this case, we would like to delegate
* the sending of the final ResultComplete to the handler thread itself.
* @param request
* The request to check
* @return
* True if we should delegate sending the final ResultComplete to the handler thread, i.e.
* don't send a ResultComplete when the ExecuteThread returns.
*/
private def shouldDelegateCompleteResponse(request: proto.ExecutePlanRequest): Boolean = {
request.getPlan.getOpTypeCase == proto.Plan.OpTypeCase.COMMAND &&
request.getPlan.getCommand.getCommandTypeCase ==
proto.Command.CommandTypeCase.STREAMING_QUERY_LISTENER_BUS_COMMAND &&
request.getPlan.getCommand.getStreamingQueryListenerBusCommand.getCommandCase ==
proto.StreamingQueryListenerBusCommand.CommandCase.ADD_LISTENER_BUS_LISTENER
}

private def handlePlan(request: proto.ExecutePlanRequest): Unit = {
val responseObserver = executeHolder.responseObserver

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,20 @@ class SparkConnectStreamingQueryListenerHandler(executeHolder: ExecuteHolder) ex
return
}
}
logInfo(log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionId)}]" +
log"[UserId: ${MDC(LogKeys.USER_ID, userId)}]" +
log"[operationId: ${MDC(LogKeys.OPERATION_HANDLE_ID, executeHolder.operationId)}] " +
log"Server side listener added. Now blocking until " +
log"all client side listeners are removed or there is error transmitting the event back.")
// Block the handling thread, and have serverListener continuously send back new events
listenerHolder.streamingQueryListenerLatch.await()
logInfo(
log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionId)}]" +
log"[UserId: ${MDC(LogKeys.USER_ID, userId)}]" +
log"[operationId: ${MDC(LogKeys.OPERATION_HANDLE_ID, executeHolder.operationId)}] " +
log"Server side listener long-running handling thread ended.")
log"Server side listener added.")

case StreamingQueryListenerBusCommand.CommandCase.REMOVE_LISTENER_BUS_LISTENER =>
listenerHolder.isServerSideListenerRegistered match {
case true =>
sessionHolder.streamingServersideListenerHolder.cleanUp()
logInfo(log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionId)}]" +
log"[UserId: ${MDC(LogKeys.USER_ID, userId)}]" +
log"[operationId: ${MDC(LogKeys.OPERATION_HANDLE_ID, executeHolder.operationId)}] " +
log"Server side listener removed.")
case false =>
logWarning(log"[SessionId: ${MDC(LogKeys.SESSION_ID, sessionId)}]" +
log"[UserId: ${MDC(LogKeys.USER_ID, userId)}]" +
Expand All @@ -121,11 +119,6 @@ class SparkConnectStreamingQueryListenerHandler(executeHolder: ExecuteHolder) ex
case StreamingQueryListenerBusCommand.CommandCase.COMMAND_NOT_SET =>
throw new IllegalArgumentException("Missing command in StreamingQueryListenerBusCommand")
}
// If this thread is the handling thread of the original ADD_LISTENER_BUS_LISTENER command,
// this will be sent when the latch is counted down (either through
// a REMOVE_LISTENER_BUS_LISTENER command, or long-lived gRPC throws.
// If this thread is the handling thread of the REMOVE_LISTENER_BUS_LISTENER command,
// this is hit right away.
executeHolder.eventsManager.postFinished()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.connect.service

import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}

import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal
Expand All @@ -29,6 +29,7 @@ import org.apache.spark.connect.proto.StreamingQueryEventType
import org.apache.spark.connect.proto.StreamingQueryListenerEvent
import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.connect.execution.ExecuteResponseObserver
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.ArrayImplicits._

Expand All @@ -41,8 +42,6 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
// There is only one listener per sessionHolder, but each listener is responsible for all events
// of all streaming queries in the SparkSession.
var streamingQueryServerSideListener: Option[SparkConnectListenerBusListener] = None
// The count down latch to hold the long-running listener thread before sending ResultComplete.
var streamingQueryListenerLatch = new CountDownLatch(1)
// The cache for QueryStartedEvent, key is query runId and value is the actual QueryStartedEvent.
// Events for corresponding query will be sent back to client with
// the WriteStreamOperationStart response, so that the client can handle the event before
Expand Down Expand Up @@ -70,23 +69,21 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) {
val serverListener = new SparkConnectListenerBusListener(this, responseObserver)
sessionHolder.session.streams.addListener(serverListener)
streamingQueryServerSideListener = Some(serverListener)
streamingQueryListenerLatch = new CountDownLatch(1)
}

/**
* The cleanup of the server side listener and related resources. This method is called when the
* REMOVE_LISTENER_BUS_LISTENER command is received or when responseObserver.onNext throws an
* exception. It removes the listener from the session, clears the cache. Also it counts down
* the latch, so the long-running thread can proceed to send back the final ResultComplete
* response.
* exception. It removes the listener from the session, clears the cache. Also it sends back the
* final ResultComplete response.
*/
def cleanUp(): Unit = lock.synchronized {
streamingQueryServerSideListener.foreach { listener =>
sessionHolder.session.streams.removeListener(listener)
listener.sendResultComplete()
}
streamingQueryStartedEventCache.clear()
streamingQueryServerSideListener = None
streamingQueryListenerLatch.countDown()
}
}

Expand All @@ -104,8 +101,6 @@ private[sql] class SparkConnectListenerBusListener(
val sessionHolder = serverSideListenerHolder.sessionHolder
// The method used to stream back the events to the client.
// The event is serialized to json and sent to the client.
// The responseObserver is what of the first executeThread (long running thread),
// which is held still by the streamingQueryListenerLatch.
// If any exception is thrown while transmitting back the event, the listener is removed,
// all related sources are cleaned up, and the long-running thread will proceed to send
// the final ResultComplete response.
Expand Down Expand Up @@ -141,6 +136,16 @@ private[sql] class SparkConnectListenerBusListener(
}
}

def sendResultComplete(): Unit = {
responseObserver
.asInstanceOf[ExecuteResponseObserver[ExecutePlanResponse]]
.onNextComplete(
ExecutePlanResponse
.newBuilder()
.setResultComplete(ExecutePlanResponse.ResultComplete.newBuilder().build())
.build())
}

// QueryStartedEvent is sent to client along with WriteStreamOperationStartResult
override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
serverSideListenerHolder.streamingQueryStartedEventCache.put(event.runId.toString, event)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto.{Command, ExecutePlanResponse}
import org.apache.spark.sql.connect.SparkConnectTestUtils
import org.apache.spark.sql.connect.execution.ExecuteResponseObserver
import org.apache.spark.sql.connect.planner.SparkConnectStreamingQueryListenerHandler
import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryListener}
import org.apache.spark.sql.streaming.Trigger.ProcessingTime
Expand Down Expand Up @@ -186,7 +187,7 @@ class SparkConnectListenerBusListenerSuite
when(executeHolder.sessionHolder).thenReturn(sessionHolder)
when(executeHolder.operationId).thenReturn("operationId")

val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
val responseObserver = mock[ExecuteResponseObserver[ExecutePlanResponse]]
doThrow(new RuntimeException("I'm dead"))
.when(responseObserver)
.onNext(any[ExecutePlanResponse]())
Expand All @@ -204,14 +205,13 @@ class SparkConnectListenerBusListenerSuite
sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty)
assert(spark.streams.listListeners().size === listenerCntBeforeThrow)
assert(listenerHolder.streamingQueryStartedEventCache.isEmpty)
assert(listenerHolder.streamingQueryListenerLatch.getCount === 0)
}

}

test("Proper handling on onNext throw - query progress") {
val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
val responseObserver = mock[StreamObserver[ExecutePlanResponse]]
val responseObserver = mock[ExecuteResponseObserver[ExecutePlanResponse]]
doThrow(new RuntimeException("I'm dead"))
.when(responseObserver)
.onNext(any[ExecutePlanResponse]())
Expand All @@ -235,7 +235,6 @@ class SparkConnectListenerBusListenerSuite
eventually(timeout(5.seconds), interval(500.milliseconds)) {
assert(!spark.streams.listListeners().contains(listenerBusListener))
assert(listenerHolder.streamingQueryStartedEventCache.isEmpty)
assert(listenerHolder.streamingQueryListenerLatch.getCount === 0)
}
}
}

0 comments on commit aaf602a

Please sign in to comment.