Skip to content

Commit 6b34e74

Browse files
cloud-fangatorsmile
authored andcommitted
[SPARK-19178][SQL] convert string of large numbers to int should return null
## What changes were proposed in this pull request? When we convert a string to integral, we will convert that string to `decimal(20, 0)` first, so that we can turn a string with decimal format to truncated integral, e.g. `CAST('1.2' AS int)` will return `1`. However, this brings problems when we convert a string with large numbers to integral, e.g. `CAST('1234567890123' AS int)` will return `1912276171`, while Hive returns null as we expected. This is a long standing bug(seems it was there the first day Spark SQL was created), this PR fixes this bug by adding the native support to convert `UTF8String` to integral. ## How was this patch tested? new regression tests Author: Wenchen Fan <[email protected]> Closes apache#16550 from cloud-fan/string-to-int.
1 parent 7f24a0b commit 6b34e74

File tree

5 files changed

+414
-25
lines changed

5 files changed

+414
-25
lines changed

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

+184
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,190 @@ public UTF8String translate(Map<Character, Character> dict) {
835835
return fromString(sb.toString());
836836
}
837837

838+
private int getDigit(byte b) {
839+
if (b >= '0' && b <= '9') {
840+
return b - '0';
841+
}
842+
throw new NumberFormatException(toString());
843+
}
844+
845+
/**
846+
* Parses this UTF8String to long.
847+
*
848+
* Note that, in this method we accumulate the result in negative format, and convert it to
849+
* positive format at the end, if this string is not started with '-'. This is because min value
850+
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
851+
* Integer.MIN_VALUE is '-2147483648'.
852+
*
853+
* This code is mostly copied from LazyLong.parseLong in Hive.
854+
*/
855+
public long toLong() {
856+
if (numBytes == 0) {
857+
throw new NumberFormatException("Empty string");
858+
}
859+
860+
byte b = getByte(0);
861+
final boolean negative = b == '-';
862+
int offset = 0;
863+
if (negative || b == '+') {
864+
offset++;
865+
if (numBytes == 1) {
866+
throw new NumberFormatException(toString());
867+
}
868+
}
869+
870+
final byte separator = '.';
871+
final int radix = 10;
872+
final long stopValue = Long.MIN_VALUE / radix;
873+
long result = 0;
874+
875+
while (offset < numBytes) {
876+
b = getByte(offset);
877+
offset++;
878+
if (b == separator) {
879+
// We allow decimals and will return a truncated integral in that case.
880+
// Therefore we won't throw an exception here (checking the fractional
881+
// part happens below.)
882+
break;
883+
}
884+
885+
int digit = getDigit(b);
886+
// We are going to process the new digit and accumulate the result. However, before doing
887+
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
888+
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
889+
if (result < stopValue) {
890+
throw new NumberFormatException(toString());
891+
}
892+
893+
result = result * radix - digit;
894+
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
895+
// can just use `result > 0` to check overflow. If result overflows, we should stop and throw
896+
// exception.
897+
if (result > 0) {
898+
throw new NumberFormatException(toString());
899+
}
900+
}
901+
902+
// This is the case when we've encountered a decimal separator. The fractional
903+
// part will not change the number, but we will verify that the fractional part
904+
// is well formed.
905+
while (offset < numBytes) {
906+
if (getDigit(getByte(offset)) == -1) {
907+
throw new NumberFormatException(toString());
908+
}
909+
offset++;
910+
}
911+
912+
if (!negative) {
913+
result = -result;
914+
if (result < 0) {
915+
throw new NumberFormatException(toString());
916+
}
917+
}
918+
919+
return result;
920+
}
921+
922+
/**
923+
* Parses this UTF8String to int.
924+
*
925+
* Note that, in this method we accumulate the result in negative format, and convert it to
926+
* positive format at the end, if this string is not started with '-'. This is because min value
927+
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
928+
* Integer.MIN_VALUE is '-2147483648'.
929+
*
930+
* This code is mostly copied from LazyInt.parseInt in Hive.
931+
*
932+
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
933+
* reasons, like Hive does.
934+
*/
935+
public int toInt() {
936+
if (numBytes == 0) {
937+
throw new NumberFormatException("Empty string");
938+
}
939+
940+
byte b = getByte(0);
941+
final boolean negative = b == '-';
942+
int offset = 0;
943+
if (negative || b == '+') {
944+
offset++;
945+
if (numBytes == 1) {
946+
throw new NumberFormatException(toString());
947+
}
948+
}
949+
950+
final byte separator = '.';
951+
final int radix = 10;
952+
final int stopValue = Integer.MIN_VALUE / radix;
953+
int result = 0;
954+
955+
while (offset < numBytes) {
956+
b = getByte(offset);
957+
offset++;
958+
if (b == separator) {
959+
// We allow decimals and will return a truncated integral in that case.
960+
// Therefore we won't throw an exception here (checking the fractional
961+
// part happens below.)
962+
break;
963+
}
964+
965+
int digit = getDigit(b);
966+
// We are going to process the new digit and accumulate the result. However, before doing
967+
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
968+
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
969+
if (result < stopValue) {
970+
throw new NumberFormatException(toString());
971+
}
972+
973+
result = result * radix - digit;
974+
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
975+
// we can just use `result > 0` to check overflow. If result overflows, we should stop and
976+
// throw exception.
977+
if (result > 0) {
978+
throw new NumberFormatException(toString());
979+
}
980+
}
981+
982+
// This is the case when we've encountered a decimal separator. The fractional
983+
// part will not change the number, but we will verify that the fractional part
984+
// is well formed.
985+
while (offset < numBytes) {
986+
if (getDigit(getByte(offset)) == -1) {
987+
throw new NumberFormatException(toString());
988+
}
989+
offset++;
990+
}
991+
992+
if (!negative) {
993+
result = -result;
994+
if (result < 0) {
995+
throw new NumberFormatException(toString());
996+
}
997+
}
998+
999+
return result;
1000+
}
1001+
1002+
public short toShort() {
1003+
int intValue = toInt();
1004+
short result = (short) intValue;
1005+
if (result != intValue) {
1006+
throw new NumberFormatException(toString());
1007+
}
1008+
1009+
return result;
1010+
}
1011+
1012+
public byte toByte() {
1013+
int intValue = toInt();
1014+
byte result = (byte) intValue;
1015+
if (result != intValue) {
1016+
throw new NumberFormatException(toString());
1017+
}
1018+
1019+
return result;
1020+
}
1021+
8381022
@Override
8391023
public String toString() {
8401024
return new String(getBytes(), StandardCharsets.UTF_8);

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

-16
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ object TypeCoercion {
5151
PromoteStrings ::
5252
DecimalPrecision ::
5353
BooleanEquality ::
54-
StringToIntegralCasts ::
5554
FunctionArgumentConversion ::
5655
CaseWhenCoercion ::
5756
IfCoercion ::
@@ -428,21 +427,6 @@ object TypeCoercion {
428427
}
429428
}
430429

431-
/**
432-
* When encountering a cast from a string representing a valid fractional number to an integral
433-
* type the jvm will throw a `java.lang.NumberFormatException`. Hive, in contrast, returns the
434-
* truncated version of this number.
435-
*/
436-
object StringToIntegralCasts extends Rule[LogicalPlan] {
437-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
438-
// Skip nodes who's children have not been resolved yet.
439-
case e if !e.childrenResolved => e
440-
441-
case Cast(e @ StringType(), t: IntegralType) =>
442-
Cast(Cast(e, DecimalType.forType(LongType)), t)
443-
}
444-
}
445-
446430
/**
447431
* This ensure that the types for various functions are as expected.
448432
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

+9-9
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
247247
// LongConverter
248248
private[this] def castToLong(from: DataType): Any => Any = from match {
249249
case StringType =>
250-
buildCast[UTF8String](_, s => try s.toString.toLong catch {
250+
buildCast[UTF8String](_, s => try s.toLong catch {
251251
case _: NumberFormatException => null
252252
})
253253
case BooleanType =>
@@ -263,7 +263,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
263263
// IntConverter
264264
private[this] def castToInt(from: DataType): Any => Any = from match {
265265
case StringType =>
266-
buildCast[UTF8String](_, s => try s.toString.toInt catch {
266+
buildCast[UTF8String](_, s => try s.toInt catch {
267267
case _: NumberFormatException => null
268268
})
269269
case BooleanType =>
@@ -279,7 +279,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
279279
// ShortConverter
280280
private[this] def castToShort(from: DataType): Any => Any = from match {
281281
case StringType =>
282-
buildCast[UTF8String](_, s => try s.toString.toShort catch {
282+
buildCast[UTF8String](_, s => try s.toShort catch {
283283
case _: NumberFormatException => null
284284
})
285285
case BooleanType =>
@@ -295,7 +295,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
295295
// ByteConverter
296296
private[this] def castToByte(from: DataType): Any => Any = from match {
297297
case StringType =>
298-
buildCast[UTF8String](_, s => try s.toString.toByte catch {
298+
buildCast[UTF8String](_, s => try s.toByte catch {
299299
case _: NumberFormatException => null
300300
})
301301
case BooleanType =>
@@ -498,7 +498,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
498498
s"""
499499
boolean $resultNull = $childNull;
500500
${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)};
501-
if (!${childNull}) {
501+
if (!$childNull) {
502502
${cast(childPrim, resultPrim, resultNull)}
503503
}
504504
"""
@@ -705,7 +705,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
705705
(c, evPrim, evNull) =>
706706
s"""
707707
try {
708-
$evPrim = Byte.valueOf($c.toString());
708+
$evPrim = $c.toByte();
709709
} catch (java.lang.NumberFormatException e) {
710710
$evNull = true;
711711
}
@@ -727,7 +727,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
727727
(c, evPrim, evNull) =>
728728
s"""
729729
try {
730-
$evPrim = Short.valueOf($c.toString());
730+
$evPrim = $c.toShort();
731731
} catch (java.lang.NumberFormatException e) {
732732
$evNull = true;
733733
}
@@ -749,7 +749,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
749749
(c, evPrim, evNull) =>
750750
s"""
751751
try {
752-
$evPrim = Integer.valueOf($c.toString());
752+
$evPrim = $c.toInt();
753753
} catch (java.lang.NumberFormatException e) {
754754
$evNull = true;
755755
}
@@ -771,7 +771,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
771771
(c, evPrim, evNull) =>
772772
s"""
773773
try {
774-
$evPrim = Long.valueOf($c.toString());
774+
$evPrim = $c.toLong();
775775
} catch (java.lang.NumberFormatException e) {
776776
$evNull = true;
777777
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
-- cast string representing a valid fractional number to integral should truncate the number
2+
SELECT CAST('1.23' AS int);
3+
SELECT CAST('1.23' AS long);
4+
SELECT CAST('-4.56' AS int);
5+
SELECT CAST('-4.56' AS long);
6+
7+
-- cast string which are not numbers to integral should return null
8+
SELECT CAST('abc' AS int);
9+
SELECT CAST('abc' AS long);
10+
11+
-- cast string representing a very large number to integral should return null
12+
SELECT CAST('1234567890123' AS int);
13+
SELECT CAST('12345678901234567890123' AS long);
14+
15+
-- cast empty string to integral should return null
16+
SELECT CAST('' AS int);
17+
SELECT CAST('' AS long);
18+
19+
-- cast null to integral should return null
20+
SELECT CAST(NULL AS int);
21+
SELECT CAST(NULL AS long);
22+
23+
-- cast invalid decimal string to integral should return null
24+
SELECT CAST('123.a' AS int);
25+
SELECT CAST('123.a' AS long);
26+
27+
-- '-2147483648' is the smallest int value
28+
SELECT CAST('-2147483648' AS int);
29+
SELECT CAST('-2147483649' AS int);
30+
31+
-- '2147483647' is the largest int value
32+
SELECT CAST('2147483647' AS int);
33+
SELECT CAST('2147483648' AS int);
34+
35+
-- '-9223372036854775808' is the smallest long value
36+
SELECT CAST('-9223372036854775808' AS long);
37+
SELECT CAST('-9223372036854775809' AS long);
38+
39+
-- '9223372036854775807' is the largest long value
40+
SELECT CAST('9223372036854775807' AS long);
41+
SELECT CAST('9223372036854775808' AS long);
42+
43+
-- TODO: migrate all cast tests here.

0 commit comments

Comments
 (0)