From f7ec854f1b7f575c4c7437daf8e6992c684b6de2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 9 Apr 2016 13:51:28 -0700 Subject: [PATCH] Revert "[SPARK-14419] [SQL] Improve HashedRelation for key fit within Long" This reverts commit 90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f. --- .../aggregate/TungstenAggregate.scala | 3 +- .../execution/joins/BroadcastHashJoin.scala | 18 +- .../spark/sql/execution/joins/HashJoin.scala | 31 +- .../sql/execution/joins/HashedRelation.scala | 688 ++++++------------ .../execution/joins/ShuffledHashJoin.scala | 51 +- .../BenchmarkWholeStageCodegen.scala | 132 +--- .../spark/sql/execution/ExchangeSuite.scala | 8 +- .../execution/joins/HashedRelationSuite.scala | 48 +- 8 files changed, 346 insertions(+), 633 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 692fef703f7c8..0a5a72c52a372 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -454,7 +454,7 @@ case class TungstenAggregate( val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"") + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") @@ -467,7 +467,6 @@ case class TungstenAggregate( s""" ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { - $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a8f854136c1f9..e3d554c2de20a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -51,7 +50,10 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode(buildKeys) + val mode = HashedRelationBroadcastMode( + canJoinKeyFitWithinLong, + rewriteKeyExpr(buildKeys), + buildPlan.output) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -66,7 +68,7 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) join(streamedIter, hashed, numOutputRows) } } @@ -103,7 +105,7 @@ case class BroadcastHashJoin( ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.estimatedSize()); + | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) (broadcastRelation, relationTerm) } @@ -116,13 +118,15 @@ case class BroadcastHashJoin( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + if (canJoinKeyFitWithinLong) { // generate the join key as Long - val ev = streamedKeys.head.gen(ctx) + val expr = rewriteKeyExpr(streamedKeys).head + val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4c912d371e05e..8f45d57126723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -59,13 +59,9 @@ trait HashJoin { case BuildRight => (right, left) } - protected lazy val (buildKeys, streamedKeys) = { - val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) - buildSide match { - case BuildLeft => (lkeys, rkeys) - case BuildRight => (rkeys, lkeys) - } + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) } /** @@ -88,8 +84,17 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 + // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same + // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys + // with two same ints have hash code 0, we rotate the bits of second one. + val rotated = if (e.dataType == IntegerType) { + // (e >>> 15) | (e << 17) + BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) + } else { + e + } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -100,11 +105,17 @@ trait HashJoin { keyExpr :: Nil } + protected lazy val canJoinKeyFitWithinLong: Boolean = { + val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) + val key = rewriteKeyExpr(buildKeys) + sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] + } + protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(buildKeys) + UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(streamedKeys) + UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 4959f60dab275..5ccb435686f23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,22 +18,24 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} +import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkConf, SparkEnv, SparkException} -import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} +import org.apache.spark.util.collection.CompactBuffer /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation extends KnownSizeEstimation { +private[execution] sealed trait HashedRelation { /** * Returns matched rows. * @@ -72,36 +74,51 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { */ def asReadOnlyCopy(): HashedRelation + /** + * Returns the size of used memory. + */ + def getMemorySize: Long = 1L // to make the test happy + /** * Release any used resources. */ - def close(): Unit + def close(): Unit = {} + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { + out.writeInt(serialized.length) // Write the length of serialized bytes first + out.write(serialized) + } + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def readBytes(in: ObjectInput): Array[Byte] = { + val serializedSize = in.readInt() // Read the length of serialized bytes first + val bytes = new Array[Byte](serializedSize) + in.readFully(bytes) + bytes + } } private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. + * + * Note: The caller should make sure that these InternalRow are different objects. */ def apply( + canJoinKeyFitWithinLong: Boolean, input: Iterator[InternalRow], - key: Seq[Expression], - sizeEstimate: Int = 64, - taskMemoryManager: TaskMemoryManager = null): HashedRelation = { - val mm = Option(taskMemoryManager).getOrElse { - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - } + keyGenerator: Projection, + sizeEstimate: Int = 64): HashedRelation = { - if (key.length == 1 && key.head.dataType == LongType) { - LongHashedRelation(input, key, sizeEstimate, mm) + if (canJoinKeyFitWithinLong) { + LongHashedRelation(input, keyGenerator, sizeEstimate) } else { - UnsafeHashedRelation(input, key, sizeEstimate, mm) + UnsafeHashedRelation( + input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } } } @@ -116,7 +133,7 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numFields: Int, private var binaryMap: BytesToBytesMap) - extends HashedRelation with Externalizable { + extends HashedRelation with KnownSizeEstimation with Externalizable { private[joins] def this() = this(0, null) // Needed for serialization @@ -125,6 +142,10 @@ private[joins] class UnsafeHashedRelation( override def asReadOnlyCopy(): UnsafeHashedRelation = new UnsafeHashedRelation(numFields, binaryMap) + override def getMemorySize: Long = { + binaryMap.getTotalMemoryConsumption + } + override def estimatedSize: Long = { binaryMap.getTotalMemoryConsumption } @@ -255,10 +276,20 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - key: Seq[Expression], - sizeEstimate: Int, - taskMemoryManager: TaskMemoryManager): HashedRelation = { + keyGenerator: UnsafeProjection, + sizeEstimate: Int): HashedRelation = { + val taskMemoryManager = if (TaskContext.get() != null) { + TaskContext.get().taskMemoryManager() + } else { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) @@ -269,7 +300,6 @@ private[joins] object UnsafeHashedRelation { pageSizeBytes) // Create a mapping of buildKeys -> rows - val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] @@ -291,471 +321,144 @@ private[joins] object UnsafeHashedRelation { } } -private[joins] object LongToUnsafeRowMap { - // the largest prime that below 2^n - val LARGEST_PRIMES = { - // https://primes.utm.edu/lists/2small/0bit.html - val diffs = Seq( - 0, 1, 1, 3, 1, 3, 1, 5, - 3, 3, 9, 3, 1, 3, 19, 15, - 1, 5, 1, 3, 9, 3, 15, 3, - 39, 5, 39, 57, 3, 35, 1, 5 - ) - val primes = new Array[Int](32) - primes(0) = 1 - var power2 = 1 - (1 until 32).foreach { i => - power2 *= 2 - primes(i) = power2 - diffs(i) - } - primes - } -} - /** - * An append-only hash map mapping from key of Long to UnsafeRow. - * - * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array - * (`page`) in this format: - * - * [bytes of row1][address1][bytes of row2][address1] ... - * - * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key - * could have multiple values. the address at the end of last value for every key is 0. - * - * The keys and addresses of their values could be stored in two modes: - * - * 1) sparse mode: the keys and addresses are stored in `array` as: - * - * [key1][address1][key2][address2]...[] - * - * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 - * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address - * hash collision. - * - * 2) dense mode: all the addresses are packed into a single array of long, as: - * - * [address1] [address2] ... - * - * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is - * determined by `key1 - minKey`. - * - * The map is created as sparse mode, then key-value could be appended into it. Once finish - * appending, caller could all optimize() to try to turn the map into dense mode, which is faster - * to probe. + * An interface for a hashed relation that the key is a Long. */ -private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) - extends MemoryConsumer(mm) with Externalizable { - import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._ - - // Whether the keys are stored in dense mode or not. - private var isDense = false - - // The minimum value of keys. - private var minKey = Long.MaxValue - - // The Maxinum value of keys. - private var maxKey = Long.MinValue - - // Sparse mode: the actual capacity of map, is a prime number. - private var cap: Int = 0 - - // The array to store the key and offset of UnsafeRow in the page. - // - // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... - // Dense mode: [offset1 | size1] [offset2 | size2] - private var array: Array[Long] = null - - // The page to store all bytes of UnsafeRow and the pointer to next rows. - // [row1][pointer1] [row2][pointer2] - private var page: Array[Byte] = null - - // Current write cursor in the page. - private var cursor = Platform.BYTE_ARRAY_OFFSET - - // The total number of values of all keys. - private var numValues = 0 - - // The number of unique keys. - private var numKeys = 0 - - // needed by serializer - def this() = { - this( - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0), - 0) - } - - private def acquireMemory(size: Long): Unit = { - // do not support spilling - val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) - if (got < size) { - mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this) - throw new SparkException(s"Can't acquire $size bytes memory to build hash relation") - } - } - - private def freeMemory(size: Long): Unit = { - mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) - } - - private def init(): Unit = { - if (mm != null) { - cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ - sys.error(s"Can't create map with capacity $capacity") - } - acquireMemory(cap * 2 * 8 + (1 << 20)) - array = new Array[Long](cap * 2) - page = new Array[Byte](1 << 20) // 1M bytes - } - } - - init() - - def spill(size: Long, trigger: MemoryConsumer): Long = { - 0L - } - - /** - * Returns whether all the keys are unique. - */ - def keyIsUnique: Boolean = numKeys == numValues - - /** - * Returns total memory consumption. - */ - def getTotalMemoryConsumption: Long = { - array.length * 8 + page.length - } - - /** - * Returns the slot of array that store the keys (sparse mode). - */ - private def getSlot(key: Long): Int = { - var s = (key % cap).toInt - if (s < 0) { - s += cap - } - s * 2 - } - - private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - val offset = address >>> 32 - val size = address & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - resultRow +private[joins] trait LongHashedRelation extends HashedRelation { + override def get(key: InternalRow): Iterator[InternalRow] = { + get(key.getLong(0)) } - - /** - * Returns the single UnsafeRow for given key, or null if not found. - */ - def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { - if (isDense) { - val idx = (key - minKey).toInt - if (idx >= 0 && key <= maxKey && array(idx) > 0) { - return getRow(array(idx), resultRow) - } - } else { - var pos = getSlot(key) - var step = 1 - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return getRow(array(pos + 1), resultRow) - } - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } - } - } - null + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) } +} - /** - * Returns an interator of UnsafeRow for multiple linked values. - */ - private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { - new Iterator[UnsafeRow] { - var addr = address - override def hasNext: Boolean = addr != 0 - override def next(): UnsafeRow = { - val offset = addr >>> 32 - val size = addr & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - addr = Platform.getLong(page, offset + size) - resultRow - } - } - } +private[joins] final class GeneralLongHashedRelation( + private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) + extends LongHashedRelation with Externalizable { - /** - * Returns an iterator for all the values for the given key, or null if no value found. - */ - def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { - if (isDense) { - val idx = (key - minKey).toInt - if (idx >=0 && key <= maxKey && array(idx) > 0) { - return valueIter(array(idx), resultRow) - } - } else { - var pos = getSlot(key) - var step = 1 - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return valueIter(array(pos + 1), resultRow) - } - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } - } - } - null - } - - /** - * Appends the key and row into this map. - */ - def append(key: Long, row: UnsafeRow): Unit = { - if (key < minKey) { - minKey = key - } - if (key > maxKey) { - maxKey = key - } + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { - val used = page.length - if (used * 2L > (1L << 31)) { - sys.error("Can't allocate a page that is larger than 2G") - } - acquireMemory(used * 2) - val newPage = new Array[Byte](used * 2) - System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) - page = newPage - freeMemory(used) - } + override def keyIsUnique: Boolean = false - // copy the bytes of UnsafeRow - val offset = cursor - Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) - cursor += row.getSizeInBytes - Platform.putLong(page, cursor, 0) - cursor += 8 - numValues += 1 - updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) - } + override def asReadOnlyCopy(): GeneralLongHashedRelation = + new GeneralLongHashedRelation(hashTable) - /** - * Update the address in array for given key. - */ - private def updateIndex(key: Long, address: Long): Unit = { - var pos = getSlot(key) - var step = 1 - while (array(pos + 1) != 0 && array(pos) != key) { - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } - } - if (array(pos + 1) == 0) { - // this is the first value for this key, put the address in array. - array(pos) = key - array(pos + 1) = address - numKeys += 1 - if (numKeys * 2 > cap) { - // reach half of the capacity - growArray() - } + override def get(key: Long): Iterator[InternalRow] = { + val rows = hashTable.get(key) + if (rows != null) { + rows.toIterator } else { - // there is another value for this key, put the address at the end of final value. - var addr = array(pos + 1) - var pointer = (addr >>> 32) + (addr & 0xffffffffL) - while (Platform.getLong(page, pointer) != 0) { - addr = Platform.getLong(page, pointer) - pointer = (addr >>> 32) + (addr & 0xffffffffL) - } - Platform.putLong(page, pointer, address) - } - } - - private def growArray(): Unit = { - val old_cap = cap - var old_array = array - cap = LARGEST_PRIMES.find(_ > cap).getOrElse{ - sys.error(s"Can't grow map any more than $cap") - } - numKeys = 0 - acquireMemory(cap * 2 * 8) - array = new Array[Long](cap * 2) - var i = 0 - while (i < old_array.length) { - if (old_array(i + 1) > 0) { - updateIndex(old_array(i), old_array(i + 1)) - } - i += 2 - } - old_array = null // release the reference to old array - freeMemory(old_cap * 2 * 8) - } - - /** - * Try to turn the map into dense mode, which is faster to probe. - */ - def optimize(): Unit = { - val range = maxKey - minKey - // Convert to dense mode if it does not require more memory or could fit within L1 cache - if (range < array.length || range < 1024) { - try { - acquireMemory((range + 1) * 8) - } catch { - case e: SparkException => - // there is no enough memory to convert - return - } - val denseArray = new Array[Long]((range + 1).toInt) - var i = 0 - while (i < array.length) { - if (array(i + 1) > 0) { - val idx = (array(i) - minKey).toInt - denseArray(idx) = array(i + 1) - } - i += 2 - } - val old_length = array.length - array = denseArray - isDense = true - freeMemory(old_length * 8) - } - } - - /** - * Free all the memory acquired by this map. - */ - def free(): Unit = { - if (page != null) { - freeMemory(page.length) - page = null - } - if (array != null) { - freeMemory(array.length * 8) - array = null + null } } override def writeExternal(out: ObjectOutput): Unit = { - out.writeBoolean(isDense) - out.writeLong(minKey) - out.writeLong(maxKey) - out.writeInt(numKeys) - out.writeInt(numValues) - out.writeInt(cap) - - out.writeInt(array.length) - val buffer = new Array[Byte](4 << 10) - var offset = Platform.LONG_ARRAY_OFFSET - val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET - while (offset < end) { - val size = Math.min(buffer.length, end - offset) - Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) - out.write(buffer, 0, size) - offset += size - } - - val used = cursor - Platform.BYTE_ARRAY_OFFSET - out.writeInt(used) - out.write(page, 0, used) + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) } override def readExternal(in: ObjectInput): Unit = { - isDense = in.readBoolean() - minKey = in.readLong() - maxKey = in.readLong() - numKeys = in.readInt() - numValues = in.readInt() - cap = in.readInt() - - val length = in.readInt() - array = new Array[Long](length) - val buffer = new Array[Byte](4 << 10) - var offset = Platform.LONG_ARRAY_OFFSET - val end = length * 8 + Platform.LONG_ARRAY_OFFSET - while (offset < end) { - val size = Math.min(buffer.length, end - offset) - in.readFully(buffer, 0, size) - Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) - offset += size - } - - val numBytes = in.readInt() - page = new Array[Byte](numBytes) - in.readFully(page) + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) } } -private[joins] class LongHashedRelation( - private var nFields: Int, - private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { - - private var resultRow: UnsafeRow = new UnsafeRow(nFields) +/** + * A relation that pack all the rows into a byte array, together with offsets and sizes. + * + * All the bytes of UnsafeRow are packed together as `bytes`: + * + * [ Row0 ][ Row1 ][] ... [ RowN ] + * + * With keys: + * + * start start+1 ... start+N + * + * `offsets` are offsets of UnsafeRows in the `bytes` + * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. + * + * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: + * + * start = 3 + * offsets = [0, 0, 24] + * sizes = [24, 0, 32] + * bytes = [0 - 24][][24 - 56] + */ +private[joins] final class LongArrayRelation( + private var numFields: Int, + private var start: Long, + private var offsets: Array[Int], + private var sizes: Array[Int], + private var bytes: Array[Byte] + ) extends LongHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) - def this() = this(0, null) + def this() = this(0, 0L, null, null, null) - override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) + override def keyIsUnique: Boolean = true - override def estimatedSize: Long = { - map.getTotalMemoryConsumption + override def asReadOnlyCopy(): LongArrayRelation = { + new LongArrayRelation(numFields, start, offsets, sizes, bytes) } - override def get(key: InternalRow): Iterator[InternalRow] = { - if (key.isNullAt(0)) { - null - } else { - get(key.getLong(0)) - } + override def getMemorySize: Long = { + offsets.length * 4 + sizes.length * 4 + bytes.length } - override def getValue(key: InternalRow): InternalRow = { - if (key.isNullAt(0)) { - null + override def get(key: Long): Iterator[InternalRow] = { + val row = getValue(key) + if (row != null) { + Seq(row).toIterator } else { - getValue(key.getLong(0)) + null } } - override def get(key: Long): Iterator[InternalRow] = - map.get(key, resultRow) - + var resultRow = new UnsafeRow(numFields) override def getValue(key: Long): InternalRow = { - map.getValue(key, resultRow) - } - - override def keyIsUnique: Boolean = map.keyIsUnique - - override def close(): Unit = { - map.free() + val idx = (key - start).toInt + if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { + resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) + resultRow + } else { + null + } } override def writeExternal(out: ObjectOutput): Unit = { - out.writeInt(nFields) - out.writeObject(map) + out.writeInt(numFields) + out.writeLong(start) + out.writeInt(sizes.length) + var i = 0 + while (i < sizes.length) { + out.writeInt(sizes(i)) + i += 1 + } + out.writeInt(bytes.length) + out.write(bytes) } override def readExternal(in: ObjectInput): Unit = { - nFields = in.readInt() - resultRow = new UnsafeRow(nFields) - map = in.readObject().asInstanceOf[LongToUnsafeRowMap] + numFields = in.readInt() + resultRow = new UnsafeRow(numFields) + start = in.readLong() + val length = in.readInt() + // read sizes of rows + sizes = new Array[Int](length) + offsets = new Array[Int](length) + var i = 0 + var offset = 0 + while (i < length) { + offsets(i) = offset + sizes(i) = in.readInt() + offset += sizes(i) + i += 1 + } + // read all the bytes + val total = in.readInt() + assert(total == offset) + bytes = new Array[Byte](total) + in.readFully(bytes) } } @@ -763,45 +466,96 @@ private[joins] class LongHashedRelation( * Create hashed relation with key that is long. */ private[joins] object LongHashedRelation { + + val DENSE_FACTOR = 0.2 + def apply( - input: Iterator[InternalRow], - key: Seq[Expression], - sizeEstimate: Int, - taskMemoryManager: TaskMemoryManager): LongHashedRelation = { + input: Iterator[InternalRow], + keyGenerator: Projection, + sizeEstimate: Int): HashedRelation = { - val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) - val keyGenerator = UnsafeProjection.create(key) + // TODO: use LongToBytesMap for better memory efficiency + val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of key -> rows var numFields = 0 + var keyIsUnique = true + var minKey = Long.MaxValue + var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.isNullAt(0)) { + if (!rowKey.anyNull) { val key = rowKey.getLong(0) - map.append(key, unsafeRow) + minKey = math.min(minKey, key) + maxKey = math.max(maxKey, key) + val existingMatchList = hashTable.get(key) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(key, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += unsafeRow + } + } + + if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { + // The keys are dense enough, so use LongArrayRelation + val length = (maxKey - minKey).toInt + 1 + val sizes = new Array[Int](length) + val offsets = new Array[Int](length) + var offset = 0 + var i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + offsets(i) = offset + sizes(i) = rows(0).getSizeInBytes + offset += sizes(i) + } + i += 1 + } + val bytes = new Array[Byte](offset) + i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) + } + i += 1 } + new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) + } else { + new GeneralLongHashedRelation(hashTable) } - map.optimize() - new LongHashedRelation(numFields, map) } } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ -private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) - extends BroadcastMode { +private[execution] case class HashedRelationBroadcastMode( + canJoinKeyFitWithinLong: Boolean, + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalizedKey, rows.length) + val generator = UnsafeProjection.create(keys, attributes) + HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) } - private lazy val canonicalizedKey: Seq[Expression] = { - key.map { e => e.canonicalized } + private lazy val canonicalizedKeys: Seq[Expression] = { + keys.map { e => + BindReferences.bindReference(e.canonicalized, attributes) + } } override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey + case m: HashedRelationBroadcastMode => + canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && + canonicalizedKeys == m.canonicalizedKeys case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 0c3e3c3fc18a1..bf86096379283 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.TaskContext +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.memory.MemoryMode import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -56,20 +57,54 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { val context = TaskContext.get() - val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) - // This relation is usually used until the end of task. + if (!canJoinKeyFitWithinLong) { + // build BytesToBytesMap + val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator) + // This relation is usually used until the end of task. + context.addTaskCompletionListener((t: TaskContext) => + relation.close() + ) + return relation + } + + // try to acquire some memory for the hash table, it could trigger other operator to free some + // memory. The memory acquired here will mostly be used until the end of task. + val memoryManager = context.taskMemoryManager() + var acquired = 0L + var used = 0L context.addTaskCompletionListener((t: TaskContext) => - relation.close() + memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) ) - relation + + val copiedIter = iter.map { row => + // It's hard to guess what's exactly memory will be used, we have a rough guess here. + // TODO: use LongToBytesMap instead of HashMap for memory efficiency + // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers + val needed = 150 + row.getSizeInBytes + if (needed > acquired - used) { + val got = memoryManager.acquireExecutionMemory( + Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) + acquired += got + if (got < needed) { + throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + + "hash join, please use sort merge join by setting " + + "spark.sql.join.preferSortMergeJoin=true") + } + } + used += needed + // HashedRelation requires that the UnsafeRow should be separate objects. + row.copy() + } + + HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => - val hashed = buildHashedRelation(buildIter) + val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) join(streamIter, hashed, numOutputRows) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 352fd07d0e8b0..5dbf61987635d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,7 +21,6 @@ import java.util.HashMap import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} -import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.AggregateHashMap @@ -180,8 +179,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X - Join w long codegen=true 321 / 371 65.3 15.3 9.3X + Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X + Join w long codegen=true 275 / 352 76.2 13.1 19.4X */ runBenchmark("Join w long duplicated", N) { @@ -194,8 +193,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X - Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X + Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X + Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X */ val dim2 = broadcast(sqlContext.range(M) @@ -212,8 +211,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X - Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X + Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X + Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X */ val dim3 = broadcast(sqlContext.range(M) @@ -260,8 +259,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X - outer join w long codegen=true 261 / 276 80.5 12.4 11.7X + outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X + outer join w long codegen=true 216 / 226 97.2 10.3 26.3X */ runBenchmark("semi join w long", N) { @@ -273,8 +272,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X - semi join w long codegen=true 237 / 244 88.3 11.3 8.1X + semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X + semi join w long codegen=true 211 / 229 99.2 10.1 22.2X */ } @@ -327,8 +326,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X - shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X + shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X + shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X */ } @@ -350,11 +349,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 20 << 20 + val N = 10 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) - benchmark.addCase("UnsafeRowhash") { iter => + benchmark.addCase("hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) @@ -369,34 +368,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } - benchmark.addCase("murmur3 hash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 - var s = 0 - while (i < N) { - var h = Murmur3_x86_32.hashLong(i, 42) - key.setInt(0, h) - s += h - i += 1 - } - } - benchmark.addCase("fast hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 var s = 0 while (i < N) { - var h = i % p - if (h < 0) { - h += p - } - key.setInt(0, h) + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashLong(i % 1000, 42) s += h i += 1 } @@ -495,42 +475,6 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } - Seq(false, true).foreach { optimized => - benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => - var i = 0 - val valueBytes = new Array[Byte](16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val taskMemoryManager = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - val map = new LongToUnsafeRowMap(taskMemoryManager, 64) - while (i < 65536) { - value.setInt(0, i) - val key = i % 100000 - map.append(key, value) - i += 1 - } - if (optimized) { - map.optimize() - } - var s = 0 - i = 0 - while (i < N) { - val key = i % 100000 - if (map.getValue(key, value) != null) { - s += 1 - } - i += 1 - } - } - } - Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -549,27 +493,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - val numKeys = 65536 - while (i < numKeys) { + while (i < N) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, Murmur3_x86_32.hashLong(i % 65536, 42)) - if (!loc.isDefined) { + if (loc.isDefined) { + value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + value.setInt(0, value.getInt(0) + 1) + i += 1 + } else { loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) } - i += 1 - } - i = 0 - var s = 0 - while (i < N) { - key.setInt(0, i % 100000) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - Murmur3_x86_32.hashLong(i % 100000, 42)) - if (loc.isDefined) { - s += 1 - } - i += 1 } } } @@ -600,19 +535,16 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - UnsafeRow hash 267 / 284 78.4 12.8 1.0X - murmur3 hash 102 / 129 205.5 4.9 2.6X - fast hash 79 / 96 263.8 3.8 3.4X - arrayEqual 164 / 172 128.2 7.8 1.6X - Java HashMap (Long) 321 / 399 65.4 15.3 0.8X - Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X - Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X - LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X - LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X - BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X - BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X - Aggregate HashMap 121 / 131 173.3 5.8 2.2X - */ + hash 112 / 116 93.2 10.7 1.0X + fast hash 65 / 69 160.9 6.2 1.7X + arrayEqual 66 / 69 159.1 6.3 1.7X + Java HashMap (Long) 137 / 182 76.3 13.1 0.8X + Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X + Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X + BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X + BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X + Aggregate HashMap 56 / 62 187.9 5.3 2.0X + */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 17f2343cf971e..9680f3a008a59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode - val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) - val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) + val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) + val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) @@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) - val hashMode = HashedRelationBroadcastMode(output) + val hashMode = HashedRelationBroadcastMode(true, output, plan.output) val exchange2 = BroadcastExchange(hashMode, plan) val hashMode2 = - HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) + HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) val exchange3 = BroadcastExchange(hashMode2, plan) val exchange4 = ReusedExchange(output, exchange3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 371a9ed617d65..ed87a99439521 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,23 +30,15 @@ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { - val mm = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()) - val buildKey = Seq(BoundReference(0, IntegerType, false)) - val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) @@ -108,45 +100,31 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } - test("LongToUnsafeRowMap") { + test("LongArrayRelation") { val unsafeProj = UnsafeProjection.create( Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) - val key = Seq(BoundReference(0, IntegerType, false)) - val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) - assert(longRelation.keyIsUnique) + val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) + val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) + assert(longRelation.isInstanceOf[LongArrayRelation]) + val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] (0 until 100).foreach { i => - val row = longRelation.getValue(i) + val row = longArrayRelation.getValue(i) assert(row.getInt(0) === i) assert(row.getInt(1) === i + 1) } - val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) - assert(!longRelation2.keyIsUnique) - (0 until 100).foreach { i => - val rows = longRelation2.get(i).toArray - assert(rows.length === 2) - assert(rows(0).getInt(0) === i) - assert(rows(0).getInt(1) === i + 1) - assert(rows(1).getInt(0) === i) - assert(rows(1).getInt(1) === i + 1) - } - val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) - longRelation2.writeExternal(out) + longArrayRelation.writeExternal(out) out.flush() val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) - val relation = new LongHashedRelation() + val relation = new LongArrayRelation() relation.readExternal(in) - assert(!relation.keyIsUnique) (0 until 100).foreach { i => - val rows = relation.get(i).toArray - assert(rows.length === 2) - assert(rows(0).getInt(0) === i) - assert(rows(0).getInt(1) === i + 1) - assert(rows(1).getInt(0) === i) - assert(rows(1).getInt(1) === i + 1) + val row = longArrayRelation.getValue(i) + assert(row.getInt(0) === i) + assert(row.getInt(1) === i + 1) } } }