spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-19447] Fixing input metrics for range operator.
Date Tue, 07 Feb 2017 13:21:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master e99e34d0f -> 6ed285c68


[SPARK-19447] Fixing input metrics for range operator.

## What changes were proposed in this pull request?

This change introduces a new metric "number of generated rows". It is used exclusively for
Range, which is a leaf in the query tree, yet doesn't read any input data, and therefore cannot
report "recordsRead".

Additionally the way in which the metrics are reported by the JIT-compiled version of Range
was changed. Previously, it was immediately reported that all the records were produced. This
could be confusing for a user monitoring execution progress in the UI. Now, the metric is
updated gradually.

In order to avoid negative impact on Range performance, the code generation was reworked.
The values are now produced in batches in the tighter inner loop, while the metrics are updated
in the outer loop.

The change also contains a number of unit tests, which should help ensure the correctness
of metrics for various input sources.

## How was this patch tested?

Unit tests.

Author: Ala Luszczak <ala@databricks.com>

Closes #16829 from ala/SPARK-19447.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6ed285c6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6ed285c6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6ed285c6

Branch: refs/heads/master
Commit: 6ed285c68fee451c45db7b01ca8ec1dea2efd479
Parents: e99e34d
Author: Ala Luszczak <ala@databricks.com>
Authored: Tue Feb 7 14:21:30 2017 +0100
Committer: Reynold Xin <rxin@databricks.com>
Committed: Tue Feb 7 14:21:30 2017 +0100

----------------------------------------------------------------------
 .../sql/execution/basicPhysicalOperators.scala  |  82 ++++++++----
 .../apache/spark/sql/DataFrameRangeSuite.scala  | 130 ++++++++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   |  53 --------
 .../InputGeneratedOutputMetricsSuite.scala      | 131 +++++++++++++++++++
 .../org/apache/spark/sql/jdbc/JDBCSuite.scala   |  10 ++
 .../sql/hive/execution/HiveSerDeSuite.scala     |  19 +++
 6 files changed, 350 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index fb90799..792fb3e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -339,7 +339,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
   override val output: Seq[Attribute] = range.output
 
   override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+    "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows"))
 
   // output attributes should not affect the results
   override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
@@ -351,24 +352,37 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
 
   protected override def doProduce(ctx: CodegenContext): String = {
     val numOutput = metricTerm(ctx, "numOutputRows")
+    val numGenerated = metricTerm(ctx, "numGeneratedRows")
 
     val initTerm = ctx.freshName("initRange")
     ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
-    val partitionEnd = ctx.freshName("partitionEnd")
-    ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
     val number = ctx.freshName("number")
     ctx.addMutableState("long", number, s"$number = 0L;")
-    val overflow = ctx.freshName("overflow")
-    ctx.addMutableState("boolean", overflow, s"$overflow = false;")
 
     val value = ctx.freshName("value")
     val ev = ExprCode("", "false", value)
     val BigInt = classOf[java.math.BigInteger].getName
-    val checkEnd = if (step > 0) {
-      s"$number < $partitionEnd"
-    } else {
-      s"$number > $partitionEnd"
-    }
+
+    // In order to periodically update the metrics without inflicting performance penalty,
this
+    // operator produces elements in batches. After a batch is complete, the metrics are
updated
+    // and a new batch is started.
+    // In the implementation below, the code in the inner loop is producing all the values
+    // within a batch, while the code in the outer loop is setting batch parameters and updating
+    // the metrics.
+
+    // Once number == batchEnd, it's time to progress to the next batch.
+    val batchEnd = ctx.freshName("batchEnd")
+    ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;")
+
+    // How many values should still be generated by this range operator.
+    val numElementsTodo = ctx.freshName("numElementsTodo")
+    ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;")
+
+    // How many values should be generated in the next batch.
+    val nextBatchTodo = ctx.freshName("nextBatchTodo")
+
+    // The default size of a batch.
+    val batchSize = 1000L
 
     ctx.addNewFunction("initRange",
       s"""
@@ -378,6 +392,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
         |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
         |   $BigInt step = $BigInt.valueOf(${step}L);
         |   $BigInt start = $BigInt.valueOf(${start}L);
+        |   long partitionEnd;
         |
         |   $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
         |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
@@ -387,18 +402,26 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
         |   } else {
         |     $number = st.longValue();
         |   }
+        |   $batchEnd = $number;
         |
         |   $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
         |     .multiply(step).add(start);
         |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-        |     $partitionEnd = Long.MAX_VALUE;
+        |     partitionEnd = Long.MAX_VALUE;
         |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-        |     $partitionEnd = Long.MIN_VALUE;
+        |     partitionEnd = Long.MIN_VALUE;
         |   } else {
-        |     $partitionEnd = end.longValue();
+        |     partitionEnd = end.longValue();
         |   }
         |
-        |   $numOutput.add(($partitionEnd - $number) / ${step}L);
+        |   $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
+        |     $BigInt.valueOf($number));
+        |   $numElementsTodo  = startToEnd.divide(step).longValue();
+        |   if ($numElementsTodo < 0) {
+        |     $numElementsTodo = 0;
+        |   } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
+        |     $numElementsTodo++;
+        |   }
         | }
        """.stripMargin)
 
@@ -412,20 +435,34 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
       |   initRange(partitionIndex);
       | }
       |
