Skip to content

Commit

Permalink
[SPARK-16625][SQL] General data types to be mapped to Oracle
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Spark will convert **BooleanType** to **BIT(1)**, **LongType** to **BIGINT**, **ByteType**  to **BYTE** when saving DataFrame to Oracle, but Oracle does not support BIT, BIGINT and BYTE types.

This PR is convert following _Spark Types_ to _Oracle types_ refer to [Oracle Developer's Guide](https://docs.oracle.com/cd/E19501-01/819-3659/gcmaz/)

Spark Type | Oracle
----|----
BooleanType | NUMBER(1)
IntegerType | NUMBER(10)
LongType | NUMBER(19)
FloatType | NUMBER(19, 4)
DoubleType | NUMBER(19, 4)
ByteType | NUMBER(3)
ShortType | NUMBER(5)

## How was this patch tested?

Add new tests in [JDBCSuite.scala](wangyum@22b0c2a#diff-dc4b58851b084b274df6fe6b189db84d) and [OracleDialect.scala](wangyum@22b0c2a#diff-5e0cadf526662f9281aa26315b3750ad)

Author: Yuming Wang <[email protected]>

Closes apache#14377 from wangyum/SPARK-16625.
  • Loading branch information
wangyum authored and srowen committed Aug 5, 2016
1 parent e026064 commit 39a2b2e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.jdbc

import java.sql.Connection
import java.sql.{Connection, Date, Timestamp}
import java.util.Properties

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.tags.DockerTest

/**
Expand Down Expand Up @@ -77,4 +79,74 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
// verify the value is the inserted correct or not
assert(rows(0).getString(0).equals("foo"))
}

test("SPARK-16625: General data types to be mapped to Oracle") {
val props = new Properties()
props.put("oracle.jdbc.mapDateToTimestamp", "false")

val schema = StructType(Seq(
StructField("boolean_type", BooleanType, true),
StructField("integer_type", IntegerType, true),
StructField("long_type", LongType, true),
StructField("float_Type", FloatType, true),
StructField("double_type", DoubleType, true),
StructField("byte_type", ByteType, true),
StructField("short_type", ShortType, true),
StructField("string_type", StringType, true),
StructField("binary_type", BinaryType, true),
StructField("date_type", DateType, true),
StructField("timestamp_type", TimestampType, true)
))

val tableName = "test_oracle_general_types"
val booleanVal = true
val integerVal = 1
val longVal = 2L
val floatVal = 3.0f
val doubleVal = 4.0
val byteVal = 2.toByte
val shortVal = 5.toShort
val stringVal = "string"
val binaryVal = Array[Byte](6, 7, 8)
val dateVal = Date.valueOf("2016-07-26")
val timestampVal = Timestamp.valueOf("2016-07-26 11:49:45")

val data = spark.sparkContext.parallelize(Seq(
Row(
booleanVal, integerVal, longVal, floatVal, doubleVal, byteVal, shortVal, stringVal,
binaryVal, dateVal, timestampVal
)))

val dfWrite = spark.createDataFrame(data, schema)
dfWrite.write.jdbc(jdbcUrl, tableName, props)

val dfRead = spark.read.jdbc(jdbcUrl, tableName, props)
val rows = dfRead.collect()
// verify the data type is inserted
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Long"))
assert(types(3).equals("class java.lang.Float"))
assert(types(4).equals("class java.lang.Float"))
assert(types(5).equals("class java.lang.Integer"))
assert(types(6).equals("class java.lang.Integer"))
assert(types(7).equals("class java.lang.String"))
assert(types(8).equals("class [B"))
assert(types(9).equals("class java.sql.Date"))
assert(types(10).equals("class java.sql.Timestamp"))
// verify the value is the inserted correct or not
val values = rows(0)
assert(values.getBoolean(0).equals(booleanVal))
assert(values.getInt(1).equals(integerVal))
assert(values.getLong(2).equals(longVal))
assert(values.getFloat(3).equals(floatVal))
assert(values.getFloat(4).equals(doubleVal.toFloat))
assert(values.getInt(5).equals(byteVal.toInt))
assert(values.getInt(6).equals(shortVal.toInt))
assert(values.getString(7).equals(stringVal))
assert(values.getAs[Array[Byte]](8).mkString.equals("678"))
assert(values.getDate(9).equals(dateVal))
assert(values.getTimestamp(10).equals(timestampVal))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,42 @@ private case object OracleDialect extends JdbcDialect {

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
// Handle NUMBER fields that have no precision/scale in special way
// because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
// For more details, please see
// https://github.com/apache/spark/pull/8780#issuecomment-145598968
// and
// https://github.com/apache/spark/pull/8780#issuecomment-144541760
if (sqlType == Types.NUMERIC && size == 0) {
// This is sub-optimal as we have to pick a precision/scale in advance whereas the data
// in Oracle is allowed to have different precision/scale for each value.
Option(DecimalType(DecimalType.MAX_PRECISION, 10))
} else if (sqlType == Types.NUMERIC && md.build().getLong("scale") == -127) {
// Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
// this to NUMERIC with -127 scale
// Not sure if there is a more robust way to identify the field as a float (or other
// numeric types that do not specify a scale.
Option(DecimalType(DecimalType.MAX_PRECISION, 10))
if (sqlType == Types.NUMERIC) {
val scale = if (null != md) md.build().getLong("scale") else 0L
size match {
// Handle NUMBER fields that have no precision/scale in special way
// because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
// For more details, please see
// https://github.com/apache/spark/pull/8780#issuecomment-145598968
// and
// https://github.com/apache/spark/pull/8780#issuecomment-144541760
case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
// Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
// this to NUMERIC with -127 scale
// Not sure if there is a more robust way to identify the field as a float (or other
// numeric types that do not specify a scale.
case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
case 1 => Option(BooleanType)
case 3 | 5 | 10 => Option(IntegerType)
case 19 if scale == 0L => Option(LongType)
case 19 if scale == 4L => Option(FloatType)
case _ => None
}
} else {
None
}
}

override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
// For more details, please see
// https://docs.oracle.com/cd/E19501-01/819-3659/gcmaz/
case BooleanType => Some(JdbcType("NUMBER(1)", java.sql.Types.BOOLEAN))
case IntegerType => Some(JdbcType("NUMBER(10)", java.sql.Types.INTEGER))
case LongType => Some(JdbcType("NUMBER(19)", java.sql.Types.BIGINT))
case FloatType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.FLOAT))
case DoubleType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE))
case ByteType => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT))
case ShortType => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT))
case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR))
case _ => None
}
Expand Down
21 changes: 21 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,27 @@ class JDBCSuite extends SparkFunSuite
map(_.databaseTypeDefinition).get == "VARCHAR2(255)")
}

test("SPARK-16625: General data types to be mapped to Oracle") {

def getJdbcType(dialect: JdbcDialect, dt: DataType): String = {
dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).
map(_.databaseTypeDefinition).get
}

val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db")
assert(getJdbcType(oracleDialect, BooleanType) == "NUMBER(1)")
assert(getJdbcType(oracleDialect, IntegerType) == "NUMBER(10)")
assert(getJdbcType(oracleDialect, LongType) == "NUMBER(19)")
assert(getJdbcType(oracleDialect, FloatType) == "NUMBER(19, 4)")
assert(getJdbcType(oracleDialect, DoubleType) == "NUMBER(19, 4)")
assert(getJdbcType(oracleDialect, ByteType) == "NUMBER(3)")
assert(getJdbcType(oracleDialect, ShortType) == "NUMBER(5)")
assert(getJdbcType(oracleDialect, StringType) == "VARCHAR2(255)")
assert(getJdbcType(oracleDialect, BinaryType) == "BLOB")
assert(getJdbcType(oracleDialect, DateType) == "DATE")
assert(getJdbcType(oracleDialect, TimestampType) == "TIMESTAMP")
}

private def assertEmptyQuery(sqlString: String): Unit = {
assert(sql(sqlString).collect().isEmpty)
}
Expand Down

0 comments on commit 39a2b2e

Please sign in to comment.