Skip to content

Commit

Permalink
Refactor PR to move fixes on two other PRs
Browse files Browse the repository at this point in the history
Depends on apache#49452 and apache#49453

Signed-off-by: Xiaoguang Sun <[email protected]>
  • Loading branch information
sunxiaoguang committed Jan 11, 2025
1 parent 5929ca4 commit a9575ca
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
}

test("SPARK-50704: Test SQL function push down with different types and casts in WHERE clause") {
withTable(s"$catalogName.test_pushdown") {
val tableName = s"$catalogName.test_pushdown"
withTable(tableName) {
// Define test values for different data types
val boolean = true
val int = 1
Expand All @@ -252,18 +253,17 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
val float = 0.123
val binary = "X'123456'"
val decimal = "-.001234567E+2BD"
val tableName = "test_pushdown"

// Create a table with various data types
sql(s"""CREATE TABLE $catalogName.$tableName (
sql(s"""CREATE TABLE $tableName (
boolean_col BOOLEAN, byte_col BYTE, tinyint_col TINYINT, short_col SHORT,
smallint_col SMALLINT, int_col INT, integer_col INTEGER, long_col LONG,
bigint_col BIGINT, float_col FLOAT, real_col REAL, double_col DOUBLE,
str_col STRING, binary_col BINARY, decimal_col DECIMAL(10, 7), dec_col DEC(10, 7),
numeric_col NUMERIC(10, 7))""")

// Insert test values into the table
sql(s"""INSERT INTO $catalogName.$tableName VALUES ($boolean, $int, $int, $int,
sql(s"""INSERT INTO $tableName VALUES ($boolean, $int, $int, $int,
$int, $int, $int, $long, $long, $float, $float, $float, '$str', $binary, $decimal,
$decimal, $decimal)""")

Expand Down Expand Up @@ -546,30 +546,12 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
decimalColumns
),
Some(toStringLiteral)
),
generateTests(
"CAST",
"CAST(COLUMN AS STRING) = CAST(VALUE AS STRING)",
Seq(intColumns, longColumns, floatColumns, decimalColumns),
Some(toStringLiteral)
),
generateTests(
"CAST",
"CAST(COLUMN AS INT) = CAST(VALUE AS INT)",
Seq(intColumns, floatColumns, decimalColumns)
),
generateTests(
"CAST",
"ABS(ABS(CAST(COLUMN AS DOUBLE)) - ABS(CAST(VALUE AS DOUBLE))) <= 0.00001",
Seq(intColumns, floatColumns, longColumns, decimalColumns)
)
).flatten

// Execute the generated test cases
functions.foreach { case (name, query) =>
val sql = s"SELECT * FROM $catalogName.$tableName WHERE $query"
val df = spark.sql(sql)
val rows = df.collect()
val rows = spark.sql(s"SELECT * FROM $tableName WHERE $query").collect()
assert(rows.length === 1, s"Function `$name` pushdown test failed: $sql")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ abstract class JdbcDialect extends Serializable with Logging {
case dateValue: Date => "'" + dateValue + "'"
case dateValue: LocalDate => s"'${DateFormatter().format(dateValue)}'"
case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ")
case binaryValue: Array[Byte] => binaryValue.map("%02X".format(_)).mkString("X'", "", "'")
case _ => value
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,6 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
} else {
super.visitAggregateFunction(funcName, isDistinct, inputs)
}

override def visitCast(expr: String, exprDataType: DataType, dataType: DataType): String = {
val databaseTypeDefinition = dataType match {
// MySQL uses CHAR in the cast function for the type LONGTEXT
case StringType => "CHAR"
// MySQL uses SIGNED INTEGER in the cast function for the types SMALLINT, INTEGER and BIGINT
case ShortType | IntegerType | LongType => "SIGNED INTEGER"
// MySQL uses BINARY in the cast function for the type BLOB
case BinaryType => "BINARY"
case _ => getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName)
}
s"CAST($expr AS $databaseTypeDefinition)"
}
}

override def compileExpression(expr: Expression): Option[String] = {
Expand Down Expand Up @@ -311,8 +298,6 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
case StringType => Option(JdbcType("LONGTEXT", java.sql.Types.LONGVARCHAR))
case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT))
case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT))
// We override getJDBCType so that DoubleType is mapped to DOUBLE instead.
case DoubleType => Option(JdbcType("DOUBLE", java.sql.Types.DOUBLE))
// scalastyle:off line.size.limit
// In MYSQL, DATETIME is TIMESTAMP WITHOUT TIME ZONE
// https://github.com/mysql/mysql-connector-j/blob/8.3.0/src/main/core-api/java/com/mysql/cj/MysqlType.java#L251
Expand Down

0 comments on commit a9575ca

Please sign in to comment.