Skip to content

Commit

Permalink
[SPARK-17490][SQL] Optimize SerializeFromObject() for a primitive array
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Waiting for merging apache#13680

This PR optimizes `SerializeFromObject()` for an primitive array. This is derived from apache#13758 to address one of problems by using a simple way in apache#13758.

The current implementation always generates `GenericArrayData` from `SerializeFromObject()` for any type of an array in a logical plan. This involves a boxing at a constructor of `GenericArrayData` when `SerializedFromObject()` has an primitive array.

This PR enables to generate `UnsafeArrayData` from `SerializeFromObject()` for a primitive array. It can avoid boxing to create an instance of `ArrayData` in the generated code by Catalyst.

This PR also generate `UnsafeArrayData` in a case for `RowEncoder.serializeFor` or `CatalystTypeConverters.createToCatalystConverter`.

Performance improvement of `SerializeFromObject()` is up to 2.0x

```
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)

Without this PR
Write an array in Dataset:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            556 /  608         15.1          66.3       1.0X
Double                                        1668 / 1746          5.0         198.8       0.3X

with this PR
Write an array in Dataset:               Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
Int                                            352 /  401         23.8          42.0       1.0X
Double                                         821 /  885         10.2          97.9       0.4X
```

Here is an example program that will happen in mllib as described in [SPARK-16070](https://issues.apache.org/jira/browse/SPARK-16070).

```
sparkContext.parallelize(Seq(Array(1, 2)), 1).toDS.map(e => e).show
```

Generated code before applying this PR

``` java
/* 039 */   protected void processNext() throws java.io.IOException {
/* 040 */     while (inputadapter_input.hasNext()) {
/* 041 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 042 */       int[] inputadapter_value = (int[])inputadapter_row.get(0, null);
/* 043 */
/* 044 */       Object mapelements_obj = ((Expression) references[0]).eval(null);
/* 045 */       scala.Function1 mapelements_value1 = (scala.Function1) mapelements_obj;
/* 046 */
/* 047 */       boolean mapelements_isNull = false || false;
/* 048 */       int[] mapelements_value = null;
/* 049 */       if (!mapelements_isNull) {
/* 050 */         Object mapelements_funcResult = null;
/* 051 */         mapelements_funcResult = mapelements_value1.apply(inputadapter_value);
/* 052 */         if (mapelements_funcResult == null) {
/* 053 */           mapelements_isNull = true;
/* 054 */         } else {
/* 055 */           mapelements_value = (int[]) mapelements_funcResult;
/* 056 */         }
/* 057 */
/* 058 */       }
/* 059 */       mapelements_isNull = mapelements_value == null;
/* 060 */
/* 061 */       serializefromobject_argIsNulls[0] = mapelements_isNull;
/* 062 */       serializefromobject_argValue = mapelements_value;
/* 063 */
/* 064 */       boolean serializefromobject_isNull = false;
/* 065 */       for (int idx = 0; idx < 1; idx++) {
/* 066 */         if (serializefromobject_argIsNulls[idx]) { serializefromobject_isNull = true; break; }
/* 067 */       }
/* 068 */
/* 069 */       final ArrayData serializefromobject_value = serializefromobject_isNull ? null : new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_argValue);
/* 070 */       serializefromobject_holder.reset();
/* 071 */
/* 072 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 073 */
/* 074 */       if (serializefromobject_isNull) {
/* 075 */         serializefromobject_rowWriter.setNullAt(0);
/* 076 */       } else {
/* 077 */         // Remember the current cursor so that we can calculate how many bytes are
/* 078 */         // written later.
/* 079 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 080 */
/* 081 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 082 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 083 */           // grow the global buffer before writing data.
/* 084 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 085 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 086 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 087 */
/* 088 */         } else {
/* 089 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 090 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 091 */
/* 092 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 093 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 094 */               serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 095 */             } else {
/* 096 */               final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 097 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 098 */             }
/* 099 */           }
/* 100 */         }
/* 101 */
/* 102 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 103 */       }
/* 104 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 105 */       append(serializefromobject_result);
/* 106 */       if (shouldStop()) return;
/* 107 */     }
/* 108 */   }
/* 109 */ }
```

Generated code after applying this PR

``` java
/* 035 */   protected void processNext() throws java.io.IOException {
/* 036 */     while (inputadapter_input.hasNext()) {
/* 037 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 038 */       int[] inputadapter_value = (int[])inputadapter_row.get(0, null);
/* 039 */
/* 040 */       Object mapelements_obj = ((Expression) references[0]).eval(null);
/* 041 */       scala.Function1 mapelements_value1 = (scala.Function1) mapelements_obj;
/* 042 */
/* 043 */       boolean mapelements_isNull = false || false;
/* 044 */       int[] mapelements_value = null;
/* 045 */       if (!mapelements_isNull) {
/* 046 */         Object mapelements_funcResult = null;
/* 047 */         mapelements_funcResult = mapelements_value1.apply(inputadapter_value);
/* 048 */         if (mapelements_funcResult == null) {
/* 049 */           mapelements_isNull = true;
/* 050 */         } else {
/* 051 */           mapelements_value = (int[]) mapelements_funcResult;
/* 052 */         }
/* 053 */
/* 054 */       }
/* 055 */       mapelements_isNull = mapelements_value == null;
/* 056 */
/* 057 */       boolean serializefromobject_isNull = mapelements_isNull;
/* 058 */       final ArrayData serializefromobject_value = serializefromobject_isNull ? null : org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.fromPrimitiveArray(mapelements_value);
/* 059 */       serializefromobject_isNull = serializefromobject_value == null;
/* 060 */       serializefromobject_holder.reset();
/* 061 */
/* 062 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 063 */
/* 064 */       if (serializefromobject_isNull) {
/* 065 */         serializefromobject_rowWriter.setNullAt(0);
/* 066 */       } else {
/* 067 */         // Remember the current cursor so that we can calculate how many bytes are
/* 068 */         // written later.
/* 069 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 070 */
/* 071 */         if (serializefromobject_value instanceof UnsafeArrayData) {
/* 072 */           final int serializefromobject_sizeInBytes = ((UnsafeArrayData) serializefromobject_value).getSizeInBytes();
/* 073 */           // grow the global buffer before writing data.
/* 074 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 075 */           ((UnsafeArrayData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 076 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 077 */
/* 078 */         } else {
/* 079 */           final int serializefromobject_numElements = serializefromobject_value.numElements();
/* 080 */           serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 081 */
/* 082 */           for (int serializefromobject_index = 0; serializefromobject_index < serializefromobject_numElements; serializefromobject_index++) {
/* 083 */             if (serializefromobject_value.isNullAt(serializefromobject_index)) {
/* 084 */               serializefromobject_arrayWriter.setNullInt(serializefromobject_index);
/* 085 */             } else {
/* 086 */               final int serializefromobject_element = serializefromobject_value.getInt(serializefromobject_index);
/* 087 */               serializefromobject_arrayWriter.write(serializefromobject_index, serializefromobject_element);
/* 088 */             }
/* 089 */           }
/* 090 */         }
/* 091 */
/* 092 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 093 */       }
/* 094 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 095 */       append(serializefromobject_result);
/* 096 */       if (shouldStop()) return;
/* 097 */     }
/* 098 */   }
/* 099 */ }
```
## How was this patch tested?

Added a test in `DatasetSuite`, `RowEncoderSuite`, and `CatalystTypeConvertersSuite`

Author: Kazuaki Ishizaki <[email protected]>

Closes apache#15044 from kiszk/SPARK-17490.
  • Loading branch information
kiszk authored and hvanhovell committed Nov 7, 2016
1 parent 8f0ea01 commit 19cf208
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,22 @@ object ScalaReflection extends ScalaReflection {
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
MapObjects(serializerFor(_, elementType, newPath), input, dt)

case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType) =>
val cls = input.dataType.asInstanceOf[ObjectType].cls
if (cls.isArray && cls.getComponentType.isPrimitive) {
StaticInvoke(
classOf[UnsafeArrayData],
ArrayType(dt, false),
"fromPrimitiveArray",
input :: Nil)
} else {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(dt, schemaFor(elementType).nullable))
}

