spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-10195] [SQL] Data sources Filter should not expose internal types
Date Tue, 25 Aug 2015 08:06:54 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 e5cea566a -> a0f22cf29


[SPARK-10195] [SQL] Data sources Filter should not expose internal types

Spark SQL's data sources API exposes Catalyst's internal types through its Filter interfaces.
This is a problem because types like UTF8String are not stable developer APIs and should not
be exposed to third-parties.

This issue caused incompatibilities when upgrading our `spark-redshift` library to work against
Spark 1.5.0.  To avoid these issues in the future we should only expose public types through
these Filter objects. This patch accomplishes this by using CatalystTypeConverters to add
the appropriate conversions.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #8403 from JoshRosen/datasources-internal-vs-external-types.

(cherry picked from commit 7bc9a8c6249300ded31ea931c463d0a8f798e193)
Signed-off-by: Reynold Xin <rxin@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a0f22cf2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a0f22cf2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a0f22cf2

Branch: refs/heads/branch-1.5
Commit: a0f22cf295a1d20814c5be6cc727e39e95a81c27
Parents: e5cea56
Author: Josh Rosen <joshrosen@databricks.com>
Authored: Tue Aug 25 01:06:36 2015 -0700
Committer: Reynold Xin <rxin@databricks.com>
Committed: Tue Aug 25 01:06:51 2015 -0700

