spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-16674][SQL] Avoid per-record type dispatch in JDBC when reading
Date Mon, 25 Jul 2016 11:57:54 GMT
Repository: spark
Updated Branches:
  refs/heads/master 68b4020d0 -> 7ffd99ec5


[SPARK-16674][SQL] Avoid per-record type dispatch in JDBC when reading

## What changes were proposed in this pull request?

Currently, `JDBCRDD.compute` is doing type dispatch for each row to read appropriate values.
It might not have to be done like this because the schema is already kept in `JDBCRDD`.

So, appropriate converters can be created first according to the schema, and then apply them
to each row.

## How was this patch tested?

Existing tests should cover this.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #14313 from HyukjinKwon/SPARK-16674.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7ffd99ec
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7ffd99ec
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7ffd99ec

Branch: refs/heads/master
Commit: 7ffd99ec5f267730734431097cbb700ad074bebe
Parents: 68b4020
Author: hyukjinkwon <gurwls223@gmail.com>
Authored: Mon Jul 25 19:57:47 2016 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Mon Jul 25 19:57:47 2016 +0800

----------------------------------------------------------------------
 .../execution/datasources/jdbc/JDBCRDD.scala    | 245 ++++++++++---------
 1 file changed, 129 insertions(+), 116 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7ffd99ec/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 24e2c1a..4c98430 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
 import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
 import org.apache.spark.sql.jdbc.JdbcDialects
 import org.apache.spark.sql.sources._
@@ -322,43 +322,134 @@ private[sql] class JDBCRDD(
     }
   }
 
-  // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
-  // we don't have to potentially poke around in the Metadata once for every
-  // row.
-  // Is there a better way to do this?  I'd rather be using a type that
-  // contains only the tags I define.
-  abstract class JDBCConversion
-  case object BooleanConversion extends JDBCConversion
-  case object DateConversion extends JDBCConversion
-  case class  DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
-  case object DoubleConversion extends JDBCConversion
-  case object FloatConversion extends JDBCConversion
-  case object IntegerConversion extends JDBCConversion
-  case object LongConversion extends JDBCConversion
-  case object BinaryLongConversion extends JDBCConversion
-  case object StringConversion extends JDBCConversion
-  case object TimestampConversion extends JDBCConversion
-  case object BinaryConversion extends JDBCConversion
-  case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
+  // A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
+  // into a field for `MutableRow`. The last argument `Int` means the index for the
+  // value to be set in the row and also used for the value to retrieve from `ResultSet`.
+  private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
 
   /**
-   * Maps a StructType to a type tag list.
+   * Creates `JDBCValueSetter`s according to [[StructType]], which can set
+   * each value from `ResultSet` to each field of [[MutableRow]] correctly.
    */
-  def getConversions(schema: StructType): Array[JDBCConversion] =
-    schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
-
-  private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match
{
-    case BooleanType => BooleanConversion
-    case DateType => DateConversion
-    case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
-    case DoubleType => DoubleConversion
-    case FloatType => FloatConversion
-    case IntegerType => IntegerConversion
-    case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
-    case StringType => StringConversion
-    case TimestampType => TimestampConversion
-    case BinaryType => BinaryConversion
-    case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
+  def makeSetters(schema: StructType): Array[JDBCValueSetter] =
+    schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
+
+  private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
+    case BooleanType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        row.setBoolean(pos, rs.getBoolean(pos + 1))
+
+    case DateType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
+        val dateVal = rs.getDate(pos + 1)
+        if (dateVal != null) {
+          row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
+        } else {
+          row.update(pos, null)
+        }
+
+    // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
+    // object returned by ResultSet.getBigDecimal is not correctly matched to the table
+    // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
+    // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
+    // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
+    // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
+    // retrieve it, you will get wrong result 199.99.
+    // So it is needed to set precision and scale for Decimal based on JDBC metadata.
+    case DecimalType.Fixed(p, s) =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        val decimal =
+          nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d,
p, s))
+        row.update(pos, decimal)
+
+    case DoubleType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        row.setDouble(pos, rs.getDouble(pos + 1))
+
+    case FloatType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        row.setFloat(pos, rs.getFloat(pos + 1))
+
+    case IntegerType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        row.setInt(pos, rs.getInt(pos + 1))
+
+    case LongType if metadata.contains("binarylong") =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        val bytes = rs.getBytes(pos + 1)
+        var ans = 0L
+        var j = 0
+        while (j < bytes.size) {
+          ans = 256 * ans + (255 & bytes(j))
+          j = j + 1
+        }
+        row.setLong(pos, ans)
+
+    case LongType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        row.setLong(pos, rs.getLong(pos + 1))
+
+    case StringType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
+        row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
+
+    case TimestampType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        val t = rs.getTimestamp(pos + 1)
+        if (t != null) {
+          row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
+        } else {
+          row.update(pos, null)
+        }
+
+    case BinaryType =>
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        row.update(pos, rs.getBytes(pos + 1))
+
+    case ArrayType(et, _) =>
+      val elementConversion = et match {
+        case TimestampType =>
+          (array: Object) =>
+            array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
+              nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
+            }
+
+        case StringType =>
+          (array: Object) =>
+            array.asInstanceOf[Array[java.lang.String]]
+              .map(UTF8String.fromString)
+
+        case DateType =>
+          (array: Object) =>
+            array.asInstanceOf[Array[java.sql.Date]].map { date =>
+              nullSafeConvert(date, DateTimeUtils.fromJavaDate)
+            }
+
+        case dt: DecimalType =>
+          (array: Object) =>
+            array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
+              nullSafeConvert[java.math.BigDecimal](
+                decimal, d => Decimal(d, dt.precision, dt.scale))
+            }
+
+        case LongType if metadata.contains("binarylong") =>
+          throw new IllegalArgumentException(s"Unsupported array element " +
+            s"type ${dt.simpleString} based on binary")
+
+        case ArrayType(_, _) =>
+          throw new IllegalArgumentException("Nested arrays unsupported")
+
+        case _ => (array: Object) => array.asInstanceOf[Array[Any]]
+      }
+
+      (rs: ResultSet, row: MutableRow, pos: Int) =>
+        val array = nullSafeConvert[Object](
+          rs.getArray(pos + 1).getArray,
+          array => new GenericArrayData(elementConversion.apply(array)))
+        row.update(pos, array)
+
     case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
   }
 
