Skip to content

Commit

Permalink
[SPARK-16350][SQL] Fix support for incremental planning in wirteStrea…
Browse files Browse the repository at this point in the history
…m.foreach()

## What changes were proposed in this pull request?

There are cases where `complete` output mode does not output updated aggregated value; for details please refer to [SPARK-16350](https://issues.apache.org/jira/browse/SPARK-16350).

The cause is that, as we do `data.as[T].foreachPartition { iter => ... }` in `ForeachSink.addBatch()`, `foreachPartition()` does not support incremental planning for now.

This patches makes `foreachPartition()` support incremental planning in `ForeachSink`, by making a special version of `Dataset` with its `rdd()` method supporting incremental planning.

## How was this patch tested?

Added a unit test which failed before the change

Author: Liwei Lin <[email protected]>

Closes apache#14030 from lw-lin/fix-foreach-complete.
  • Loading branch information
lw-lin authored and zsxwing committed Jul 7, 2016
1 parent a04cab8 commit 0f7175d
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.execution.streaming

import org.apache.spark.TaskContext
import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter}
import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde

/**
* A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
Expand All @@ -30,7 +32,41 @@ import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {

override def addBatch(batchId: Long, data: DataFrame): Unit = {
data.as[T].foreachPartition { iter =>
// TODO: Refine this method when SPARK-16264 is resolved; see comments below.

// This logic should've been as simple as:
// ```
// data.as[T].foreachPartition { iter => ... }
// ```
//
// Unfortunately, doing that would just break the incremental planing. The reason is,
// `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` just
// does not support `IncrementalExecution`.
//
// So as a provisional fix, below we've made a special version of `Dataset` with its `rdd()`
// method supporting incremental planning. But in the long run, we should generally make newly
// created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to
// resolve).

val datasetWithIncrementalExecution =
new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) {
override lazy val rdd: RDD[T] = {
val objectType = exprEnc.deserializer.dataType
val deserialized = CatalystSerde.deserialize[T](logicalPlan)

// was originally: sparkSession.sessionState.executePlan(deserialized) ...
val incrementalExecution = new IncrementalExecution(
this.sparkSession,
deserialized,
data.queryExecution.asInstanceOf[IncrementalExecution].outputMode,
data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation,
data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId)
incrementalExecution.toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType))
}.asInstanceOf[RDD[T]]
}
}
datasetWithIncrementalExecution.foreachPartition { iter =>
if (writer.open(TaskContext.getPartitionId(), batchId)) {
var isFailed = false
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import org.apache.spark.sql.streaming.OutputMode
class IncrementalExecution private[sql](
sparkSession: SparkSession,
logicalPlan: LogicalPlan,
outputMode: OutputMode,
checkpointLocation: String,
val outputMode: OutputMode,
val checkpointLocation: String,
val currentBatchId: Long)
extends QueryExecution(sparkSession, logicalPlan) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.mutable
import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.ForeachWriter
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext

class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
Expand All @@ -35,35 +35,103 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
sqlContext.streams.active.foreach(_.stop())
}

test("foreach") {
test("foreach() with `append` output mode") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]
val query = input.toDS().repartition(2).writeStream
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.outputMode(OutputMode.Append)
.foreach(new TestForeachWriter())
.start()

// -- batch 0 ---------------------------------------
input.addData(1, 2, 3, 4)
query.processAllAvailable()

val expectedEventsForPartition0 = Seq(
var expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 1),
ForeachSinkSuite.Process(value = 3),
ForeachSinkSuite.Close(None)
)
val expectedEventsForPartition1 = Seq(
var expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 0),
ForeachSinkSuite.Process(value = 2),
ForeachSinkSuite.Process(value = 4),
ForeachSinkSuite.Close(None)
)

val allEvents = ForeachSinkSuite.allEvents()
var allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))

ForeachSinkSuite.clear()

// -- batch 1 ---------------------------------------
input.addData(5, 6, 7, 8)
query.processAllAvailable()

expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 1),
ForeachSinkSuite.Process(value = 5),
ForeachSinkSuite.Process(value = 7),
ForeachSinkSuite.Close(None)
)
expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 1),
ForeachSinkSuite.Process(value = 6),
ForeachSinkSuite.Process(value = 8),
ForeachSinkSuite.Close(None)
)

allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert {
allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) ||
allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0)
}
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))

query.stop()
}
}

test("foreach() with `complete` output mode") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]

val query = input.toDS()
.groupBy().count().as[Long].map(_.toInt)
.writeStream
.option("checkpointLocation", checkpointDir.getCanonicalPath)
.outputMode(OutputMode.Complete)
.foreach(new TestForeachWriter())
.start()

// -- batch 0 ---------------------------------------
input.addData(1, 2, 3, 4)
query.processAllAvailable()

var allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 1)
var expectedEvents = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 4),
ForeachSinkSuite.Close(None)
)
assert(allEvents === Seq(expectedEvents))

ForeachSinkSuite.clear()

// -- batch 1 ---------------------------------------
input.addData(5, 6, 7, 8)
query.processAllAvailable()

allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 1)
expectedEvents = Seq(
ForeachSinkSuite.Open(partition = 0, version = 1),
ForeachSinkSuite.Process(value = 8),
ForeachSinkSuite.Close(None)
)
assert(allEvents === Seq(expectedEvents))

query.stop()
}
}
Expand Down

0 comments on commit 0f7175d

Please sign in to comment.