----------------------------------------------------------------------
 .../datasources/DataSourceStrategy.scala        | 67 ++++++++++----------
 .../execution/datasources/jdbc/JDBCRDD.scala    |  2 +-
 .../datasources/parquet/ParquetFilters.scala    | 19 +++---
 .../spark/sql/sources/FilteredScanSuite.scala   |  7 ++
 4 files changed, 54 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a0f22cf2/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 2a4c40d..6c1ef6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources
 import org.apache.spark.{Logging, TaskContext}
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
-import org.apache.spark.sql.catalyst.{InternalRow, expressions}
+import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical
@@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging
{
    */
   protected[sql] def selectFilters(filters: Seq[Expression]) = {
     def translate(predicate: Expression): Option[Filter] = predicate match {
-      case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
-        Some(sources.EqualTo(a.name, v))
-      case expressions.EqualTo(Literal(v, _), a: Attribute) =>
-        Some(sources.EqualTo(a.name, v))
-
-      case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
-        Some(sources.EqualNullSafe(a.name, v))
-      case expressions.EqualNullSafe(Literal(v, _), a: Attribute) =>
-        Some(sources.EqualNullSafe(a.name, v))
-
-      case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
-        Some(sources.GreaterThan(a.name, v))
-      case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
-        Some(sources.LessThan(a.name, v))
-
-      case expressions.LessThan(a: Attribute, Literal(v, _)) =>
-        Some(sources.LessThan(a.name, v))
-      case expressions.LessThan(Literal(v, _), a: Attribute) =>
-        Some(sources.GreaterThan(a.name, v))
-
-      case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
-        Some(sources.GreaterThanOrEqual(a.name, v))
-      case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
-        Some(sources.LessThanOrEqual(a.name, v))
-
-      case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
-        Some(sources.LessThanOrEqual(a.name, v))
-      case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
-        Some(sources.GreaterThanOrEqual(a.name, v))
+      case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
+        Some(sources.EqualTo(a.name, convertToScala(v, t)))
+      case expressions.EqualTo(Literal(v, t), a: Attribute) =>
+        Some(sources.EqualTo(a.name, convertToScala(v, t)))
+
+      case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
+        Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
+      case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
+        Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
+
+      case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
+        Some(sources.GreaterThan(a.name, convertToScala(v, t)))
+      case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
+        Some(sources.LessThan(a.name, convertToScala(v, t)))
+
+      case expressions.LessThan(a: Attribute, Literal(v, t)) =>
+        Some(sources.LessThan(a.name, convertToScala(v, t)))
+      case expressions.LessThan(Literal(v, t), a: Attribute) =>
+        Some(sources.GreaterThan(a.name, convertToScala(v, t)))
+
+      case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
+        Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
+      case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
+        Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
+
+      case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
+        Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
+      case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
+        Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
 
       case expressions.InSet(a: Attribute, set) =>
-        Some(sources.In(a.name, set.toArray))
+        val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
+        Some(sources.In(a.name, set.toArray.map(toScala)))
 
       // Because we only convert In to InSet in Optimizer when there are more than certain
       // items. So it is possible we still get an In expression here that needs to be pushed
       // down.
       case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
         val hSet = list.map(e => e.eval(EmptyRow))
-        Some(sources.In(a.name, hSet.toArray))
+        val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
+        Some(sources.In(a.name, hSet.toArray.map(toScala)))
 
       case expressions.IsNull(a: Attribute) =>
         Some(sources.IsNull(a.name))

http://git-wip-us.apache.org/repos/asf/spark/blob/a0f22cf2/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index e537d63..730d88b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -262,7 +262,7 @@ private[sql] class JDBCRDD(
    * Converts value to SQL expression.
    */
   private def compileValue(value: Any): Any = value match {
-    case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
+    case stringValue: String => s"'${escapeSql(stringValue)}'"
     case _ => value
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a0f22cf2/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 63915e0..83eaf8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -34,7 +34,6 @@ import org.apache.spark.SparkEnv
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.sources
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 private[sql] object ParquetFilters {
   val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
@@ -73,7 +72,7 @@ private[sql] object ParquetFilters {
     case StringType =>
       (n: String, v: Any) => FilterApi.eq(
         binaryColumn(n),
-        Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
+        Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
     case BinaryType =>
       (n: String, v: Any) => FilterApi.eq(
         binaryColumn(n),
@@ -94,7 +93,7 @@ private[sql] object ParquetFilters {
     case StringType =>
       (n: String, v: Any) => FilterApi.notEq(
         binaryColumn(n),
-        Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
+        Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
     case BinaryType =>
       (n: String, v: Any) => FilterApi.notEq(
         binaryColumn(n),
@@ -112,7 +111,8 @@ private[sql] object ParquetFilters {
       (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
     case StringType =>
       (n: String, v: Any) =>
-        FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+        FilterApi.lt(binaryColumn(n),
+          Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
     case BinaryType =>
       (n: String, v: Any) =>
         FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -129,7 +129,8 @@ private[sql] object ParquetFilters {
       (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
     case StringType =>
       (n: String, v: Any) =>
-        FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+        FilterApi.ltEq(binaryColumn(n),
+          Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
     case BinaryType =>
       (n: String, v: Any) =>
         FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -146,7 +147,8 @@ private[sql] object ParquetFilters {
       (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
     case StringType =>
       (n: String, v: Any) =>
-        FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+        FilterApi.gt(binaryColumn(n),
+          Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
     case BinaryType =>
       (n: String, v: Any) =>
         FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -163,7 +165,8 @@ private[sql] object ParquetFilters {
       (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
     case StringType =>
       (n: String, v: Any) =>
-        FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
+        FilterApi.gtEq(binaryColumn(n),
+          Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
     case BinaryType =>
       (n: String, v: Any) =>
         FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -185,7 +188,7 @@ private[sql] object ParquetFilters {
     case StringType =>
       (n: String, v: Set[Any]) =>
         FilterApi.userDefined(binaryColumn(n),
-          SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes))))
+          SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
     case BinaryType =>
       (n: String, v: Set[Any]) =>
         FilterApi.userDefined(binaryColumn(n),

http://git-wip-us.apache.org/repos/asf/spark/blob/a0f22cf2/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index c81c3d3..68ce37c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
 import scala.language.existentials
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.sql._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext:
SQL
       case StringStartsWith("c", v) => _.startsWith(v)
       case StringEndsWith("c", v) => _.endsWith(v)
       case StringContains("c", v) => _.contains(v)
+      case EqualTo("c", v: String) => _.equals(v)
+      case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
+      case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
       case _ => (c: String) => true
     }
 
@@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
   testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
   testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)
 
+  testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
+  testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1)
+
   def testPushDown(sqlString: String, expectedCount: Int): Unit = {
     test(s"PushDown Returns $expectedCount: $sqlString") {
       val queryExecution = sql(sqlString).queryExecution


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message