spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-21984][SQL] Join estimation based on equi-height histogram
Date Tue, 19 Dec 2017 13:55:30 GMT
Repository: spark
Updated Branches:
  refs/heads/master ab7346f20 -> 571aa2755


[SPARK-21984][SQL] Join estimation based on equi-height histogram

## What changes were proposed in this pull request?

Equi-height histogram is one of the state-of-the-art statistics for cardinality estimation,
which can provide better estimation accuracy, and good at cases with skew data.

This PR is to improve join estimation based on equi-height histogram. The difference from
basic estimation (based on ndv) is the logic for computing join cardinality and the new ndv
after join.

The main idea is as follows:
1. find overlapped ranges between two histograms from two join keys;
2. apply the formula `T(A IJ B) = T(A) * T(B) / max(V(A.k1), V(B.k1))` in each overlapped
range.

## How was this patch tested?
Added new test cases.

Author: Zhenhua Wang <wangzhenhua@huawei.com>

Closes #19594 from wzhfy/join_estimation_histogram.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/571aa275
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/571aa275
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/571aa275

Branch: refs/heads/master
Commit: 571aa275541d71dbef8f0c86eab4ef04d56e4394
Parents: ab7346f
Author: Zhenhua Wang <wangzhenhua@huawei.com>
Authored: Tue Dec 19 21:55:21 2017 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Dec 19 21:55:21 2017 +0800

----------------------------------------------------------------------
 .../statsEstimation/EstimationUtils.scala       | 169 +++++++++++++++
 .../statsEstimation/JoinEstimation.scala        |  54 ++++-
 .../statsEstimation/JoinEstimationSuite.scala   | 209 ++++++++++++++++++-
 3 files changed, 428 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/571aa275/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 6f868cb..71e852a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
 
+import scala.collection.mutable.ArrayBuffer
 import scala.math.BigDecimal.RoundingMode
 
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
@@ -212,4 +213,172 @@ object EstimationUtils {
     }
   }
 
