Skip to content

Commit

Permalink
[SPARK-38415][SQL] Update the histogram_numeric (x, y) result type to…
Browse files Browse the repository at this point in the history
… make x == the input type

### What changes were proposed in this pull request?

This pull request updates the histogram_numeric SQL function to support more numeric input types, returning the results an an array of structs of two fields each. The first field has the same type as the first argument to the histogram_numeric aggregate function (rather than always having double type before this change). This removes the need for the user to apply a cast function to the result in order to use it.

Example behavior after this change becomes effective:

SELECT histogram_numeric(col, 3) FROM VALUES (TIMESTAMP '2017-03-01 00:00:00'),
(TIMESTAMP '2017-04-01 00:00:00'), (TIMESTAMP '2017-05-01 00:00:00') AS tab(col);

Returns type: struct<histogram_numeric(col, 3):array<struct<x:timestamp,y:double>>>.
Query output: [{"x":2017-03-01 00:00:00,"y":1.0},{"x":2017-04-01 00:00:00,"y":1.0},{"x":2017-05-01 00:00:00,"y":1.0}].

### Why are the changes needed?

This removes the need for users to explicitly cast the function result type in many cases.

### Does this PR introduce _any_ user-facing change?

Yes, it changes the `histogram_numeric` function result type.

### How was this patch tested?

Unit tests, file-based query tests.

Closes apache#35735 from dtenedor/numeric-histogram-types.

Lead-authored-by: Daniel Tenedorio <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Mar 14, 2022
1 parent 2844a18 commit 130bcce
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ license: |

## Upgrading from Spark SQL 3.2 to 3.3

- Since Spark 3.3, the `histogram_numeric` function in Spark SQL returns an output type of an array of structs (x, y), where the type of the 'x' field in the return value is propagated from the input values consumed in the aggregate function. In Spark 3.2 or earlier, 'x' always had double type. Optionally, use the configuration `spark.sql.legacy.histogramNumericPropagateInputType` since Spark 3.3 to revert back to the previous behavior.

- Since Spark 3.3, `DayTimeIntervalType` in Spark SQL is mapped to Arrow's `Duration` type in `ArrowWriter` and `ArrowColumnVector` developer APIs. Previously, `DayTimeIntervalType` was mapped to Arrow's `Interval` type which does not match with the types of other languages Spark SQL maps. For example, `DayTimeIntervalType` is mapped to `java.time.Duration` in Java.

- Since Spark 3.3, the functions `lpad` and `rpad` have been overloaded to support byte sequences. When the first argument is a byte sequence, the optional padding pattern must also be a byte sequence and the result is a BINARY value. The default padding pattern in this case is the zero byte.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure,
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DateType, DayTimeIntervalType, DoubleType, IntegerType, NumericType, StructField, StructType, TimestampNTZType, TimestampType, TypeCollection, YearMonthIntervalType}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.NumericHistogram

