spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject spark git commit: [SPARK-11490][SQL] variance should alias var_samp instead of var_pop.
Date Wed, 04 Nov 2015 17:34:58 GMT
Repository: spark
Updated Branches:
  refs/heads/master e0fc9c7e5 -> 3bd6f5d2a


[SPARK-11490][SQL] variance should alias var_samp instead of var_pop.

stddev is an alias for stddev_samp. variance should be consistent with stddev.

Also took the chance to remove internal Stddev and Variance, and only kept StddevSamp/StddevPop
and VarianceSamp/VariancePop.

Author: Reynold Xin <rxin@databricks.com>

Closes #9449 from rxin/SPARK-11490.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3bd6f5d2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3bd6f5d2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3bd6f5d2

Branch: refs/heads/master
Commit: 3bd6f5d2ae503468de0e218d51c331e249a862bb
Parents: e0fc9c7
Author: Reynold Xin <rxin@databricks.com>
Authored: Wed Nov 4 09:34:52 2015 -0800
Committer: Yin Huai <yhuai@databricks.com>
Committed: Wed Nov 4 09:34:52 2015 -0800

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |  4 +-
 .../catalyst/analysis/HiveTypeCoercion.scala    |  2 -
 .../apache/spark/sql/catalyst/dsl/package.scala |  8 ----
 .../expressions/aggregate/functions.scala       | 29 -------------
 .../catalyst/expressions/aggregate/utils.scala  | 12 ------
 .../sql/catalyst/expressions/aggregates.scala   | 45 +++++---------------
 .../scala/org/apache/spark/sql/DataFrame.scala  |  2 +-
 .../org/apache/spark/sql/GroupedData.scala      |  4 +-
 .../scala/org/apache/spark/sql/functions.scala  |  9 ++--
 .../spark/sql/DataFrameAggregateSuite.scala     | 17 +++-----
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 14 +++---
 11 files changed, 32 insertions(+), 114 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 24c1a7b..d4334d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -187,11 +187,11 @@ object FunctionRegistry {
     expression[Max]("max"),
     expression[Average]("mean"),
     expression[Min]("min"),
-    expression[Stddev]("stddev"),
+    expression[StddevSamp]("stddev"),
     expression[StddevPop]("stddev_pop"),
     expression[StddevSamp]("stddev_samp"),
     expression[Sum]("sum"),
-    expression[Variance]("variance"),
+    expression[VarianceSamp]("variance"),
     expression[VariancePop]("var_pop"),
     expression[VarianceSamp]("var_samp"),
     expression[Skewness]("skewness"),

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 3c67567..84e2b13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -297,10 +297,8 @@ object HiveTypeCoercion {
       case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
       case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
       case Average(e @ StringType()) => Average(Cast(e, DoubleType))
-      case Stddev(e @ StringType()) => Stddev(Cast(e, DoubleType))
       case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType))
       case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType))
-      case Variance(e @ StringType()) => Variance(Cast(e, DoubleType))
       case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType))
       case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType))
       case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType))

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 787f67a..d8df664 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -159,14 +159,6 @@ package object dsl {
     def lower(e: Expression): Expression = Lower(e)
     def sqrt(e: Expression): Expression = Sqrt(e)
     def abs(e: Expression): Expression = Abs(e)
-    def stddev(e: Expression): Expression = Stddev(e)
-    def stddev_pop(e: Expression): Expression = StddevPop(e)
-    def stddev_samp(e: Expression): Expression = StddevSamp(e)
-    def variance(e: Expression): Expression = Variance(e)
-    def var_pop(e: Expression): Expression = VariancePop(e)
-    def var_samp(e: Expression): Expression = VarianceSamp(e)
-    def skewness(e: Expression): Expression = Skewness(e)
-    def kurtosis(e: Expression): Expression = Kurtosis(e)
 
     implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name
}
     // TODO more implicit class for literal?

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index f2c3eca..10dc5e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -328,13 +328,6 @@ case class Min(child: Expression) extends DeclarativeAggregate {
   override val evaluateExpression = min
 }
 
-// Compute the sample standard deviation of a column
-case class Stddev(child: Expression) extends StddevAgg(child) {
-
-  override def isSample: Boolean = true
-  override def prettyName: String = "stddev"
-}
-
 // Compute the population standard deviation of a column
 case class StddevPop(child: Expression) extends StddevAgg(child) {
 
@@ -1274,28 +1267,6 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate
w
   }
 }
 