case dt =>
NewInstance(
classOf[GenericArrayData],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkException
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects._
Expand Down Expand Up @@ -119,18 +119,19 @@ object RowEncoder {
"fromString",
inputObject :: Nil)

case t @ ArrayType(et, _) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
// TODO: validate input type for primitive array.
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
ObjectType(classOf[Object]))
}
case t @ ArrayType(et, cn) =>
et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
StaticInvoke(
classOf[ArrayData],
t,
"toArrayData",
inputObject :: Nil)
case _ => MapObjects(
element => serializerFor(ValidateExternalType(element, et), et),
inputObject,
ObjectType(classOf[Object]))
}

case t @ MapType(kt, vt, valueNullable) =>
val keys =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,22 @@ package org.apache.spark.sql.catalyst.util

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
import org.apache.spark.sql.types.DataType

object ArrayData {
def toArrayData(input: Any): ArrayData = input match {
case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a)
case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a)
case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a)
case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a)
case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a)
case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a)
case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a)
case other => new GenericArrayData(other)
}
}

abstract class ArrayData extends SpecializedGetters with Serializable {
def numElements(): Int

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._

class CatalystTypeConvertersSuite extends SparkFunSuite {
Expand Down Expand Up @@ -61,4 +63,35 @@ class CatalystTypeConvertersSuite extends SparkFunSuite {
test("option handling in createToCatalystConverter") {
assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123)
}

test("primitive array handling") {
val intArray = Array(1, 100, 10000)
val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray)
val intArrayType = ArrayType(IntegerType, false)
assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray)

val doubleArray = Array(1.1, 111.1, 11111.1)
val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray)
val doubleArrayType = ArrayType(DoubleType, false)
assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray)
=== doubleArray)
}

test("An array with null handling") {
val intArray = Array(1, null, 100, null, 10000)
val intGenericArray = new GenericArrayData(intArray)
val intArrayType = ArrayType(IntegerType, true)
assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intGenericArray)
=== intArray)
assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray)
== intGenericArray)

