Skip to content

Commit

Permalink
[SPARK-27768][SQL] Support Infinity/NaN-related float/double literals…
Browse files Browse the repository at this point in the history
… case-insensitively

## What changes were proposed in this pull request?
Here is the problem description from the JIRA.
```
When the inputs contain the constant 'infinity', Spark SQL does not generate the expected results.

SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('1'), (CAST('infinity' AS DOUBLE))) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('infinity'), ('1')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('infinity'), ('infinity')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('-infinity'), ('infinity')) v(x);
 The root cause: Spark SQL does not recognize the special constants in a case insensitive way. In PostgreSQL, they are recognized in a case insensitive way.

Link: https://www.postgresql.org/docs/9.3/datatype-numeric.html
```

In this PR, the casting code is enhanced to handle these `special` string literals in case insensitive manner.

## How was this patch tested?
Added tests in CastSuite and modified existing test suites.

Closes apache#25331 from dilipbiswal/double_infinity.

Authored-by: Dilip Biswal <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
dilipbiswal authored and dongjoon-hyun committed Aug 13, 2019
1 parent f1d6b19 commit 331f265
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 62 deletions.
89 changes: 89 additions & 0 deletions docs/sql-migration-guide-upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,95 @@ license: |

- Since Spark 3.0, Dataset query fails if it contains ambiguous column reference that is caused by self join. A typical example: `val df1 = ...; val df2 = df1.filter(...);`, then `df1.join(df2, df1("a") > df2("a"))` returns an empty result which is quite confusing. This is because Spark cannot resolve Dataset column references that point to tables being self joined, and `df1("a")` is exactly the same as `df2("a")` in Spark. To restore the behavior before Spark 3.0, you can set `spark.sql.analyzer.failAmbiguousSelfJoin` to `false`.

- Since Spark 3.0, `Cast` function processes string literals such as 'Infinity', '+Infinity', '-Infinity', 'NaN', 'Inf', '+Inf', '-Inf' in case insensitive manner when casting the literals to `Double` or `Float` type to ensure greater compatibility with other database systems. This behaviour change is illustrated in the table below:
<table class="table">
<tr>
<th>
<b>Operation</b>
</th>
<th>
<b>Result prior to Spark 3.0</b>
</th>
<th>
<b>Result starting Spark 3.0</b>
</th>
</tr>
<tr>
<td>
CAST('infinity' AS DOUBLE)<br>
CAST('+infinity' AS DOUBLE)<br>
CAST('inf' AS DOUBLE)<br>
CAST('+inf' AS DOUBLE)<br>
</td>
<td>
NULL
</td>
<td>
Double.PositiveInfinity
</td>
</tr>
<tr>
<td>
CAST('-infinity' AS DOUBLE)<br>
CAST('-inf' AS DOUBLE)<br>
</td>
<td>
NULL
</td>
<td>
Double.NegativeInfinity
</td>
</tr>
<tr>
<td>
CAST('infinity' AS FLOAT)<br>
CAST('+infinity' AS FLOAT)<br>
CAST('inf' AS FLOAT)<br>
CAST('+inf' AS FLOAT)<br>
</td>
<td>
NULL
</td>
<td>
Float.PositiveInfinity
</td>
</tr>
<tr>
<td>
CAST('-infinity' AS FLOAT)<br>
CAST('-inf' AS FLOAT)<br>
</td>
<td>
NULL
</td>
<td>
Float.NegativeInfinity
</td>
</tr>
<tr>
<td>
CAST('nan' AS DOUBLE)
</td>
<td>
NULL
</td>
<td>
Double.NaN
</td>
</tr>
<tr>
<td>
CAST('nan' AS FLOAT)
</td>
<td>
NULL
</td>
<td>
Float.NaN
</td>
</tr>
</table>

## Upgrading from Spark SQL 2.4 to 2.4.1

- The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.math.{BigDecimal => JavaBigDecimal}
import java.time.ZoneId
import java.util.Locale
import java.util.concurrent.TimeUnit._

