spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-19020][SQL] Cardinality estimation of aggregate operator
Date Mon, 09 Jan 2017 19:29:50 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3ccabdfb4 -> 15c2bd01b


[SPARK-19020][SQL] Cardinality estimation of aggregate operator

## What changes were proposed in this pull request?

Support cardinality estimation of aggregate operator

## How was this patch tested?

Add test cases

Author: Zhenhua Wang <wzh_zju@163.com>
Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #16431 from wzhfy/aggEstimation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/15c2bd01
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/15c2bd01
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/15c2bd01

Branch: refs/heads/master
Commit: 15c2bd01b03b1a07f10779f68118cd28f2c62c9a
Parents: 3ccabdf
Author: Zhenhua Wang <wzh_zju@163.com>
Authored: Mon Jan 9 11:29:42 2017 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Mon Jan 9 11:29:42 2017 -0800

----------------------------------------------------------------------
 .../plans/logical/basicLogicalOperators.scala   |   4 +-
 .../statsEstimation/AggregateEstimation.scala   |  57 ++++++++
 .../statsEstimation/AggEstimationSuite.scala    | 135 +++++++++++++++++++
 .../StatsEstimationTestBase.scala               |   5 +-
 4 files changed, 198 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/15c2bd01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 9b52a9c..b97c81c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation,
ProjectEstimation}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -495,7 +495,7 @@ case class Aggregate(
     child.constraints.union(getAliasedConstraints(nonAgg))
   }
 
-  override lazy val statistics: Statistics = {
+  override lazy val statistics: Statistics = AggregateEstimation.estimate(this).getOrElse
{
     if (groupingExpressions.isEmpty) {
       super.statistics.copy(sizeInBytes = 1)
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/15c2bd01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
new file mode 100644
index 0000000..33ebc38
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
+
+
+object AggregateEstimation {
+  import EstimationUtils._
+
+  /**
+   * Estimate the number of output rows based on column stats of group-by columns, and propagate
+   * column stats for aggregate expressions.
+   */
+  def estimate(agg: Aggregate): Option[Statistics] = {
+    val childStats = agg.child.statistics
+    // Check if we have column stats for all group-by columns.
+    val colStatsExist = agg.groupingExpressions.forall { e =>
+      e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
+    }
+    if (rowCountsExist(agg.child) && colStatsExist) {
+      // Multiply distinct counts of group-by columns. This is an upper bound, which assumes
+      // the data contains all combinations of distinct values of group-by columns.
+      var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
+        (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount)
+
+      // Here we set another upper bound for the number of output rows: it must not be larger
than
+      // child's number of rows.
+      outputRows = outputRows.min(childStats.rowCount.get)
+
+      val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
+      Some(Statistics(
+        sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats),
+        rowCount = Some(outputRows),
+        attributeStats = outputAttrStats,
+        isBroadcastable = childStats.isBroadcastable))
+    } else {
+      None
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15c2bd01/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
new file mode 100644
index 0000000..42ce2f8
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+
+
+class AggEstimationSuite extends StatsEstimationTestBase {
+
+  /** Columns for testing */
+  private val columnInfo: Map[Attribute, ColumnStat] =
+    Map(
+      attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount
= 0,
+        avgLen = 4, maxLen = 4),
+      attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount
= 0,
+        avgLen = 4, maxLen = 4),
+      attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount
= 0,
+        avgLen = 4, maxLen = 4),
+      attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount
= 0,
+        avgLen = 4, maxLen = 4),
+      attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount
= 0,
+        avgLen = 4, maxLen = 4),
+      attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount
= 0,
+        avgLen = 4, maxLen = 4))
+
+  private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name ->
kv._1)
+  private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+    columnInfo.map(kv => kv._1.name -> kv)
+
+  test("empty group-by column") {
+    val colNames = Seq("key11", "key12")
+    // Suppose table1 has 2 records: (1, 10), (2, 10)
+    val table1 = StatsTestPlan(
+      outputList = colNames.map(nameToAttr),
+      stats = Statistics(
+        sizeInBytes = 2 * (4 + 4),
+        rowCount = Some(2),
+        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+    checkAggStats(
+      child = table1,
+      colNames = Nil,
+      expectedRowCount = 1)
+  }
+
+  test("there's a primary key in group-by columns") {
+    val colNames = Seq("key11", "key12")
+    // Suppose table1 has 2 records: (1, 10), (2, 10)
+    val table1 = StatsTestPlan(
+      outputList = colNames.map(nameToAttr),
+      stats = Statistics(
+        sizeInBytes = 2 * (4 + 4),
+        rowCount = Some(2),
+        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+    checkAggStats(
+      child = table1,
+      colNames = colNames,
+      // Column key11 a primary key, so row count = ndv of key11 = child's row count
+      expectedRowCount = table1.stats.rowCount.get)
+  }
+
+  test("the product of ndv's of group-by columns is too large") {
+    val colNames = Seq("key21", "key22")
+    // Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
+    val table2 = StatsTestPlan(
+      outputList = colNames.map(nameToAttr),
+      stats = Statistics(
+        sizeInBytes = 4 * (4 + 4),
+        rowCount = Some(4),
+        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+    checkAggStats(
+      child = table2,
+      colNames = colNames,
+      // Use child's row count as an upper bound
+      expectedRowCount = table2.stats.rowCount.get)
+  }
+
+  test("data contains all combinations of distinct values of group-by columns.") {
+    val colNames = Seq("key31", "key32")
+    // Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
+    val table3 = StatsTestPlan(
+      outputList = colNames.map(nameToAttr),
+      stats = Statistics(
+        sizeInBytes = 6 * (4 + 4),
+        rowCount = Some(6),
+        attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+    checkAggStats(
+      child = table3,
+      colNames = colNames,
+      // Row count = product of ndv
+      expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2
+        .distinctCount)
+  }
+
+  private def checkAggStats(
+      child: LogicalPlan,
+      colNames: Seq[String],
+      expectedRowCount: BigInt): Unit = {
+
+    val columns = colNames.map(nameToAttr)
+    val testAgg = Aggregate(
+      groupingExpressions = columns,
+      aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(),
+      child = child)
+
+    val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo))
+    val expectedStats = Statistics(
+      sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats),
+      rowCount = Some(expectedRowCount),
+      attributeStats = expectedAttrStats)
+
+    assert(testAgg.statistics == expectedStats)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/15c2bd01/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index fa5b290..0d81aa3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -18,12 +18,15 @@
 package org.apache.spark.sql.catalyst.statsEstimation
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.types.IntegerType
 
 
 class StatsEstimationTestBase extends SparkFunSuite {
 
+  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()
+
   /** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */
   def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan)
     : AttributeMap[ColumnStat] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message