@@ -398,93 +489,15 @@ private[sql] class JDBCRDD(
     stmt.setFetchSize(fetchSize)
     val rs = stmt.executeQuery()
 
-    val conversions = getConversions(schema)
+    val setters: Array[JDBCValueSetter] = makeSetters(schema)
     val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
 
     def getNext(): InternalRow = {
       if (rs.next()) {
         inputMetrics.incRecordsRead(1)
         var i = 0
-        while (i < conversions.length) {
-          val pos = i + 1
-          conversions(i) match {
-            case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
-            case DateConversion =>
-              // DateTimeUtils.fromJavaDate does not handle null value, so we need to check
it.
-              val dateVal = rs.getDate(pos)
-              if (dateVal != null) {
-                mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
-              } else {
-                mutableRow.update(i, null)
-              }
-            // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
-            // object returned by ResultSet.getBigDecimal is not correctly matched to the
table
-            // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
-            // If inserting values like 19999 into a column with NUMBER(12, 2) type, you
get through
-            // a BigDecimal object with scale as 0. But the dataframe schema has correct
type as
-            // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and
then
-            // retrieve it, you will get wrong result 199.99.
-            // So it is needed to set precision and scale for Decimal based on JDBC metadata.
-            case DecimalConversion(p, s) =>
-              val decimalVal = rs.getBigDecimal(pos)
-              if (decimalVal == null) {
-                mutableRow.update(i, null)
-              } else {
-                mutableRow.update(i, Decimal(decimalVal, p, s))
-              }
-            case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
-            case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
-            case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
-            case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
-            // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
-            case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos)))
-            case TimestampConversion =>
-              val t = rs.getTimestamp(pos)
-              if (t != null) {
-                mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
-              } else {
-                mutableRow.update(i, null)
-              }
-            case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
-            case BinaryLongConversion =>
-              val bytes = rs.getBytes(pos)
-              var ans = 0L
-              var j = 0
-              while (j < bytes.size) {
-                ans = 256 * ans + (255 & bytes(j))
-                j = j + 1
-              }
-              mutableRow.setLong(i, ans)
-            case ArrayConversion(elementConversion) =>
-              val array = rs.getArray(pos).getArray
-              if (array != null) {
-                val data = elementConversion match {
-                  case TimestampConversion =>
-                    array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
-                      nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
-                    }
-                  case StringConversion =>
-                    array.asInstanceOf[Array[java.lang.String]]
-                      .map(UTF8String.fromString)
-                  case DateConversion =>
-                    array.asInstanceOf[Array[java.sql.Date]].map { date =>
-                      nullSafeConvert(date, DateTimeUtils.fromJavaDate)
-                    }
-                  case DecimalConversion(p, s) =>
-                    array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
-                      nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p,
s))
-                    }
-                  case BinaryLongConversion =>
-                    throw new IllegalArgumentException(s"Unsupported array element conversion
$i")
-                  case _: ArrayConversion =>
-                    throw new IllegalArgumentException("Nested arrays unsupported")
-                  case _ => array.asInstanceOf[Array[Any]]
-                }
-                mutableRow.update(i, new GenericArrayData(data))
-              } else {
-                mutableRow.update(i, null)
-              }
-          }
+        while (i < setters.length) {
+          setters(i).apply(rs, mutableRow, i)
           if (rs.wasNull) mutableRow.setNullAt(i)
           i = i + 1
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message