import org.apache.spark.SparkException
Expand Down Expand Up @@ -193,6 +194,22 @@ object Cast {
}

def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to

/**
* We process literals such as 'Infinity', 'Inf', '-Infinity' and 'NaN' etc in case
* insensitive manner to be compatible with other database systems such as PostgreSQL and DB2.
*/
def processFloatingPointSpecialLiterals(v: String, isFloat: Boolean): Any = {
v.trim.toLowerCase(Locale.ROOT) match {
case "inf" | "+inf" | "infinity" | "+infinity" =>
if (isFloat) Float.PositiveInfinity else Double.PositiveInfinity
case "-inf" | "-infinity" =>
if (isFloat) Float.NegativeInfinity else Double.NegativeInfinity
case "nan" =>
if (isFloat) Float.NaN else Double.NaN
case _ => null
}
}
}

/**
Expand Down Expand Up @@ -563,8 +580,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// DoubleConverter
private[this] def castToDouble(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toString.toDouble catch {
case _: NumberFormatException => null
buildCast[UTF8String](_, s => {
val doubleStr = s.toString
try doubleStr.toDouble catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(doubleStr, false)
}
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1d else 0d)
Expand All @@ -579,8 +600,12 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
// FloatConverter
private[this] def castToFloat(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try s.toString.toFloat catch {
case _: NumberFormatException => null
buildCast[UTF8String](_, s => {
val floatStr = s.toString
try floatStr.toFloat catch {
case _: NumberFormatException =>
Cast.processFloatingPointSpecialLiterals(floatStr, true)
}
})
case BooleanType =>
buildCast[Boolean](_, b => if (b) 1f else 0f)
Expand Down Expand Up @@ -718,9 +743,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case ByteType => castToByteCode(from, ctx)
case ShortType => castToShortCode(from, ctx)
case IntegerType => castToIntCode(from, ctx)
case FloatType => castToFloatCode(from)
case FloatType => castToFloatCode(from, ctx)
case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from)
case DoubleType => castToDoubleCode(from, ctx)

case array: ArrayType =>
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
Expand Down Expand Up @@ -1260,48 +1285,66 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
}

private[this] def castToFloatCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) =>
code"""
private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType =>
val floatStr = ctx.freshVariable("floatStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $floatStr = $c.toString();
try {
$evPrim = Float.valueOf($c.toString());
$evPrim = Float.valueOf($floatStr);
} catch (java.lang.NumberFormatException e) {
$evNull = true;
final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
if (f == null) {
$evNull = true;
} else {
$evPrim = f.floatValue();
}
}
"""
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toFloat();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (float) $c;"
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toFloat();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (float) $c;"
}
}

private[this] def castToDoubleCode(from: DataType): CastFunction = from match {
case StringType =>
(c, evPrim, evNull) =>
code"""
private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case StringType =>
val doubleStr = ctx.freshVariable("doubleStr", StringType)
(c, evPrim, evNull) =>
code"""
final String $doubleStr = $c.toString();
try {
$evPrim = Double.valueOf($c.toString());
$evPrim = Double.valueOf($doubleStr);
} catch (java.lang.NumberFormatException e) {
$evNull = true;
final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
if (d == null) {
$evNull = true;
} else {
$evPrim = d.doubleValue();
}
}
"""
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toDouble();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (double) $c;"
case BooleanType =>
(c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
case DateType =>
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};"
case DecimalType() =>
(c, evPrim, evNull) => code"$evPrim = $c.toDouble();"
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (double) $c;"
}
}

private[this] def castArrayCode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1049,4 +1049,30 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
Cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
}
}

test("Process Infinity, -Infinity, NaN in case insensitive manner") {
Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.PositiveInfinity)
}
Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.NegativeInfinity)
}
Seq("inf", "+inf", "infinity", "+infiNity", " infinity ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.PositiveInfinity)
}
Seq("-infinity", "-infiniTy", " -infinity ", " -inf ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.NegativeInfinity)
}
Seq("nan", "nAn", " nan ").foreach { value =>
checkEvaluation(cast(value, FloatType), Float.NaN)
}
Seq("nan", "nAn", " nan ").foreach { value =>
checkEvaluation(cast(value, DoubleType), Double.NaN)
}

// Invalid literals when casted to double and float results in null.
Seq(DoubleType, FloatType).foreach { dataType =>
checkEvaluation(cast("badvalue", dataType), null)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,14 @@ select avg(CAST(null AS DOUBLE)) from range(1,4);
select sum(CAST('NaN' AS DOUBLE)) from range(1,4);
select avg(CAST('NaN' AS DOUBLE)) from range(1,4);

-- [SPARK-27768] verify correct results for infinite inputs
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x);
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('1')) v(x);
FROM (VALUES ('infinity'), ('1')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('Infinity')) v(x);
FROM (VALUES ('infinity'), ('infinity')) v(x);
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('-Infinity'), ('Infinity')) v(x);

FROM (VALUES ('-infinity'), ('infinity')) v(x);

-- test accuracy with a large input offset
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
Expand Down
2 changes: 0 additions & 2 deletions sql/core/src/test/resources/sql-tests/inputs/pgSQL/float4.sql
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ INSERT INTO FLOAT4_TBL VALUES ('1.2345678901234e-20');

-- special inputs
SELECT float('NaN');
-- [SPARK-28060] Float type can not accept some special inputs
SELECT float('nan');
SELECT float(' NAN ');
SELECT float('infinity');
Expand All @@ -49,7 +48,6 @@ SELECT float('N A N');
SELECT float('NaN x');
SELECT float(' INFINITY x');

-- [SPARK-28060] Float type can not accept some special inputs
SELECT float('Infinity') + 100.0;
SELECT float('Infinity') / float('Infinity');
SELECT float('nan') / float('nan');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ SELECT double('-10e-400');

-- special inputs
SELECT double('NaN');
-- [SPARK-28060] Double type can not accept some special inputs
SELECT double('nan');
SELECT double(' NAN ');
SELECT double('infinity');
Expand All @@ -49,7 +48,6 @@ SELECT double('NaN x');
SELECT double(' INFINITY x');

SELECT double('Infinity') + 100.0;
-- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner
SELECT double('Infinity') / double('Infinity');
SELECT double('NaN') / double('NaN');
-- [SPARK-28315] Decimal can not accept NaN as input
Expand Down Expand Up @@ -190,7 +188,7 @@ SELECT tanh(double('1'));
SELECT asinh(double('1'));
SELECT acosh(double('2'));
SELECT atanh(double('0.5'));
-- [SPARK-27768] Infinity, -Infinity, NaN should be recognized in a case insensitive manner

-- test Inf/NaN cases for hyperbolic functions
SELECT sinh(double('Infinity'));
SELECT sinh(double('-Infinity'));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ NaN

-- !query 29
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('Infinity' AS DOUBLE))) v(x)
FROM (VALUES (CAST('1' AS DOUBLE)), (CAST('infinity' AS DOUBLE))) v(x)
-- !query 29 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 29 output
Expand All @@ -245,7 +245,7 @@ Infinity NaN

-- !query 30
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('1')) v(x)
FROM (VALUES ('infinity'), ('1')) v(x)
-- !query 30 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 30 output
Expand All @@ -254,7 +254,7 @@ Infinity NaN

-- !query 31
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('Infinity'), ('Infinity')) v(x)
FROM (VALUES ('infinity'), ('infinity')) v(x)
-- !query 31 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 31 output
Expand All @@ -263,7 +263,7 @@ Infinity NaN

-- !query 32
SELECT avg(CAST(x AS DOUBLE)), var_pop(CAST(x AS DOUBLE))
FROM (VALUES ('-Infinity'), ('Infinity')) v(x)
FROM (VALUES ('-infinity'), ('infinity')) v(x)
-- !query 32 schema
struct<avg(CAST(x AS DOUBLE)):double,var_pop(CAST(x AS DOUBLE)):double>
-- !query 32 output
Expand Down
Loading

0 comments on commit 331f265

Please sign in to comment.