val doubleArray = Array(1.1, null, 111.1, null, 11111.1)
val doubleGenericArray = new GenericArrayData(doubleArray)
val doubleArrayType = ArrayType(DoubleType, true)
assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleGenericArray)
=== doubleArray)
assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray)
== doubleGenericArray)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,32 @@ class RowEncoderSuite extends SparkFunSuite {
assert(encoder.serializer.head.nullable == false)
}

test("RowEncoder should support primitive arrays") {
val schema = new StructType()
.add("booleanPrimitiveArray", ArrayType(BooleanType, false))
.add("bytePrimitiveArray", ArrayType(ByteType, false))
.add("shortPrimitiveArray", ArrayType(ShortType, false))
.add("intPrimitiveArray", ArrayType(IntegerType, false))
.add("longPrimitiveArray", ArrayType(LongType, false))
.add("floatPrimitiveArray", ArrayType(FloatType, false))
.add("doublePrimitiveArray", ArrayType(DoubleType, false))
val encoder = RowEncoder(schema).resolveAndBind()
val input = Seq(
Array(true, false),
Array(1.toByte, 64.toByte, Byte.MaxValue),
Array(1.toShort, 255.toShort, Short.MaxValue),
Array(1, 10000, Int.MaxValue),
Array(1.toLong, 1000000.toLong, Long.MaxValue),
Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue),
Array(11.1111, 123456.7890123, Double.MaxValue)
)
val row = encoder.toRow(Row.fromSeq(input))
val convertedBack = encoder.fromRow(row)
input.zipWithIndex.map { case (array, index) =>
assert(convertedBack.getSeq(index) === array)
}
}

test("RowEncoder should support array as the external type for ArrayType") {
val schema = new StructType()
.add("array", ArrayType(IntegerType))
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,24 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(agg, ds.groupBy('id % 2).agg(count('id)))
}
}

test("identity map for primitive arrays") {
val arrayByte = Array(1.toByte, 2.toByte, 3.toByte)
val arrayInt = Array(1, 2, 3)
val arrayLong = Array(1.toLong, 2.toLong, 3.toLong)
val arrayDouble = Array(1.1, 2.2, 3.3)
val arrayString = Array("a", "b", "c")
val dsByte = sparkContext.parallelize(Seq(arrayByte), 1).toDS.map(e => e)
val dsInt = sparkContext.parallelize(Seq(arrayInt), 1).toDS.map(e => e)
val dsLong = sparkContext.parallelize(Seq(arrayLong), 1).toDS.map(e => e)
val dsDouble = sparkContext.parallelize(Seq(arrayDouble), 1).toDS.map(e => e)
val dsString = sparkContext.parallelize(Seq(arrayString), 1).toDS.map(e => e)
checkDataset(dsByte, arrayByte)
checkDataset(dsInt, arrayInt)
checkDataset(dsLong, arrayLong)
checkDataset(dsDouble, arrayDouble)
checkDataset(dsString, arrayString)
}
}

case class Generic[T](id: T, value: Double)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.benchmark

import scala.concurrent.duration._

import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.util.Benchmark

/**
* Benchmark [[PrimitiveArray]] for DataFrame and Dataset program using primitive array
* To run this:
* 1. replace ignore(...) with test(...)
* 2. build/sbt "sql/test-only *benchmark.PrimitiveArrayBenchmark"
*
* Benchmarks in this file are skipped in normal builds.
*/
class PrimitiveArrayBenchmark extends BenchmarkBase {

def writeDatasetArray(iters: Int): Unit = {
import sparkSession.implicits._

val count = 1024 * 1024 * 2

val sc = sparkSession.sparkContext
val primitiveIntArray = Array.fill[Int](count)(65535)
val dsInt = sc.parallelize(Seq(primitiveIntArray), 1).toDS
dsInt.count // force to build dataset
val intArray = { i: Int =>
var n = 0
var len = 0
while (n < iters) {
len += dsInt.map(e => e).queryExecution.toRdd.collect.length
n += 1
}
}
val primitiveDoubleArray = Array.fill[Double](count)(65535.0)
val dsDouble = sc.parallelize(Seq(primitiveDoubleArray), 1).toDS
dsDouble.count // force to build dataset
val doubleArray = { i: Int =>
var n = 0
var len = 0
while (n < iters) {
len += dsDouble.map(e => e).queryExecution.toRdd.collect.length
n += 1
}
}

val benchmark = new Benchmark("Write an array in Dataset", count * iters)
benchmark.addCase("Int ")(intArray)
benchmark.addCase("Double")(doubleArray)
benchmark.run
/*
OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64
Intel Xeon E3-12xx v2 (Ivy Bridge)
Write an array in Dataset: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------
Int 352 / 401 23.8 42.0 1.0X
Double 821 / 885 10.2 97.9 0.4X
*/
}

ignore("Write an array in Dataset") {
writeDatasetArray(4)
}
}

0 comments on commit 19cf208

Please sign in to comment.