/**
Expand All @@ -46,12 +47,13 @@ import org.apache.spark.sql.util.NumericHistogram
smaller datasets. Note that this function creates a histogram with non-uniform
bin widths. It offers no guarantees in terms of the mean-squared-error of the
histogram, but in practice is comparable to the histograms produced by the R/S-Plus
statistical computing packages.
statistical computing packages. Note: the output type of the 'x' field in the return value is
propagated from the input value consumed in the aggregate function.
""",
examples = """
Examples:
> SELECT _FUNC_(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col);
[{"x":0.0,"y":1.0},{"x":1.0,"y":1.0},{"x":2.0,"y":1.0},{"x":10.0,"y":1.0}]
[{"x":0,"y":1.0},{"x":1,"y":1.0},{"x":2,"y":1.0},{"x":10,"y":1.0}]
""",
group = "agg_funcs",
since = "3.3.0")
Expand All @@ -72,6 +74,8 @@ case class HistogramNumeric(
case n: Int => n
}

private lazy val propagateInputType: Boolean = SQLConf.get.histogramNumericPropagateInputType

override def inputTypes: Seq[AbstractDataType] = {
// Support NumericType, DateType, TimestampType and TimestampNTZType, YearMonthIntervalType,
// DayTimeIntervalType since their internal types are all numeric,
Expand Down Expand Up @@ -124,8 +128,33 @@ case class HistogramNumeric(
null
} else {
val result = (0 until buffer.getUsedBins).map { index =>
// Note that the 'coord.x' and 'coord.y' have double-precision floating point type here.
val coord = buffer.getBin(index)
InternalRow.apply(coord.x, coord.y)
if (propagateInputType) {
// If the SQLConf.spark.sql.legacy.histogramNumericPropagateInputType is set to true,
// we need to internally convert the 'coord.x' value to the expected result type, for
// cases like integer types, timestamps, and intervals which are valid inputs to the
// numeric histogram aggregate function. For example, in this case:
// 'SELECT histogram_numeric(val, 3) FROM VALUES (0L), (1L), (2L), (10L) AS tab(col)'
// returns an array of structs where the first field has LongType.
val result: Any = left.dataType match {
case ByteType => coord.x.toByte
case IntegerType | DateType | _: YearMonthIntervalType =>
coord.x.toInt
case FloatType => coord.x.toFloat
case ShortType => coord.x.toShort
case _: DayTimeIntervalType | LongType | TimestampType | TimestampNTZType =>
coord.x.toLong
case _ => coord.x
}
InternalRow.apply(result, coord.y)
} else {
// Otherwise, just apply the double-precision values in 'coord.x' and 'coord.y' to the
// output row directly. In this case: 'SELECT histogram_numeric(val, 3)
// FROM VALUES (0L), (1L), (2L), (10L) AS tab(col)' returns an array of structs where the
// first field has DoubleType.
InternalRow.apply(coord.x, coord.y)
}
}
new GenericArrayData(result)
}
Expand Down Expand Up @@ -157,10 +186,17 @@ case class HistogramNumeric(

override def nullable: Boolean = true

override def dataType: DataType =
override def dataType: DataType = {
// If the SQLConf.spark.sql.legacy.histogramNumericPropagateInputType is set to true,
// the output data type of this aggregate function is an array of structs, where each struct
// has two fields (x, y): one of the same data type as the left child and another of double
// type. Otherwise, the 'x' field always has double type.
ArrayType(new StructType(Array(
StructField("x", DoubleType, true),
StructField(name = "x",
dataType = if (propagateInputType) left.dataType else DoubleType,
nullable = true),
StructField("y", DoubleType, true))), true)
}

override def prettyName: String = "histogram_numeric"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3561,6 +3561,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE =
buildConf("spark.sql.legacy.histogramNumericPropagateInputType")
.internal()
.doc("The histogram_numeric function computes a histogram on numeric 'expr' using nb bins. " +
"The return value is an array of (x,y) pairs representing the centers of the histogram's " +
"bins. If this config is set to true, the output type of the 'x' field in the return " +
"value is propagated from the input value consumed in the aggregate function. Otherwise, " +
"'x' always has double type.")
.version("3.3.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -4299,6 +4311,9 @@ class SQLConf extends Serializable with Logging {

def useV1Command: Boolean = getConf(SQLConf.LEGACY_USE_V1_COMMAND)

def histogramNumericPropagateInputType: Boolean =
getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,25 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import java.sql.Timestamp
import java.time.{Duration, Period}

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions.{DslString, DslSymbol}
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.NumericHistogram

class HistogramNumericSuite extends SparkFunSuite {
class HistogramNumericSuite extends SparkFunSuite with SQLHelper with Logging {

private val random = new java.util.Random()

Expand Down Expand Up @@ -76,7 +83,6 @@ class HistogramNumericSuite extends SparkFunSuite {
}

test("class HistogramNumeric, sql string") {
val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
assertEqual(s"histogram_numeric(a, 3)",
new HistogramNumeric("a".attr, Literal(3)).sql: String)

Expand Down Expand Up @@ -106,23 +112,47 @@ class HistogramNumericSuite extends SparkFunSuite {
}

test("class HistogramNumeric, automatically add type casting for parameters") {
val testRelation = LocalRelation('a.int)
// These are the types of input relations under test. We exercise the unit test with several
// input column types to inspect the behavior of query analysis for the aggregate function.
val relations = Seq(LocalRelation('a.double),
LocalRelation('a.int),
LocalRelation('a.timestamp),
LocalRelation('a.dayTimeInterval()),
LocalRelation('a.yearMonthInterval()))

// accuracy types must be integral, no type casting
// These are the types of the second 'nbins' argument to the aggregate function.
// These accuracy types must be integral, no type casting is allowed.
val nBinsExpressions = Seq(
Literal(2.toByte),
Literal(100.toShort),
Literal(100),
Literal(1000L))

nBinsExpressions.foreach { nBins =>
// Iterate through each of the input relation column types and 'nbins' expression types under
// test.
for {
relation <- relations
nBins <- nBinsExpressions
} {
// We expect each relation under test to have exactly one output attribute.
assert(relation.output.length == 1)
val relationAttributeType = relation.output(0).dataType
val agg = new HistogramNumeric(UnresolvedAttribute("a"), nBins)
val analyzed = testRelation.select(agg).analyze.expressions.head
val analyzed = relation.select(agg).analyze.expressions.head
analyzed match {
case Alias(agg: HistogramNumeric, _) =>
assert(agg.resolved)
assert(agg.child.dataType == IntegerType)
assert(agg.child.dataType == relationAttributeType)
assert(agg.nBins.dataType == IntegerType)
// We expect the output type of the histogram aggregate function to be an array of structs
// where the first element of each struct has the same type as the original input
// attribute.
val expectedType =
ArrayType(
StructType(Seq(
StructField("x", relationAttributeType, nullable = true),
StructField("y", DoubleType, nullable = true))))
assert(agg.dataType == expectedType)
case _ => fail()
}
}
Expand Down Expand Up @@ -151,6 +181,84 @@ class HistogramNumericSuite extends SparkFunSuite {
assert(agg.eval(buffer) != null)
}

test("class HistogramNumeric, exercise many different numeric input types") {
val inputs = Seq(
(Literal(null),
Literal(null),
Literal(null)),
(Literal(0),
Literal(1),
Literal(2)),
(Literal(0L),
Literal(1L),
Literal(2L)),
(Literal(0.toShort),
Literal(1.toShort),
Literal(2.toShort)),
(Literal(0F),
Literal(1F),
Literal(2F)),
(Literal(0D),
Literal(1D),
Literal(2D)),
(Literal(Timestamp.valueOf("2017-03-01 00:00:00")),
Literal(Timestamp.valueOf("2017-03-02 00:00:00")),
Literal(Timestamp.valueOf("2017-03-03 00:00:00"))),
(Literal(Duration.ofSeconds(1111)),
Literal(Duration.ofSeconds(1211)),
Literal(Duration.ofSeconds(1311))),
(Literal(Period.ofMonths(10)),
Literal(Period.ofMonths(11)),
Literal(Period.ofMonths(12))))
for ((left, middle, right) <- inputs) {
// Check that the 'propagateInputType' bit correctly toggles the output type.
withSQLConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE.key -> "false") {
val aggDoubleOutputType = new HistogramNumeric(
BoundReference(0, left.dataType, nullable = true), Literal(5))
assert(aggDoubleOutputType.dataType match {
case ArrayType(StructType(Array(
StructField("x", DoubleType, _, _),
StructField("y", _, _, _))), true) => true
})
}
val aggPropagateOutputType = new HistogramNumeric(
BoundReference(0, left.dataType, nullable = true), Literal(5))
assert(aggPropagateOutputType.left.dataType ==
(aggPropagateOutputType.dataType match {
case
ArrayType(StructType(Array(
StructField("x", lhs@_, true, _),
StructField("y", _, true, _))), true) => lhs
}))
// Now consume some input values and check the result.
val buffer = new GenericInternalRow(new Array[Any](1))
aggPropagateOutputType.initialize(buffer)
// Consume three non-empty rows in the aggregation.
aggPropagateOutputType.update(buffer, InternalRow(left.value))
aggPropagateOutputType.update(buffer, InternalRow(middle.value))
aggPropagateOutputType.update(buffer, InternalRow(right.value))
// Evaluate the aggregate function.
val result = aggPropagateOutputType.eval(buffer)
if (left.dataType != NullType) {
assert(result != null)
// Sanity-check the sum of the heights.
var ys = 0.0
result match {
case v: GenericArrayData =>
for (row <- v.array) {
row match {
case r: GenericInternalRow =>
assert(r.values.length == 2)
ys += r.values(1).asInstanceOf[Double]
}
}
}
assert(ys > 1)
}
// As a basic sanity check, the sum of the heights of the bins should be greater than one.
}
}

private def compareEquals(left: NumericHistogram, right: NumericHistogram): Boolean = {
left.getNumBins == right.getNumBins && left.getUsedBins == right.getUsedBins &&
(0 until left.getUsedBins).forall { i =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.CovSample | covar_samp | SELECT covar_samp(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct<covar_samp(c1, c2):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first_value | SELECT first_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first_value(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.HistogramNumeric | histogram_numeric | SELECT histogram_numeric(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<histogram_numeric(col, 5):array<struct<x:double,y:double>>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.HistogramNumeric | histogram_numeric | SELECT histogram_numeric(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<histogram_numeric(col, 5):array<struct<x:int,y:double>>> |
| org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus | approx_count_distinct | SELECT approx_count_distinct(col1) FROM VALUES (1), (1), (2), (2), (3) tab(col1) | struct<approx_count_distinct(col1):bigint> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct<kurtosis(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last(col):int> |
Expand Down
27 changes: 27 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/group-by.sql
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ FROM testData
GROUP BY a IS NULL;


-- Histogram aggregates with different numeric input types
SELECT
histogram_numeric(col, 2) as histogram_2,
histogram_numeric(col, 3) as histogram_3,
Expand All @@ -210,6 +211,32 @@ FROM VALUES
(21), (22), (23), (24), (25), (26), (27), (28), (29), (30),
(31), (32), (33), (34), (35), (3), (37), (38), (39), (40),
(41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (1), (2), (3) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (1L), (2L), (3L) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (1F), (2F), (3F) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (1D), (2D), (3D) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (1S), (2S), (3S) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES
(CAST(1 AS BYTE)), (CAST(2 AS BYTE)), (CAST(3 AS BYTE)) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES
(CAST(1 AS TINYINT)), (CAST(2 AS TINYINT)), (CAST(3 AS TINYINT)) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES
(CAST(1 AS SMALLINT)), (CAST(2 AS SMALLINT)), (CAST(3 AS SMALLINT)) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES
(CAST(1 AS BIGINT)), (CAST(2 AS BIGINT)), (CAST(3 AS BIGINT)) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (TIMESTAMP '2017-03-01 00:00:00'),
(TIMESTAMP '2017-04-01 00:00:00'), (TIMESTAMP '2017-05-01 00:00:00') AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (INTERVAL '100-00' YEAR TO MONTH),
(INTERVAL '110-00' YEAR TO MONTH), (INTERVAL '120-00' YEAR TO MONTH) AS tab(col);
SELECT histogram_numeric(col, 3) FROM VALUES (INTERVAL '12 20:4:0' DAY TO SECOND),
(INTERVAL '12 21:4:0' DAY TO SECOND), (INTERVAL '12 22:4:0' DAY TO SECOND) AS tab(col);
SELECT histogram_numeric(col, 3)
FROM VALUES (NULL), (NULL), (NULL) AS tab(col);
SELECT histogram_numeric(col, 3)
FROM VALUES (CAST(NULL AS DOUBLE)), (CAST(NULL AS DOUBLE)), (CAST(NULL AS DOUBLE)) AS tab(col);
SELECT histogram_numeric(col, 3)
FROM VALUES (CAST(NULL AS INT)), (CAST(NULL AS INT)), (CAST(NULL AS INT)) AS tab(col);


-- SPARK-37613: Support ANSI Aggregate Function: regr_count
SELECT regr_count(y, x) FROM testRegression;
Expand Down
Loading

0 comments on commit 130bcce

Please sign in to comment.