spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22384][SQL] Refine partition pruning when attribute is wrapped in Cast
Date Tue, 05 Jun 2018 18:31:25 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2c2a86b5d -> 93df3cd03


[SPARK-22384][SQL] Refine partition pruning when attribute is wrapped in Cast

## What changes were proposed in this pull request?

Sql below will get all partitions from metastore, which put much burden on metastore;
```
CREATE TABLE `partition_test`(`col` int) PARTITIONED BY (`pt` byte)
SELECT * FROM partition_test WHERE CAST(pt AS INT)=1
```
The reason is that the the analyzed attribute `dt` is wrapped in `Cast` and `HiveShim` fails
to generate a proper partition filter.
This pr proposes to take `Cast` into consideration when generate partition filter.

## How was this patch tested?
Test added.
This pr proposes to use analyzed expressions in `HiveClientSuite`

Author: jinxing <jinxing6042@126.com>

Closes #19602 from jinxing64/SPARK-22384.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/93df3cd0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/93df3cd0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/93df3cd0

Branch: refs/heads/master
Commit: 93df3cd03503fca7745141fbd2676b8bf70fe92f
Parents: 2c2a86b
Author: jinxing <jinxing6042@126.com>
Authored: Tue Jun 5 11:32:42 2018 -0700
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Jun 5 11:32:42 2018 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/hive/client/HiveShim.scala |  23 ++++-
 .../spark/sql/hive/client/HiveClientSuite.scala | 102 ++++++++++++-------
 2 files changed, 86 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/93df3cd0/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
index 948ba54..130e258 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
@@ -24,7 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Locale,
Map => JMap, S
 import java.util.concurrent.TimeUnit
 
 import scala.collection.JavaConverters._
-import scala.util.Try
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.fs.Path
@@ -657,17 +656,31 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
 
     val useAdvanced = SQLConf.get.advancedPartitionPredicatePushdownEnabled
 
+    object ExtractAttribute {
+      def unapply(expr: Expression): Option[Attribute] = {
+        expr match {
+          case attr: Attribute => Some(attr)
+          case Cast(child, dt, _) if !Cast.mayTruncate(child.dataType, dt) => unapply(child)
+          case _ => None
+        }
+      }
+    }
+
     def convert(expr: Expression): Option[String] = expr match {
-      case In(NonVarcharAttribute(name), ExtractableLiterals(values)) if useAdvanced =>
+      case In(ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiterals(values))
+          if useAdvanced =>
         Some(convertInToOr(name, values))
 
-      case InSet(NonVarcharAttribute(name), ExtractableValues(values)) if useAdvanced =>
+      case InSet(ExtractAttribute(NonVarcharAttribute(name)), ExtractableValues(values))
+          if useAdvanced =>
         Some(convertInToOr(name, values))
 
-      case op @ SpecialBinaryComparison(NonVarcharAttribute(name), ExtractableLiteral(value))
=>
+      case op @ SpecialBinaryComparison(
+          ExtractAttribute(NonVarcharAttribute(name)), ExtractableLiteral(value)) =>
         Some(s"$name ${op.symbol} $value")
 
-      case op @ SpecialBinaryComparison(ExtractableLiteral(value), NonVarcharAttribute(name))
=>
+      case op @ SpecialBinaryComparison(
+          ExtractableLiteral(value), ExtractAttribute(NonVarcharAttribute(name))) =>
         Some(s"$value ${op.symbol} $name")
 
       case And(expr1, expr2) if useAdvanced =>

http://git-wip-us.apache.org/repos/asf/spark/blob/93df3cd0/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
index f991352..55275f6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala
@@ -22,13 +22,13 @@ import org.apache.hadoop.hive.conf.HiveConf
 import org.scalatest.BeforeAndAfterAll
 
 import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{EmptyRow, Expression, In, InSet}
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.LongType
 
 // TODO: Refactor this to `HivePartitionFilteringSuite`
 class HiveClientSuite(version: String)
     extends HiveVersionSuite(version) with BeforeAndAfterAll {
-  import CatalystSqlParser._
 
   private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname
 
@@ -46,8 +46,7 @@ class HiveClientSuite(version: String)
     val hadoopConf = new Configuration()
     hadoopConf.setBoolean(tryDirectSqlKey, tryDirectSql)
     val client = buildClient(hadoopConf)
-    client
-      .runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk STRING)")
+    client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (ds INT, h INT, chunk
STRING)")
 
     val partitions =
       for {
@@ -66,6 +65,15 @@ class HiveClientSuite(version: String)
     client
   }
 
+  private def attr(name: String): Attribute = {
+    client.getTable("default", "test").partitionSchema.fields
+        .find(field => field.name.equals(name)) match {
+      case Some(field) => AttributeReference(field.name, field.dataType)()
+      case None =>
+        fail(s"Illegal name of partition attribute: $name")
+    }
+  }
+
   override def beforeAll() {
     super.beforeAll()
     client = init(true)
@@ -74,7 +82,7 @@ class HiveClientSuite(version: String)
   test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
     val client = init(false)
     val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
-      Seq(parseExpression("ds=20170101")))
+      Seq(attr("ds") === 20170101))
 
     assert(filteredPartitions.size == testPartitionCount)
   }