+  /**
+   * Returns overlapped ranges between two histograms, in the given value range
+   * [lowerBound, upperBound].
+   */
+  def getOverlappedRanges(
+    leftHistogram: Histogram,
+    rightHistogram: Histogram,
+    lowerBound: Double,
+    upperBound: Double): Seq[OverlappedRange] = {
+    val overlappedRanges = new ArrayBuffer[OverlappedRange]()
+    // Only bins whose range intersect [lowerBound, upperBound] have join possibility.
+    val leftBins = leftHistogram.bins
+      .filter(b => b.lo <= upperBound && b.hi >= lowerBound)
+    val rightBins = rightHistogram.bins
+      .filter(b => b.lo <= upperBound && b.hi >= lowerBound)
+
+    leftBins.foreach { lb =>
+      rightBins.foreach { rb =>
+        val (left, leftHeight) = trimBin(lb, leftHistogram.height, lowerBound, upperBound)
+        val (right, rightHeight) = trimBin(rb, rightHistogram.height, lowerBound, upperBound)
+        // Only collect overlapped ranges.
+        if (left.lo <= right.hi && left.hi >= right.lo) {
+          // Collect overlapped ranges.
+          val range = if (right.lo >= left.lo && right.hi >= left.hi) {
+            // Case1: the left bin is "smaller" than the right bin
+            //      left.lo            right.lo     left.hi          right.hi
+            // --------+------------------+------------+----------------+------->
+            if (left.hi == right.lo) {
+              // The overlapped range has only one value.
+              OverlappedRange(
+                lo = right.lo,
+                hi = right.lo,
+                leftNdv = 1,
+                rightNdv = 1,
+                leftNumRows = leftHeight / left.ndv,
+                rightNumRows = rightHeight / right.ndv
+              )
+            } else {
+              val leftRatio = (left.hi - right.lo) / (left.hi - left.lo)
+              val rightRatio = (left.hi - right.lo) / (right.hi - right.lo)
+              OverlappedRange(
+                lo = right.lo,
+                hi = left.hi,
+                leftNdv = left.ndv * leftRatio,
+                rightNdv = right.ndv * rightRatio,
+                leftNumRows = leftHeight * leftRatio,
+                rightNumRows = rightHeight * rightRatio
+              )
+            }
+          } else if (right.lo <= left.lo && right.hi <= left.hi) {
+            // Case2: the left bin is "larger" than the right bin
+            //      right.lo           left.lo      right.hi         left.hi
+            // --------+------------------+------------+----------------+------->
+            if (right.hi == left.lo) {
+              // The overlapped range has only one value.
+              OverlappedRange(
+                lo = right.hi,
+                hi = right.hi,
+                leftNdv = 1,
+                rightNdv = 1,
+                leftNumRows = leftHeight / left.ndv,
+                rightNumRows = rightHeight / right.ndv
+              )
+            } else {
+              val leftRatio = (right.hi - left.lo) / (left.hi - left.lo)
+              val rightRatio = (right.hi - left.lo) / (right.hi - right.lo)
+              OverlappedRange(
+                lo = left.lo,
+                hi = right.hi,
+                leftNdv = left.ndv * leftRatio,
+                rightNdv = right.ndv * rightRatio,
+                leftNumRows = leftHeight * leftRatio,
+                rightNumRows = rightHeight * rightRatio
+              )
+            }
+          } else if (right.lo >= left.lo && right.hi <= left.hi) {
+            // Case3: the left bin contains the right bin
+            //      left.lo            right.lo     right.hi         left.hi
+            // --------+------------------+------------+----------------+------->
+            val leftRatio = (right.hi - right.lo) / (left.hi - left.lo)
+            OverlappedRange(
+              lo = right.lo,
+              hi = right.hi,
+              leftNdv = left.ndv * leftRatio,
+              rightNdv = right.ndv,
+              leftNumRows = leftHeight * leftRatio,
+              rightNumRows = rightHeight
+            )
+          } else {
+            assert(right.lo <= left.lo && right.hi >= left.hi)
+            // Case4: the right bin contains the left bin
+            //      right.lo           left.lo      left.hi          right.hi
+            // --------+------------------+------------+----------------+------->
+            val rightRatio = (left.hi - left.lo) / (right.hi - right.lo)
+            OverlappedRange(
+              lo = left.lo,
+              hi = left.hi,
+              leftNdv = left.ndv,
+              rightNdv = right.ndv * rightRatio,
+              leftNumRows = leftHeight,
+              rightNumRows = rightHeight * rightRatio
+            )
+          }
+          overlappedRanges += range
+        }
+      }
+    }
+    overlappedRanges
+  }
+
+  /**
+   * Given an original bin and a value range [lowerBound, upperBound], returns the trimmed
part
+   * of the bin in that range and its number of rows.
+   * @param bin the input histogram bin.
+   * @param height the number of rows of the given histogram bin inside an equi-height histogram.
+   * @param lowerBound lower bound of the given range.
+   * @param upperBound upper bound of the given range.
+   * @return trimmed part of the given bin and its number of rows.
+   */
+  def trimBin(bin: HistogramBin, height: Double, lowerBound: Double, upperBound: Double)
+  : (HistogramBin, Double) = {
+    val (lo, hi) = if (bin.lo <= lowerBound && bin.hi >= upperBound) {
+      //       bin.lo          lowerBound     upperBound      bin.hi
+      // --------+------------------+------------+-------------+------->
+      (lowerBound, upperBound)
+    } else if (bin.lo <= lowerBound && bin.hi >= lowerBound) {
+      //       bin.lo          lowerBound      bin.hi      upperBound
+      // --------+------------------+------------+-------------+------->
+      (lowerBound, bin.hi)
+    } else if (bin.lo <= upperBound && bin.hi >= upperBound) {
+      //    lowerBound            bin.lo     upperBound       bin.hi
+      // --------+------------------+------------+-------------+------->
+      (bin.lo, upperBound)
+    } else {
+      //    lowerBound            bin.lo        bin.hi     upperBound
+      // --------+------------------+------------+-------------+------->
+      assert(bin.lo >= lowerBound && bin.hi <= upperBound)
+      (bin.lo, bin.hi)
+    }
+
+    if (hi == lo) {
+      // Note that bin.hi == bin.lo also falls into this branch.
+      (HistogramBin(lo, hi, 1), height / bin.ndv)
+    } else {
+      assert(bin.hi != bin.lo)
+      val ratio = (hi - lo) / (bin.hi - bin.lo)
+      (HistogramBin(lo, hi, math.ceil(bin.ndv * ratio).toLong), height * ratio)
+    }
+  }
+
+  /**
+   * A join between two equi-height histograms may produce multiple overlapped ranges.
+   * Each overlapped range is produced by a part of one bin in the left histogram and a part
of
+   * one bin in the right histogram.
+   * @param lo lower bound of this overlapped range.
+   * @param hi higher bound of this overlapped range.
+   * @param leftNdv ndv in the left part.
+   * @param rightNdv ndv in the right part.
+   * @param leftNumRows number of rows in the left part.
+   * @param rightNumRows number of rows in the right part.
+   */
+  case class OverlappedRange(
+    lo: Double,
+    hi: Double,
+    leftNdv: Double,
+    rightNdv: Double,
+    leftNumRows: Double,
+    rightNumRows: Double)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/571aa275/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index b073108..f0294a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference,
Expression}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Histogram, Join, Statistics}
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
 
 
@@ -191,8 +191,19 @@ case class JoinEstimation(join: Join) extends Logging {
       val rInterval = ValueInterval(rightKeyStat.min, rightKeyStat.max, rightKey.dataType)
       if (ValueInterval.isIntersected(lInterval, rInterval)) {
         val (newMin, newMax) = ValueInterval.intersect(lInterval, rInterval, leftKey.dataType)
-        val (card, joinStat) = computeByNdv(leftKey, rightKey, newMin, newMax)
-        keyStatsAfterJoin += (leftKey -> joinStat, rightKey -> joinStat)
+        val (card, joinStat) = (leftKeyStat.histogram, rightKeyStat.histogram) match {
+          case (Some(l: Histogram), Some(r: Histogram)) =>
+            computeByHistogram(leftKey, rightKey, l, r, newMin, newMax)
+          case _ =>
+            computeByNdv(leftKey, rightKey, newMin, newMax)
+        }
+        keyStatsAfterJoin += (
+          // Histograms are propagated as unchanged. During future estimation, they should
be
+          // truncated by the updated max/min. In this way, only pointers of the histograms
are
+          // propagated and thus reduce memory consumption.
+          leftKey -> joinStat.copy(histogram = leftKeyStat.histogram),
+          rightKey -> joinStat.copy(histogram = rightKeyStat.histogram)
+        )
         // Return cardinality estimated from the most selective join keys.
         if (card < joinCard) joinCard = card
       } else {
@@ -225,6 +236,43 @@ case class JoinEstimation(join: Join) extends Logging {
     (ceil(card), newStats)
   }
 
+  /** Compute join cardinality using equi-height histograms. */
+  private def computeByHistogram(
+      leftKey: AttributeReference,
+      rightKey: AttributeReference,
+      leftHistogram: Histogram,
+      rightHistogram: Histogram,
+      newMin: Option[Any],
+      newMax: Option[Any]): (BigInt, ColumnStat) = {
+    val overlappedRanges = getOverlappedRanges(
+      leftHistogram = leftHistogram,
+      rightHistogram = rightHistogram,
+      // Only numeric values have equi-height histograms.
+      lowerBound = newMin.get.toString.toDouble,
+      upperBound = newMax.get.toString.toDouble)
+
+    var card: BigDecimal = 0
+    var totalNdv: Double = 0
+    for (i <- overlappedRanges.indices) {
+      val range = overlappedRanges(i)
+      if (i == 0 || range.hi != overlappedRanges(i - 1).hi) {
+        // If range.hi == overlappedRanges(i - 1).hi, that means the current range has only
one
+        // value, and this value is already counted in the previous range. So there is no
need to
+        // count it in this range.
+        totalNdv += math.min(range.leftNdv, range.rightNdv)
+      }
+      // Apply the formula in this overlapped range.
+      card += range.leftNumRows * range.rightNumRows / math.max(range.leftNdv, range.rightNdv)
+    }
+
+    val leftKeyStat = leftStats.attributeStats(leftKey)
+    val rightKeyStat = rightStats.attributeStats(rightKey)
+    val newMaxLen = math.min(leftKeyStat.maxLen, rightKeyStat.maxLen)
+    val newAvgLen = (leftKeyStat.avgLen + rightKeyStat.avgLen) / 2
+    val newStats = ColumnStat(ceil(totalNdv), newMin, newMax, 0, newAvgLen, newMaxLen)
+    (ceil(card), newStats)
+  }
+
   /**
    * Propagate or update column stats for output attributes.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/571aa275/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
index 097c78e..26139d8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeReference,
EqualTo}
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types.{DateType, TimestampType, _}
@@ -67,6 +67,213 @@ class JoinEstimationSuite extends StatsEstimationTestBase {
     rowCount = 2,
     attributeStats = AttributeMap(Seq("key-1-2", "key-2-3").map(nameToColInfo)))
 
+  private def estimateByHistogram(
+      leftHistogram: Histogram,
+      rightHistogram: Histogram,
+      expectedMin: Double,
+      expectedMax: Double,
+      expectedNdv: Long,
+      expectedRows: Long): Unit = {
+    val col1 = attr("key1")
+    val col2 = attr("key2")
+    val c1 = generateJoinChild(col1, leftHistogram, expectedMin, expectedMax)
+    val c2 = generateJoinChild(col2, rightHistogram, expectedMin, expectedMax)
+
+    val c1JoinC2 = Join(c1, c2, Inner, Some(EqualTo(col1, col2)))
+    val c2JoinC1 = Join(c2, c1, Inner, Some(EqualTo(col2, col1)))
+    val expectedStatsAfterJoin = Statistics(
+      sizeInBytes = expectedRows * (8 + 2 * 4),
+      rowCount = Some(expectedRows),
+      attributeStats = AttributeMap(Seq(
+        col1 -> c1.stats.attributeStats(col1).copy(
+          distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax)),
+        col2 -> c2.stats.attributeStats(col2).copy(
+          distinctCount = expectedNdv, min = Some(expectedMin), max = Some(expectedMax))))
+    )
+
+    // Join order should not affect estimation result.
+    Seq(c1JoinC2, c2JoinC1).foreach { join =>
+      assert(join.stats == expectedStatsAfterJoin)
+    }
+  }
+
+  private def generateJoinChild(
+      col: Attribute,
+      histogram: Histogram,
+      expectedMin: Double,
+      expectedMax: Double): LogicalPlan = {
+    val colStat = inferColumnStat(histogram)
+    StatsTestPlan(
+      outputList = Seq(col),
+      rowCount = (histogram.height * histogram.bins.length).toLong,
+      attributeStats = AttributeMap(Seq(col -> colStat)))
+  }
+
+  /** Column statistics should be consistent with histograms in tests. */
+  private def inferColumnStat(histogram: Histogram): ColumnStat = {
+    var ndv = 0L
+    for (i <- histogram.bins.indices) {
+      val bin = histogram.bins(i)
+      if (i == 0 || bin.hi != histogram.bins(i - 1).hi) {
+        ndv += bin.ndv
+      }
+    }
+    ColumnStat(distinctCount = ndv, min = Some(histogram.bins.head.lo),
+      max = Some(histogram.bins.last.hi), nullCount = 0, avgLen = 4, maxLen = 4,
+      histogram = Some(histogram))
+  }
+
+  test("equi-height histograms: a bin is contained by another one") {
+    val histogram1 = Histogram(height = 300, Array(
+      HistogramBin(lo = 10, hi = 30, ndv = 10), HistogramBin(lo = 30, hi = 60, ndv = 30)))
+    val histogram2 = Histogram(height = 100, Array(
+      HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40)))
+    // test bin trimming
+    val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 10, upperBound
= 60)
+    assert(t0 == HistogramBin(lo = 10, hi = 50, ndv = 40) && h0 == 80)
+    val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 10, upperBound
= 60)
+    assert(t1 == HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20)
+
+    val expectedRanges = Seq(
+      // histogram1.bins(0) overlaps t0
+      OverlappedRange(10, 30, 10, 40 * 1 / 2, 300, 80 * 1 / 2),
+      // histogram1.bins(1) overlaps t0
+      OverlappedRange(30, 50, 30 * 2 / 3, 40 * 1 / 2, 300 * 2 / 3, 80 * 1 / 2),
+      // histogram1.bins(1) overlaps t1
+      OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20)
+    )
+    assert(expectedRanges.equals(
+      getOverlappedRanges(histogram1, histogram2, lowerBound = 10, upperBound = 60)))
+
+    estimateByHistogram(
+      leftHistogram = histogram1,
+      rightHistogram = histogram2,
+      expectedMin = 10,
+      expectedMax = 60,
+      expectedNdv = 10 + 20 + 8,
+      expectedRows = 300 * 40 / 20 + 200 * 40 / 20 + 100 * 20 / 10)
+  }
+
+  test("equi-height histograms: a bin has only one value after trimming") {
+    val histogram1 = Histogram(height = 300, Array(
+      HistogramBin(lo = 50, hi = 60, ndv = 10), HistogramBin(lo = 60, hi = 75, ndv = 3)))
+    val histogram2 = Histogram(height = 100, Array(
+      HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40)))
+    // test bin trimming
+    val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 50, upperBound
= 75)
+    assert(t0 == HistogramBin(lo = 50, hi = 50, ndv = 1) && h0 == 2)
+    val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 50, upperBound
= 75)
+    assert(t1 == HistogramBin(lo = 50, hi = 75, ndv = 20) && h1 == 50)
+
+    val expectedRanges = Seq(
+      // histogram1.bins(0) overlaps t0
+      OverlappedRange(50, 50, 1, 1, 300 / 10, 2),
+      // histogram1.bins(0) overlaps t1
+      OverlappedRange(50, 60, 10, 20 * 10 / 25, 300, 50 * 10 / 25),
+      // histogram1.bins(1) overlaps t1
+      OverlappedRange(60, 75, 3, 20 * 15 / 25, 300, 50 * 15 / 25)
+    )
+    assert(expectedRanges.equals(
+      getOverlappedRanges(histogram1, histogram2, lowerBound = 50, upperBound = 75)))
+
+    estimateByHistogram(
+      leftHistogram = histogram1,
+      rightHistogram = histogram2,
+      expectedMin = 50,
+      expectedMax = 75,
+      expectedNdv = 1 + 8 + 3,
+      expectedRows = 30 * 2 / 1 + 300 * 20 / 10 + 300 * 30 / 12)
+  }
+
+  test("equi-height histograms: skew distribution (some bins have only one value)") {
+    val histogram1 = Histogram(height = 300, Array(
+      HistogramBin(lo = 30, hi = 30, ndv = 1),
+      HistogramBin(lo = 30, hi = 30, ndv = 1),
+      HistogramBin(lo = 30, hi = 60, ndv = 30)))
+    val histogram2 = Histogram(height = 100, Array(
+      HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 100, ndv = 40)))
+    // test bin trimming
+    val (t0, h0) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound
= 60)
+    assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 40)
+    val (t1, h1) = trimBin(histogram2.bins(1), height = 100, lowerBound = 30, upperBound
= 60)
+    assert(t1 ==HistogramBin(lo = 50, hi = 60, ndv = 8) && h1 == 20)
+
+    val expectedRanges = Seq(
+      OverlappedRange(30, 30, 1, 1, 300, 40 / 20),
+      OverlappedRange(30, 30, 1, 1, 300, 40 / 20),
+      OverlappedRange(30, 50, 30 * 2 / 3, 20, 300 * 2 / 3, 40),
+      OverlappedRange(50, 60, 30 * 1 / 3, 8, 300 * 1 / 3, 20)
+    )
+    assert(expectedRanges.equals(
+      getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 60)))
+
+    estimateByHistogram(
+      leftHistogram = histogram1,
+      rightHistogram = histogram2,
+      expectedMin = 30,
+      expectedMax = 60,
+      expectedNdv = 1 + 20 + 8,
+      expectedRows = 300 * 2 / 1 + 300 * 2 / 1 + 200 * 40 / 20 + 100 * 20 / 10)
+  }
+
+  test("equi-height histograms: skew distribution (histograms have different skewed values")
{
+    val histogram1 = Histogram(height = 300, Array(
+      HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30)))
+    val histogram2 = Histogram(height = 100, Array(
+      HistogramBin(lo = 0, hi = 50, ndv = 50), HistogramBin(lo = 50, hi = 50, ndv = 1)))
+    // test bin trimming
+    val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound
= 50)
+    assert(t0 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h0 == 200)
+    val (t1, h1) = trimBin(histogram2.bins(0), height = 100, lowerBound = 30, upperBound
= 50)
+    assert(t1 == HistogramBin(lo = 30, hi = 50, ndv = 20) && h1 == 40)
+
+    val expectedRanges = Seq(
+      OverlappedRange(30, 30, 1, 1, 300, 40 / 20),
+      OverlappedRange(30, 50, 20, 20, 200, 40),
+      OverlappedRange(50, 50, 1, 1, 200 / 20, 100)
+    )
+    assert(expectedRanges.equals(
+      getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 50)))
+
+    estimateByHistogram(
+      leftHistogram = histogram1,
+      rightHistogram = histogram2,
+      expectedMin = 30,
+      expectedMax = 50,
+      expectedNdv = 1 + 20,
+      expectedRows = 300 * 2 / 1 + 200 * 40 / 20 + 10 * 100 / 1)
+  }
+
+  test("equi-height histograms: skew distribution (both histograms have the same skewed value")
{
+    val histogram1 = Histogram(height = 300, Array(
+      HistogramBin(lo = 30, hi = 30, ndv = 1), HistogramBin(lo = 30, hi = 60, ndv = 30)))
+    val histogram2 = Histogram(height = 150, Array(
+      HistogramBin(lo = 0, hi = 30, ndv = 30), HistogramBin(lo = 30, hi = 30, ndv = 1)))
+    // test bin trimming
+    val (t0, h0) = trimBin(histogram1.bins(1), height = 300, lowerBound = 30, upperBound
= 30)
+    assert(t0 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h0 == 10)
+    val (t1, h1) = trimBin(histogram2.bins(0), height = 150, lowerBound = 30, upperBound
= 30)
+    assert(t1 == HistogramBin(lo = 30, hi = 30, ndv = 1) && h1 == 5)
+
+    val expectedRanges = Seq(
+      OverlappedRange(30, 30, 1, 1, 300, 5),
+      OverlappedRange(30, 30, 1, 1, 300, 150),
+      OverlappedRange(30, 30, 1, 1, 10, 5),
+      OverlappedRange(30, 30, 1, 1, 10, 150)
+    )
+    assert(expectedRanges.equals(
+      getOverlappedRanges(histogram1, histogram2, lowerBound = 30, upperBound = 30)))
+
+    estimateByHistogram(
+      leftHistogram = histogram1,
+      rightHistogram = histogram2,
+      expectedMin = 30,
+      expectedMax = 30,
+      // only one value: 30
+      expectedNdv = 1,
+      expectedRows = 300 * 5 / 1 + 300 * 150 / 1 + 10 * 5 / 1 + 10 * 150 / 1)
+  }
+
   test("cross join") {
     // table1 (key-1-5 int, key-5-9 int): (1, 9), (2, 8), (3, 7), (4, 6), (5, 5)
     // table2 (key-1-2 int, key-2-4 int): (1, 2), (2, 3), (2, 4)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message