-case class Variance(child: Expression,
-    mutableAggBufferOffset: Int = 0,
-    inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {
-
-  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate
=
-    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
-
-  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate
=
-    copy(inputAggBufferOffset = newInputAggBufferOffset)
-
-  override def prettyName: String = "variance"
-
-  override protected val momentOrder = 2
-
-  override def getStatistic(n: Double, mean: Double, moments: Array[Double]): Double = {
-    require(moments.length == momentOrder + 1,
-      s"$prettyName requires ${momentOrder + 1} central moments, received: ${moments.length}")
-
-    if (n == 0.0) Double.NaN else moments(2) / n
-  }
-}
-
 case class VarianceSamp(child: Expression,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) {

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
index 564174f..644c621 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -97,12 +97,6 @@ object Utils {
             mode = aggregate.Complete,
             isDistinct = false)
 
-        case expressions.Stddev(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Stddev(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
         case expressions.StddevPop(child) =>
           aggregate.AggregateExpression2(
             aggregateFunction = aggregate.StddevPop(child),
@@ -139,12 +133,6 @@ object Utils {
             mode = aggregate.Complete,
             isDistinct = false)
 
-        case expressions.Variance(child) =>
-          aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Variance(child),
-            mode = aggregate.Complete,
-            isDistinct = false)
-
         case expressions.VariancePop(child) =>
           aggregate.AggregateExpression2(
             aggregateFunction = aggregate.VariancePop(child),

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index bf59660..89d63ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -785,13 +785,6 @@ abstract class StddevAgg1(child: Expression) extends UnaryExpression
with Partia
 
 }
 
-// Compute the sample standard deviation of a column
-case class Stddev(child: Expression) extends StddevAgg1(child) {
-
-  override def toString: String = s"STDDEV($child)"
-  override def isSample: Boolean = true
-}
-
 // Compute the population standard deviation of a column
 case class StddevPop(child: Expression) extends StddevAgg1(child) {
 
@@ -807,20 +800,21 @@ case class StddevSamp(child: Expression) extends StddevAgg1(child) {
 }
 
 case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1
{
-    def this() = this(null)
-
-    override def children: Seq[Expression] = child :: Nil
-    override def nullable: Boolean = false
-    override def dataType: DataType = ArrayType(DoubleType)
-    override def toString: String = s"computePartialStddev($child)"
-    override def newInstance(): ComputePartialStdFunction =
-      new ComputePartialStdFunction(child, this)
+  def this() = this(null)
+
+  override def children: Seq[Expression] = child :: Nil
+  override def nullable: Boolean = false
+  override def dataType: DataType = ArrayType(DoubleType)
+  override def toString: String = s"computePartialStddev($child)"
+  override def newInstance(): ComputePartialStdFunction =
+    new ComputePartialStdFunction(child, this)
 }
 
 case class ComputePartialStdFunction (
     expr: Expression,
     base: AggregateExpression1
-) extends AggregateFunction1 {
+  ) extends AggregateFunction1 {
+
   def this() = this(null, null)  // Required for serialization
 
   private val computeType = DoubleType
@@ -1049,25 +1043,6 @@ case class Skewness(child: Expression) extends UnaryExpression with
AggregateExp
 }
 
 // placeholder
-case class Variance(child: Expression) extends UnaryExpression with AggregateExpression1
{
-
-  override def newInstance(): AggregateFunction1 = {
-    throw new UnsupportedOperationException("AggregateExpression1 is no longer supported,
" +
-      "please set spark.sql.useAggregate2 = true")
-  }
-
-  override def nullable: Boolean = false
-
-  override def dataType: DoubleType.type = DoubleType
-
-  override def foldable: Boolean = false
-
-  override def prettyName: String = "variance"
-
-  override def toString: String = s"VARIANCE($child)"
-}
-
-// placeholder
 case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1
{
 
   override def newInstance(): AggregateFunction1 = {

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index fc0ab63..5e9c7ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1383,7 +1383,7 @@ class DataFrame private[sql](
     val statistics = List[(String, Expression => Expression)](
       "count" -> Count,
       "mean" -> Average,
-      "stddev" -> Stddev,
+      "stddev" -> StddevSamp,
       "min" -> Min,
       "max" -> Max)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index c2b2a40..7cf66b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -96,10 +96,10 @@ class GroupedData protected[sql](
       case "avg" | "average" | "mean" => Average
       case "max" => Max
       case "min" => Min
-      case "stddev" | "std" => Stddev
+      case "stddev" | "std" => StddevSamp
       case "stddev_pop" => StddevPop
       case "stddev_samp" => StddevSamp
-      case "variance" => Variance
+      case "variance" => VarianceSamp
       case "var_pop" => VariancePop
       case "var_samp" => VarianceSamp
       case "sum" => Sum

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c8c5283..c70c965 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -329,13 +329,12 @@ object functions {
   def skewness(e: Column): Column = Skewness(e.expr)
 
   /**
-   * Aggregate function: returns the unbiased sample standard deviation of
-   * the expression in a group.
+   * Aggregate function: alias for [[stddev_samp]].
    *
    * @group agg_funcs
    * @since 1.6.0
    */
-  def stddev(e: Column): Column = Stddev(e.expr)
+  def stddev(e: Column): Column = StddevSamp(e.expr)
 
   /**
    * Aggregate function: returns the unbiased sample standard deviation of
@@ -388,12 +387,12 @@ object functions {
   def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
 
   /**
-   * Aggregate function: returns the population variance of the values in a group.
+   * Aggregate function: alias for [[var_samp]].
    *
    * @group agg_funcs
    * @since 1.6.0
    */
-  def variance(e: Column): Column = Variance(e.expr)
+  def variance(e: Column): Column = VarianceSamp(e.expr)
 
   /**
    * Aggregate function: returns the unbiased variance of the values in a group.

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 9b23977..b0e2ffa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -226,23 +226,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
     val absTol = 1e-8
 
     val sparkVariance = testData2.agg(variance('a))
-    val expectedVariance = Row(4.0 / 6.0)
-    checkAggregatesWithTol(sparkVariance, expectedVariance, absTol)
+    checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)
     val sparkVariancePop = testData2.agg(var_pop('a))
-    checkAggregatesWithTol(sparkVariancePop, expectedVariance, absTol)
+    checkAggregatesWithTol(sparkVariancePop, Row(4.0 / 6.0), absTol)
 
     val sparkVarianceSamp = testData2.agg(var_samp('a))
-    val expectedVarianceSamp = Row(4.0 / 5.0)
-    checkAggregatesWithTol(sparkVarianceSamp, expectedVarianceSamp, absTol)
+    checkAggregatesWithTol(sparkVarianceSamp, Row(4.0 / 5.0), absTol)
 
     val sparkSkewness = testData2.agg(skewness('a))
-    val expectedSkewness = Row(0.0)
-    checkAggregatesWithTol(sparkSkewness, expectedSkewness, absTol)
+    checkAggregatesWithTol(sparkSkewness, Row(0.0), absTol)
 
     val sparkKurtosis = testData2.agg(kurtosis('a))
-    val expectedKurtosis = Row(-1.5)
-    checkAggregatesWithTol(sparkKurtosis, expectedKurtosis, absTol)
-
+    checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol)
   }
 
   test("zero moments") {
@@ -251,7 +246,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
 
     checkAnswer(
       emptyTableData.agg(variance('a)),
-      Row(0.0))
+      Row(Double.NaN))
 
     checkAnswer(
       emptyTableData.agg(var_samp('a)),

http://git-wip-us.apache.org/repos/asf/spark/blob/3bd6f5d2/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 6388a8b..5731a35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -536,7 +536,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
     checkAnswer(
       sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
         "AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) FROM nullInts"),
-      Row(0, -1.5, 1, 3, 2, 2.0 / 3.0, 1, 6, 3)
+      Row(0, -1.5, 1, 3, 2, 1.0, 1, 6, 3)
     )
   }
 
@@ -757,7 +757,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
   test("variance") {
     val absTol = 1e-8
     val sparkAnswer = sql("SELECT VARIANCE(a) FROM testData2")
-    val expectedAnswer = Row(4.0 / 6.0)
+    val expectedAnswer = Row(0.8)
     checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
   }
 
@@ -784,16 +784,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
 
   test("stddev agg") {
     checkAnswer(
-        sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
+      sql("SELECT a, stddev(b), stddev_pop(b), stddev_samp(b) FROM testData2 GROUP BY a"),
       (1 to 3).map(i => Row(i, math.sqrt(1.0 / 2.0), math.sqrt(1.0 / 4.0), math.sqrt(1.0
/ 2.0))))
   }
 
   test("variance agg") {
     val absTol = 1e-8
-    val sparkAnswer = sql("SELECT a, variance(b), var_samp(b), var_pop(b)" +
-      "FROM testData2 GROUP BY a")
-    val expectedAnswer = (1 to 3).map(i => Row(i, 1.0 / 4.0, 1.0 / 2.0, 1.0 / 4.0))
-    checkAggregatesWithTol(sparkAnswer, expectedAnswer, absTol)
+    checkAggregatesWithTol(
+      sql("SELECT a, variance(b), var_samp(b), var_pop(b) FROM testData2 GROUP BY a"),
+      (1 to 3).map(i => Row(i, 1.0 / 2.0, 1.0 / 2.0, 1.0 / 4.0)),
+      absTol)
   }
 
   test("skewness and kurtosis agg") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message