-      | while (!$overflow && $checkEnd) {
-      |  long $value = $number;
-      |  $number += ${step}L;
-      |  if ($number < $value ^ ${step}L < 0) {
-      |    $overflow = true;
-      |  }
-      |  ${consume(ctx, Seq(ev))}
-      |  if (shouldStop()) return;
+      | while (true) {
+      |   while ($number != $batchEnd) {
+      |     long $value = $number;
+      |     $number += ${step}L;
+      |     ${consume(ctx, Seq(ev))}
+      |     if (shouldStop()) return;
+      |   }
+      |
+      |   long $nextBatchTodo;
+      |   if ($numElementsTodo > ${batchSize}L) {
+      |     $nextBatchTodo = ${batchSize}L;
+      |     $numElementsTodo -= ${batchSize}L;
+      |   } else {
+      |     $nextBatchTodo = $numElementsTodo;
+      |     $numElementsTodo = 0;
+      |     if ($nextBatchTodo == 0) break;
+      |   }
+      |   $numOutput.add($nextBatchTodo);
+      |   $numGenerated.add($nextBatchTodo);
+      |
+      |   $batchEnd += $nextBatchTodo * ${step}L;
       | }
      """.stripMargin
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
+    val numGeneratedRows = longMetric("numGeneratedRows")
     sqlContext
       .sparkContext
       .parallelize(0 until numSlices, numSlices)
@@ -469,6 +506,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
             }
 
             numOutputRows += 1
+            numGeneratedRows += 1
             unsafeRow.setLong(0, ret)
             unsafeRow
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
new file mode 100644
index 0000000..6d2d776
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.math.abs
+import scala.util.Random
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
+
+  test("SPARK-7150 range api") {
+    // numSlice is greater than length
+    val res1 = spark.range(0, 10, 1, 15).select("id")
+    assert(res1.count == 10)
+    assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+    val res2 = spark.range(3, 15, 3, 2).select("id")
+    assert(res2.count == 4)
+    assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+    val res3 = spark.range(1, -2).select("id")
+    assert(res3.count == 0)
+
+    // start is positive, end is negative, step is negative
+    val res4 = spark.range(1, -2, -2, 6).select("id")
+    assert(res4.count == 2)
+    assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
+
+    // start, end, step are negative
+    val res5 = spark.range(-3, -8, -2, 1).select("id")
+    assert(res5.count == 3)
+    assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
+
+    // start, end are negative, step is positive
+    val res6 = spark.range(-8, -4, 2, 1).select("id")
+    assert(res6.count == 2)
+    assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
+
+    val res7 = spark.range(-10, -9, -20, 1).select("id")
+    assert(res7.count == 0)
+
+    val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+    assert(res8.count == 3)
+    assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+    val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+    assert(res9.count == 2)
+    assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+
+    // only end provided as argument
+    val res10 = spark.range(10).select("id")
+    assert(res10.count == 10)
+    assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+    val res11 = spark.range(-1).select("id")
+    assert(res11.count == 0)
+
+    // using the default slice number
+    val res12 = spark.range(3, 15, 3).select("id")
+    assert(res12.count == 4)
+    assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+    // difference between range start and end does not fit in a 64-bit integer
+    val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
+    val res13 = spark.range(-n, n, n / 9).select("id")
+    assert(res13.count == 18)
+  }
+
+  test("Range with randomized parameters") {
+    val MAX_NUM_STEPS = 10L * 1000
+
+    val seed = System.currentTimeMillis()
+    val random = new Random(seed)
+
+    def randomBound(): Long = {
+      val n = if (random.nextBoolean()) {
+        random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
+      } else {
+        random.nextLong() / 2
+      }
+      if (random.nextBoolean()) n else -n
+    }
+
+    for (l <- 1 to 10) {
+      val start = randomBound()
+      val end = randomBound()
+      val numSteps = (abs(random.nextLong()) % MAX_NUM_STEPS) + 1
+      val stepAbs = (abs(end - start) / numSteps) + 1
+      val step = if (start < end) stepAbs else -stepAbs
+      val partitions = random.nextInt(20) + 1
+
+      val expCount = (start until end by step).size
+      val expSum = (start until end by step).sum
+
+      for (codegen <- List(false, true)) {
+        withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
+          val res = spark.range(start, end, step, partitions).toDF("id").
+            agg(count("id"), sum("id")).collect()
+
+          withClue(s"seed = $seed start = $start end = $end step = $step partitions = " +
+              s"$partitions codegen = $codegen") {
+            assert(!res.isEmpty)
+            assert(res.head.getLong(0) == expCount)
+            if (expCount > 0) {
+              assert(res.head.getLong(1) == expSum)
+            }
+          }
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6a190b9..e6338ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -979,59 +979,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
   }
 
-  test("SPARK-7150 range api") {
-    // numSlice is greater than length
-    val res1 = spark.range(0, 10, 1, 15).select("id")
-    assert(res1.count == 10)
-    assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
-    val res2 = spark.range(3, 15, 3, 2).select("id")
-    assert(res2.count == 4)
-    assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-
-    val res3 = spark.range(1, -2).select("id")
-    assert(res3.count == 0)
-
-    // start is positive, end is negative, step is negative
-    val res4 = spark.range(1, -2, -2, 6).select("id")
-    assert(res4.count == 2)
-    assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
-
-    // start, end, step are negative
-    val res5 = spark.range(-3, -8, -2, 1).select("id")
-    assert(res5.count == 3)
-    assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
-
-    // start, end are negative, step is positive
-    val res6 = spark.range(-8, -4, 2, 1).select("id")
-    assert(res6.count == 2)
-    assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
-
-    val res7 = spark.range(-10, -9, -20, 1).select("id")
-    assert(res7.count == 0)
-
-    val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
-    assert(res8.count == 3)
-    assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
-
-    val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
-    assert(res9.count == 2)
-    assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
-
-    // only end provided as argument
-    val res10 = spark.range(10).select("id")
-    assert(res10.count == 10)
-    assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
-    val res11 = spark.range(-1).select("id")
-    assert(res11.count == 0)
-
-    // using the default slice number
-    val res12 = spark.range(3, 15, 3).select("id")
-    assert(res12.count == 4)
-    assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-  }
-
   test("SPARK-8621: support empty string column name") {
     val df = Seq(Tuple1(1)).toDF("").as("t")
     // We should allow empty string as column name

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
new file mode 100644
index 0000000..ddd7a03
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io.File
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually
{
+
+  test("Range query input/output/generated metrics") {
+    val numRows = 150L
+    val numSelectedRows = 100L
+    val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
+      filter(x => x < numSelectedRows).toDF())
+
+    assert(res.recordsRead.sum === 0)
+    assert(res.shuffleRecordsRead.sum === 0)
+    assert(res.generatedRows === numRows :: Nil)
+    assert(res.outputRows === numSelectedRows :: numRows :: Nil)
+  }
+
+  test("Input/output/generated metrics with repartitioning") {
+    val numRows = 100L
+    val res = MetricsTestHelper.runAndGetMetrics(
+      spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
+
+    assert(res.recordsRead.sum === 0)
+    assert(res.shuffleRecordsRead.sum === numRows)
+    assert(res.generatedRows === numRows :: Nil)
+    assert(res.outputRows === 20 :: numRows :: Nil)
+  }
+
+  test("Input/output/generated metrics with more repartitioning") {
+    withTempDir { tempDir =>
+      val dir = new File(tempDir, "pqS").getCanonicalPath
+
+      spark.range(10).write.parquet(dir)
+      spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+      val res = MetricsTestHelper.runAndGetMetrics(
+        spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
+            .toDF()
+      )
+
+      assert(res.recordsRead.sum == 10)
+      assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
+      assert(res.generatedRows == 30 :: Nil)
+      assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
+    }
+  }
+}
+
+object MetricsTestHelper {
+  case class AggregatedMetricsResult(
+      recordsRead: List[Long],
+      shuffleRecordsRead: List[Long],
+      generatedRows: List[Long],
+      outputRows: List[Long])
+
+  private[this] def extractMetricValues(
+      df: DataFrame,
+      metricValues: Map[Long, String],
+      metricName: String): List[Long] = {
+    df.queryExecution.executedPlan.collect {
+      case plan if plan.metrics.contains(metricName) =>
+        metricValues(plan.metrics(metricName).id).toLong
+    }.toList.sorted
+  }
+
+  def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
+      AggregatedMetricsResult = {
+    val spark = df.sparkSession
+    val sparkContext = spark.sparkContext
+
+    var recordsRead = List[Long]()
+    var shuffleRecordsRead = List[Long]()
+    val listener = new SparkListener() {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+        if (taskEnd.taskMetrics != null) {
+          recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
+            recordsRead
+          shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
+            shuffleRecordsRead
+        }
+      }
+    }
+
+    val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
+
+    val prevUseWholeStageCodeGen =
+      spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
+    try {
+      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
+      sparkContext.listenerBus.waitUntilEmpty(10000)
+      sparkContext.addSparkListener(listener)
+      df.collect()
+      sparkContext.listenerBus.waitUntilEmpty(10000)
+    } finally {
+      spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
+    }
+
+    val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
+    val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
+    val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
+    val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
+
+    AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes,
outputRes)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 0396254..14fbe9f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation,
JdbcUtils}
+import org.apache.spark.sql.execution.MetricsTestHelper
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -915,4 +916,13 @@ class JDBCSuite extends SparkFunSuite
     }.getMessage
     assert(e2.contains("User specified schema not supported with `jdbc`"))
   }
+
+  test("Input/generated/output metrics on JDBC") {
+    val foobarCnt = spark.table("foobar").count()
+    val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
+    assert(res.recordsRead === foobarCnt :: Nil)
+    assert(res.shuffleRecordsRead.sum === 0)
+    assert(res.generatedRows.isEmpty)
+    assert(res.outputRows === foobarCnt :: Nil)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6ed285c6/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index ec620c2..35c41b5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.sql.execution.MetricsTestHelper
 import org.apache.spark.sql.hive.test.TestHive
 
 /**
@@ -47,4 +48,22 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll
{
   createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes")
 
   createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
+
+  test("Test input/generated/output metrics") {
+    import TestHive._
+
+    val episodesCnt = sql("select * from episodes").count()
+    val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
+    assert(episodesRes.recordsRead === episodesCnt :: Nil)
+    assert(episodesRes.shuffleRecordsRead.sum === 0)
+    assert(episodesRes.generatedRows.isEmpty)
+    assert(episodesRes.outputRows === episodesCnt :: Nil)
+
+    val serdeinsCnt = sql("select * from serdeins").count()
+    val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
+    assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
+    assert(serdeinsRes.shuffleRecordsRead.sum === 0)
+    assert(serdeinsRes.generatedRows.isEmpty)
+    assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message