@@ -82,7 +90,7 @@ class HiveClientSuite(version: String)
   test("getPartitionsByFilter: ds<=>20170101") {
     // Should return all partitions where <=> is not supported
     testMetastorePartitionFiltering(
-      "ds<=>20170101",
+      attr("ds") <=> 20170101,
       20170101 to 20170103,
       0 to 23,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -90,7 +98,7 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: ds=20170101") {
     testMetastorePartitionFiltering(
-      "ds=20170101",
+      attr("ds") === 20170101,
       20170101 to 20170101,
       0 to 23,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -100,7 +108,7 @@ class HiveClientSuite(version: String)
     // Should return all partitions where h=0 because getPartitionsByFilter does not support
     // comparisons to non-literal values
     testMetastorePartitionFiltering(
-      "ds=(20170101 + 1) and h=0",
+      attr("ds") === (Literal(20170101) + 1) && attr("h") === 0,
       20170101 to 20170103,
       0 to 0,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -108,7 +116,7 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: chunk='aa'") {
     testMetastorePartitionFiltering(
-      "chunk='aa'",
+      attr("chunk") === "aa",
       20170101 to 20170103,
       0 to 23,
       "aa" :: Nil)
@@ -116,7 +124,7 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: 20170101=ds") {
     testMetastorePartitionFiltering(
-      "20170101=ds",
+      Literal(20170101) === attr("ds"),
       20170101 to 20170101,
       0 to 23,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -124,7 +132,15 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: ds=20170101 and h=10") {
     testMetastorePartitionFiltering(
-      "ds=20170101 and h=10",
+      attr("ds") === 20170101 && attr("h") === 10,
+      20170101 to 20170101,
+      10 to 10,
+      "aa" :: "ab" :: "ba" :: "bb" :: Nil)
+  }
+
+  test("getPartitionsByFilter: chunk in cast(ds as long)=20170101L") {
+    testMetastorePartitionFiltering(
+      attr("ds").cast(LongType) === 20170101L && attr("h") === 10,
       20170101 to 20170101,
       10 to 10,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -132,7 +148,7 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: ds=20170101 or ds=20170102") {
     testMetastorePartitionFiltering(
-      "ds=20170101 or ds=20170102",
+      attr("ds") === 20170101 || attr("ds") === 20170102,
       20170101 to 20170102,
       0 to 23,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -140,7 +156,15 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: ds in (20170102, 20170103) (using IN expression)") {
     testMetastorePartitionFiltering(
-      "ds in (20170102, 20170103)",
+      attr("ds").in(20170102, 20170103),
+      20170102 to 20170103,
+      0 to 23,
+      "aa" :: "ab" :: "ba" :: "bb" :: Nil)
+  }
+
+  test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using IN expression)")
{
+    testMetastorePartitionFiltering(
+      attr("ds").cast(LongType).in(20170102L, 20170103L),
       20170102 to 20170103,
       0 to 23,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil)
@@ -148,7 +172,19 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: ds in (20170102, 20170103) (using INSET expression)") {
     testMetastorePartitionFiltering(
-      "ds in (20170102, 20170103)",
+      attr("ds").in(20170102, 20170103),
+      20170102 to 20170103,
+      0 to 23,
+      "aa" :: "ab" :: "ba" :: "bb" :: Nil, {
+        case expr @ In(v, list) if expr.inSetConvertible =>
+          InSet(v, list.map(_.eval(EmptyRow)).toSet)
+      })
+  }
+
+  test("getPartitionsByFilter: cast(ds as long) in (20170102L, 20170103L) (using INSET expression)")
+  {
+    testMetastorePartitionFiltering(
+      attr("ds").cast(LongType).in(20170102L, 20170103L),
       20170102 to 20170103,
       0 to 23,
       "aa" :: "ab" :: "ba" :: "bb" :: Nil, {
@@ -159,7 +195,7 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: chunk in ('ab', 'ba') (using IN expression)") {
     testMetastorePartitionFiltering(
-      "chunk in ('ab', 'ba')",
+      attr("chunk").in("ab", "ba"),
       20170101 to 20170103,
       0 to 23,
       "ab" :: "ba" :: Nil)
@@ -167,7 +203,7 @@ class HiveClientSuite(version: String)
 
   test("getPartitionsByFilter: chunk in ('ab', 'ba') (using INSET expression)") {
     testMetastorePartitionFiltering(
-      "chunk in ('ab', 'ba')",
+      attr("chunk").in("ab", "ba"),
       20170101 to 20170103,
       0 to 23,
       "ab" :: "ba" :: Nil, {
@@ -179,26 +215,24 @@ class HiveClientSuite(version: String)
   test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<8)") {
     val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
     val day2 = (20170102 to 20170102, 0 to 7, Seq("aa", "ab", "ba", "bb"))
-    testMetastorePartitionFiltering(
-      "(ds=20170101 and h>=8) or (ds=20170102 and h<8)",
-      day1 :: day2 :: Nil)
+    testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8)
||
+        (attr("ds") === 20170102 && attr("h") < 8), day1 :: day2 :: Nil)
   }
 
   test("getPartitionsByFilter: (ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))")
{
     val day1 = (20170101 to 20170101, 8 to 23, Seq("aa", "ab", "ba", "bb"))
     // Day 2 should include all hours because we can't build a filter for h<(7+1)
     val day2 = (20170102 to 20170102, 0 to 23, Seq("aa", "ab", "ba", "bb"))
-    testMetastorePartitionFiltering(
-      "(ds=20170101 and h>=8) or (ds=20170102 and h<(7+1))",
-      day1 :: day2 :: Nil)
+    testMetastorePartitionFiltering((attr("ds") === 20170101 && attr("h") >= 8)
||
+        (attr("ds") === 20170102 && attr("h") < (Literal(7) + 1)), day1 :: day2
:: Nil)
   }
 
   test("getPartitionsByFilter: " +
       "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))")
{
     val day1 = (20170101 to 20170101, 8 to 23, Seq("ab", "ba"))
     val day2 = (20170102 to 20170102, 0 to 7, Seq("ab", "ba"))
-    testMetastorePartitionFiltering(
-      "chunk in ('ab', 'ba') and ((ds=20170101 and h>=8) or (ds=20170102 and h<8))",
+    testMetastorePartitionFiltering(attr("chunk").in("ab", "ba") &&
+        ((attr("ds") === 20170101 && attr("h") >= 8) || (attr("ds") === 20170102
&& attr("h") < 8)),
       day1 :: day2 :: Nil)
   }
 
@@ -207,41 +241,41 @@ class HiveClientSuite(version: String)
   }
 
   private def testMetastorePartitionFiltering(
-      filterString: String,
+      filterExpr: Expression,
       expectedDs: Seq[Int],
       expectedH: Seq[Int],
       expectedChunks: Seq[String]): Unit = {
     testMetastorePartitionFiltering(
-      filterString,
+      filterExpr,
       (expectedDs, expectedH, expectedChunks) :: Nil,
       identity)
   }
 
   private def testMetastorePartitionFiltering(
-      filterString: String,
+      filterExpr: Expression,
       expectedDs: Seq[Int],
       expectedH: Seq[Int],
       expectedChunks: Seq[String],
       transform: Expression => Expression): Unit = {
     testMetastorePartitionFiltering(
-      filterString,
+      filterExpr,
       (expectedDs, expectedH, expectedChunks) :: Nil,
-      identity)
+      transform)
   }
 
   private def testMetastorePartitionFiltering(
-      filterString: String,
+      filterExpr: Expression,
       expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])]): Unit = {
-    testMetastorePartitionFiltering(filterString, expectedPartitionCubes, identity)
+    testMetastorePartitionFiltering(filterExpr, expectedPartitionCubes, identity)
   }
 
   private def testMetastorePartitionFiltering(
-      filterString: String,
+      filterExpr: Expression,
       expectedPartitionCubes: Seq[(Seq[Int], Seq[Int], Seq[String])],
       transform: Expression => Expression): Unit = {
     val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
       Seq(
-        transform(parseExpression(filterString))
+        transform(filterExpr)
       ))
 
     val expectedPartitionCount = expectedPartitionCubes.map {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message