spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ues...@apache.org
Subject spark git commit: [SPARK-23907][SQL] Add regr_* functions
Date Thu, 10 May 2018 11:38:58 GMT
Repository: spark
Updated Branches:
  refs/heads/master e3d434994 -> 94d671448


[SPARK-23907][SQL] Add regr_* functions

## What changes were proposed in this pull request?

The PR introduces regr_slope, regr_intercept, regr_r2, regr_sxx, regr_syy, regr_sxy, regr_avgx,
regr_avgy, regr_count.

The implementation of this functions mirrors Hive's one in HIVE-15978.

## How was this patch tested?

added UT (values compared with Hive)

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #21054 from mgaido91/SPARK-23907.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94d67144
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94d67144
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94d67144

Branch: refs/heads/master
Commit: 94d671448240c8f6da11d2523ba9e4ae5b56a410
Parents: e3d4349
Author: Marco Gaido <marcogaido91@gmail.com>
Authored: Thu May 10 20:38:52 2018 +0900
Committer: Takuya UESHIN <ueshin@databricks.com>
Committed: Thu May 10 20:38:52 2018 +0900

----------------------------------------------------------------------
 .../catalyst/analysis/FunctionRegistry.scala    |   9 +
 .../expressions/aggregate/Average.scala         |  47 +++--
 .../aggregate/CentralMomentAgg.scala            |  60 +++---
 .../catalyst/expressions/aggregate/Corr.scala   |  52 ++---
 .../catalyst/expressions/aggregate/Count.scala  |  47 +++--
 .../expressions/aggregate/Covariance.scala      |  36 ++--
 .../expressions/aggregate/regression.scala      | 190 +++++++++++++++++++
 .../scala/org/apache/spark/sql/functions.scala  | 172 +++++++++++++++++
 .../sql-tests/inputs/udaf-regrfunctions.sql     |  56 ++++++
 .../results/udaf-regrfunctions.sql.out          |  93 +++++++++
 .../spark/sql/DataFrameAggregateSuite.scala     |  71 ++++++-
 11 files changed, 721 insertions(+), 112 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 87b0911..087d000 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -299,6 +299,15 @@ object FunctionRegistry {
     expression[CollectList]("collect_list"),
     expression[CollectSet]("collect_set"),
     expression[CountMinSketchAgg]("count_min_sketch"),
+    expression[RegrCount]("regr_count"),
+    expression[RegrSXX]("regr_sxx"),
+    expression[RegrSYY]("regr_syy"),
+    expression[RegrAvgX]("regr_avgx"),
+    expression[RegrAvgY]("regr_avgy"),
+    expression[RegrSXY]("regr_sxy"),
+    expression[RegrSlope]("regr_slope"),
+    expression[RegrR2]("regr_r2"),
+    expression[RegrIntercept]("regr_intercept"),
 
     // string functions
     expression[Ascii]("ascii"),

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 708bdbf..a133bc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -23,24 +23,12 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
 
-@ExpressionDescription(
-  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
-case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes
{
-
-  override def prettyName: String = "avg"
-
-  override def children: Seq[Expression] = child :: Nil
+abstract class AverageLike(child: Expression) extends DeclarativeAggregate {
 
   override def nullable: Boolean = true
-
   // Return data type.
   override def dataType: DataType = resultType
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
-
-  override def checkInputDataTypes(): TypeCheckResult =
-    TypeUtils.checkForNumericExpr(child.dataType, "function average")
-
   private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(p, s) =>
       DecimalType.bounded(p + 4, s + 4)
@@ -62,14 +50,6 @@ case class Average(child: Expression) extends DeclarativeAggregate with
Implicit
     /* count = */ Literal(0L)
   )
 
-  override lazy val updateExpressions = Seq(
-    /* sum = */
-    Add(
-      sum,
-      Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
-    /* count = */ If(IsNull(child), count, count + 1L)
-  )
-
   override lazy val mergeExpressions = Seq(
     /* sum = */ sum.left + sum.right,
     /* count = */ count.left + count.right
@@ -85,4 +65,29 @@ case class Average(child: Expression) extends DeclarativeAggregate with
Implicit
     case _ =>
       Cast(sum, resultType) / Cast(count, resultType)
   }
+
+  protected def updateExpressionsDef: Seq[Expression] = Seq(
+    /* sum = */
+    Add(
+      sum,
+      Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
+    /* count = */ If(IsNull(child), count, count + 1L)
+  )
+
+  override lazy val updateExpressions = updateExpressionsDef
+}
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
+case class Average(child: Expression)
+  extends AverageLike(child) with ImplicitCastInputTypes {
+
+  override def prettyName: String = "avg"
+
+  override def children: Seq[Expression] = child :: Nil
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
+
+  override def checkInputDataTypes(): TypeCheckResult =
+    TypeUtils.checkForNumericExpr(child.dataType, "function average")
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 572d29c..6bbb083 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -67,35 +67,7 @@ abstract class CentralMomentAgg(child: Expression)
 
   override val initialValues: Seq[Expression] = Array.fill(momentOrder + 1)(Literal(0.0))
 
-  override val updateExpressions: Seq[Expression] = {
-    val newN = n + Literal(1.0)
-    val delta = child - avg
-    val deltaN = delta / newN
-    val newAvg = avg + deltaN
-    val newM2 = m2 + delta * (delta - deltaN)
-
-    val delta2 = delta * delta
-    val deltaN2 = deltaN * deltaN
-    val newM3 = if (momentOrder >= 3) {
-      m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
-    } else {
-      Literal(0.0)
-    }
-    val newM4 = if (momentOrder >= 4) {
-      m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
-        delta * (delta * delta2 - deltaN * deltaN2)
-    } else {
-      Literal(0.0)
-    }
-
-    trimHigherOrder(Seq(
-      If(IsNull(child), n, newN),
-      If(IsNull(child), avg, newAvg),
-      If(IsNull(child), m2, newM2),
-      If(IsNull(child), m3, newM3),
-      If(IsNull(child), m4, newM4)
-    ))
-  }
+  override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
 
   override val mergeExpressions: Seq[Expression] = {
 
@@ -128,6 +100,36 @@ abstract class CentralMomentAgg(child: Expression)
 
     trimHigherOrder(Seq(newN, newAvg, newM2, newM3, newM4))
   }
+
+  protected def updateExpressionsDef: Seq[Expression] = {
+    val newN = n + Literal(1.0)
+    val delta = child - avg
+    val deltaN = delta / newN
+    val newAvg = avg + deltaN
+    val newM2 = m2 + delta * (delta - deltaN)
+
+    val delta2 = delta * delta
+    val deltaN2 = deltaN * deltaN
+    val newM3 = if (momentOrder >= 3) {
+      m3 - Literal(3.0) * deltaN * newM2 + delta * (delta2 - deltaN2)
+    } else {
+      Literal(0.0)
+    }
+    val newM4 = if (momentOrder >= 4) {
+      m4 - Literal(4.0) * deltaN * newM3 - Literal(6.0) * deltaN2 * newM2 +
+        delta * (delta * delta2 - deltaN * deltaN2)
+    } else {
+      Literal(0.0)
+    }
+
+    trimHigherOrder(Seq(
+      If(IsNull(child), n, newN),
+      If(IsNull(child), avg, newAvg),
+      If(IsNull(child), m2, newM2),
+      If(IsNull(child), m3, newM3),
+      If(IsNull(child), m4, newM4)
+    ))
+  }
 }
 
 // Compute the population standard deviation of a column

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 95a4a0d..3cdef72 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -22,17 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
 /**
- * Compute Pearson correlation between two expressions.
+ * Base class for computing Pearson correlation between two expressions.
  * When applied on empty data (i.e., count is zero), it returns NULL.
  *
  * Definition of Pearson correlation can be found at
  * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
  */
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set
of number pairs.")
-// scalastyle:on line.size.limit
-case class Corr(x: Expression, y: Expression)
+abstract class PearsonCorrelation(x: Expression, y: Expression)
   extends DeclarativeAggregate with ImplicitCastInputTypes {
 
   override def children: Seq[Expression] = Seq(x, y)
@@ -51,7 +47,26 @@ case class Corr(x: Expression, y: Expression)
 
   override val initialValues: Seq[Expression] = Array.fill(6)(Literal(0.0))
 
-  override val updateExpressions: Seq[Expression] = {
+  override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
+
+  override val mergeExpressions: Seq[Expression] = {
+    val n1 = n.left
+    val n2 = n.right
+    val newN = n1 + n2
+    val dx = xAvg.right - xAvg.left
+    val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
+    val dy = yAvg.right - yAvg.left
+    val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
+    val newXAvg = xAvg.left + dxN * n2
+    val newYAvg = yAvg.left + dyN * n2
+    val newCk = ck.left + ck.right + dx * dyN * n1 * n2
+    val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
+    val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
+
+    Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
+  }
+
+  protected def updateExpressionsDef: Seq[Expression] = {
     val newN = n + Literal(1.0)
     val dx = x - xAvg
     val dxN = dx / newN
@@ -73,24 +88,15 @@ case class Corr(x: Expression, y: Expression)
       If(isNull, yMk, newYMk)
     )
   }
+}
 
-  override val mergeExpressions: Seq[Expression] = {
-
-    val n1 = n.left
-    val n2 = n.right
-    val newN = n1 + n2
-    val dx = xAvg.right - xAvg.left
-    val dxN = If(newN === Literal(0.0), Literal(0.0), dx / newN)
-    val dy = yAvg.right - yAvg.left
-    val dyN = If(newN === Literal(0.0), Literal(0.0), dy / newN)
-    val newXAvg = xAvg.left + dxN * n2
-    val newYAvg = yAvg.left + dyN * n2
-    val newCk = ck.left + ck.right + dx * dyN * n1 * n2
-    val newXMk = xMk.left + xMk.right + dx * dxN * n1 * n2
-    val newYMk = yMk.left + yMk.right + dy * dyN * n1 * n2
 
-    Seq(newN, newXAvg, newYAvg, newCk, newXMk, newYMk)
-  }
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set
of number pairs.")
+// scalastyle:on line.size.limit
+case class Corr(x: Expression, y: Expression)
+  extends PearsonCorrelation(x, y) {
 
   override val evaluateExpression: Expression = {
     If(n === Literal(0.0), Literal.create(null, DoubleType),

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 1990f2f..40582d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,24 +21,16 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
-// scalastyle:off line.size.limit
-@ExpressionDescription(
-  usage = """
-    _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null.
-
-    _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null.
-
-    _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied
expression(s) are unique and non-null.
-  """)
-// scalastyle:on line.size.limit
-case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
-
+/**
+ * Base class for all counting aggregators.
+ */
+abstract class CountLike extends DeclarativeAggregate {
   override def nullable: Boolean = false
 
   // Return data type.
   override def dataType: DataType = LongType
 
-  private lazy val count = AttributeReference("count", LongType, nullable = false)()
+  protected lazy val count = AttributeReference("count", LongType, nullable = false)()
 
   override lazy val aggBufferAttributes = count :: Nil
 
@@ -46,6 +38,27 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate
{
     /* count = */ Literal(0L)
   )
 
+  override lazy val mergeExpressions = Seq(
+    /* count = */ count.left + count.right
+  )
+
+  override lazy val evaluateExpression = count
+
+  override def defaultResult: Option[Literal] = Option(Literal(0L))
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(*) - Returns the total number of retrieved rows, including rows containing null.
+
+    _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-null.
+
+    _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied
expression(s) are unique and non-null.
+  """)
+// scalastyle:on line.size.limit
+case class Count(children: Seq[Expression]) extends CountLike {
+
   override lazy val updateExpressions = {
     val nullableChildren = children.filter(_.nullable)
     if (nullableChildren.isEmpty) {
@@ -58,14 +71,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate
{
       )
     }
   }
-
-  override lazy val mergeExpressions = Seq(
-    /* count = */ count.left + count.right
-  )
-
-  override lazy val evaluateExpression = count
-
-  override def defaultResult: Option[Literal] = Option(Literal(0L))
 }
 
 object Count {

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index fc6c34b..72a7c62 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -42,23 +42,7 @@ abstract class Covariance(x: Expression, y: Expression)
 
   override val initialValues: Seq[Expression] = Array.fill(4)(Literal(0.0))
 
-  override lazy val updateExpressions: Seq[Expression] = {
-    val newN = n + Literal(1.0)
-    val dx = x - xAvg
-    val dy = y - yAvg
-    val dyN = dy / newN
-    val newXAvg = xAvg + dx / newN
-    val newYAvg = yAvg + dyN
-    val newCk = ck + dx * (y - newYAvg)
-
-    val isNull = IsNull(x) || IsNull(y)
-    Seq(
-      If(isNull, n, newN),
-      If(isNull, xAvg, newXAvg),
-      If(isNull, yAvg, newYAvg),
-      If(isNull, ck, newCk)
-    )
-  }
+  override lazy val updateExpressions: Seq[Expression] = updateExpressionsDef
 
   override val mergeExpressions: Seq[Expression] = {
 
@@ -75,6 +59,24 @@ abstract class Covariance(x: Expression, y: Expression)
 
     Seq(newN, newXAvg, newYAvg, newCk)
   }
+
+  protected def updateExpressionsDef: Seq[Expression] = {
+    val newN = n + Literal(1.0)
+    val dx = x - xAvg
+    val dy = y - yAvg
+    val dyN = dy / newN
+    val newXAvg = xAvg + dx / newN
+    val newYAvg = yAvg + dyN
+    val newCk = ck + dx * (y - newYAvg)
+
+    val isNull = IsNull(x) || IsNull(y)
+    Seq(
+      If(isNull, n, newN),
+      If(isNull, xAvg, newXAvg),
+      If(isNull, yAvg, newYAvg),
+      If(isNull, ck, newCk)
+    )
+  }
 }
 
 @ExpressionDescription(

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
new file mode 100644
index 0000000..d8f4505
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/regression.scala
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.{AbstractDataType, DoubleType}
+
+/**
+ * Base trait for all regression functions.
+ */
+trait RegrLike extends AggregateFunction with ImplicitCastInputTypes {
+  def y: Expression
+  def x: Expression
+
+  override def children: Seq[Expression] = Seq(y, x)
+  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)
+
+  protected def updateIfNotNull(exprs: Seq[Expression]): Seq[Expression] = {
+    assert(aggBufferAttributes.length == exprs.length)
+    val nullableChildren = children.filter(_.nullable)
+    if (nullableChildren.isEmpty) {
+      exprs
+    } else {
+      exprs.zip(aggBufferAttributes).map { case (e, a) =>
+        If(nullableChildren.map(IsNull).reduce(Or), a, e)
+      }
+    }
+  }
+}
+
+
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns the number of non-null pairs.",
+  since = "2.4.0")
+case class RegrCount(y: Expression, x: Expression)
+  extends CountLike with RegrLike {
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(Seq(count + 1L))
+
+  override def prettyName: String = "regr_count"
+}
+
+
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.",
+  since = "2.4.0")
+case class RegrSXX(y: Expression, x: Expression)
+  extends CentralMomentAgg(x) with RegrLike {
+
+  override protected def momentOrder = 2
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
+  }
+
+  override def prettyName: String = "regr_sxx"
+}
+
+
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.",
+  since = "2.4.0")
+case class RegrSYY(y: Expression, x: Expression)
+  extends CentralMomentAgg(y) with RegrLike {
+
+  override protected def momentOrder = 2
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType), m2)
+  }
+
+  override def prettyName: String = "regr_syy"
+}
+
+
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns the average of x. Any pair with a NULL is ignored.",
+  since = "2.4.0")
+case class RegrAvgX(y: Expression, x: Expression)
+  extends AverageLike(x) with RegrLike {
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override def prettyName: String = "regr_avgx"
+}
+
+
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns the average of y. Any pair with a NULL is ignored.",
+  since = "2.4.0")
+case class RegrAvgY(y: Expression, x: Expression)
+  extends AverageLike(y) with RegrLike {
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override def prettyName: String = "regr_avgy"
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns the covariance of y and x multiplied for the number of
items in the dataset. Any pair with a NULL is ignored.",
+  since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrSXY(y: Expression, x: Expression)
+  extends Covariance(y, x) with RegrLike {
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0), Literal.create(null, DoubleType), ck)
+  }
+
+  override def prettyName: String = "regr_sxy"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns the slope of the linear regression line. Any pair with
a NULL is ignored.",
+  since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrSlope(y: Expression, x: Expression)
+  extends PearsonCorrelation(y, x) with RegrLike {
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override val evaluateExpression: Expression = {
+    If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType), ck
/ yMk)
+  }
+
+  override def prettyName: String = "regr_slope"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns the coefficient of determination (also called R-squared
or goodness of fit) for the regression line. Any pair with a NULL is ignored.",
+  since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrR2(y: Expression, x: Expression)
+  extends PearsonCorrelation(y, x) with RegrLike {
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override val evaluateExpression: Expression = {
+    If(n < Literal(2.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
+      If(xMk === Literal(0.0), Literal(1.0), ck * ck / yMk / xMk))
+  }
+
+  override def prettyName: String = "regr_r2"
+}
+
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(y, x) - Returns the y-intercept of the linear regression line. Any pair
with a NULL is ignored.",
+  since = "2.4.0")
+// scalastyle:on line.size.limit
+case class RegrIntercept(y: Expression, x: Expression)
+  extends PearsonCorrelation(y, x) with RegrLike {
+
+  override lazy val updateExpressions: Seq[Expression] = updateIfNotNull(updateExpressionsDef)
+
+  override val evaluateExpression: Expression = {
+    If(n === Literal(0.0) || yMk === Literal(0.0), Literal.create(null, DoubleType),
+      xAvg - (ck / yMk) * yAvg)
+  }
+
+  override def prettyName: String = "regr_intercept"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 8f9e4ae..28cf705 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -775,6 +775,178 @@ object functions {
    */
   def var_pop(columnName: String): Column = var_pop(Column(columnName))
 
+  /**
+   * Aggregate function: returns the number of non-null pairs.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_count(y: Column, x: Column): Column = withAggregateFunction {
+    RegrCount(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns the number of non-null pairs.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_sxx(y: Column, x: Column): Column = withAggregateFunction {
+    RegrSXX(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_syy(y: Column, x: Column): Column = withAggregateFunction {
+    RegrSYY(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns the average of y. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_avgy(y: Column, x: Column): Column = withAggregateFunction {
+    RegrAvgY(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns the average of y. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns the average of x. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_avgx(y: Column, x: Column): Column = withAggregateFunction {
+    RegrAvgX(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns the average of x. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns the covariance of y and x multiplied for the number of items
in
+   * the dataset. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_sxy(y: Column, x: Column): Column = withAggregateFunction {
+    RegrSXY(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns the covariance of y and x multiplied for the number of items
in
+   * the dataset. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns the slope of the linear regression line. Any pair with a
NULL is
+   * ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_slope(y: Column, x: Column): Column = withAggregateFunction {
+    RegrSlope(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns the slope of the linear regression line. Any pair with a
NULL is
+   * ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns the coefficient of determination (also called R-squared
or
+   * goodness of fit) for the regression line. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_r2(y: Column, x: Column): Column = withAggregateFunction {
+    RegrR2(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns the coefficient of determination (also called R-squared
or
+   * goodness of fit) for the regression line. Any pair with a NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x))
+
+  /**
+   * Aggregate function: returns the y-intercept of the linear regression line. Any pair
with a
+   * NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_intercept(y: Column, x: Column): Column = withAggregateFunction {
+    RegrIntercept(y.expr, x.expr)
+  }
+
+  /**
+   * Aggregate function: returns the y-intercept of the linear regression line. Any pair
with a
+   * NULL is ignored.
+   *
+   * @group agg_funcs
+   * @since 2.4.0
+   */
+  def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x))
+
+
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Window functions
   //////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
new file mode 100644
index 0000000..92c7e26
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql
@@ -0,0 +1,56 @@
+--
+--   Licensed to the Apache Software Foundation (ASF) under one or more
+--   contributor license agreements.  See the NOTICE file distributed with
+--   this work for additional information regarding copyright ownership.
+--   The ASF licenses this file to You under the Apache License, Version 2.0
+--   (the "License"); you may not use this file except in compliance with
+--   the License.  You may obtain a copy of the License at
+--
+--      http://www.apache.org/licenses/LICENSE-2.0
+--
+--   Unless required by applicable law or agreed to in writing, software
+--   distributed under the License is distributed on an "AS IS" BASIS,
+--   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+--   See the License for the specific language governing permissions and
+--   limitations under the License.
+--
+
+CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
+ (101, 1, 1, 1),
+ (201, 2, 1, 1),
+ (301, 3, 1, 1),
+ (401, 4, 1, 11),
+ (501, 5, 1, null),
+ (601, 6, null, 1),
+ (701, 6, null, null),
+ (102, 1, 2, 2),
+ (202, 2, 1, 2),
+ (302, 3, 2, 1),
+ (402, 4, 2, 12),
+ (502, 5, 2, null),
+ (602, 6, null, 2),
+ (702, 6, null, null),
+ (103, 1, 3, 3),
+ (203, 2, 1, 3),
+ (303, 3, 3, 1),
+ (403, 4, 3, 13),
+ (503, 5, 3, null),
+ (603, 6, null, 3),
+ (703, 6, null, null),
+ (104, 1, 4, 4),
+ (204, 2, 1, 4),
+ (304, 3, 4, 1),
+ (404, 4, 4, 14),
+ (504, 5, 4, null),
+ (604, 6, null, 4),
+ (704, 6, null, null),
+ (800, 7, 1, 1)
+as t1(id, px, y, x);
+
+select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x),
+ regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x),
+ regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x)
+from t1 group by px order by px;
+
+
+select id, regr_count(y,x) over (partition by px) from t1 order by id;

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out
new file mode 100644
index 0000000..d7d009a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out
@@ -0,0 +1,93 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 3
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
+ (101, 1, 1, 1),
+ (201, 2, 1, 1),
+ (301, 3, 1, 1),
+ (401, 4, 1, 11),
+ (501, 5, 1, null),
+ (601, 6, null, 1),
+ (701, 6, null, null),
+ (102, 1, 2, 2),
+ (202, 2, 1, 2),
+ (302, 3, 2, 1),
+ (402, 4, 2, 12),
+ (502, 5, 2, null),
+ (602, 6, null, 2),
+ (702, 6, null, null),
+ (103, 1, 3, 3),
+ (203, 2, 1, 3),
+ (303, 3, 3, 1),
+ (403, 4, 3, 13),
+ (503, 5, 3, null),
+ (603, 6, null, 3),
+ (703, 6, null, null),
+ (104, 1, 4, 4),
+ (204, 2, 1, 4),
+ (304, 3, 4, 1),
+ (404, 4, 4, 14),
+ (504, 5, 4, null),
+ (604, 6, null, 4),
+ (704, 6, null, null),
+ (800, 7, 1, 1)
+as t1(id, px, y, x)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x),
+ regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x),
+ regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x)
+from t1 group by px order by px
+-- !query 1 schema
+struct<px:int,var_pop(CAST(x AS DOUBLE)):double,var_pop(CAST(y AS DOUBLE)):double,corr(CAST(y
AS DOUBLE), CAST(x AS DOUBLE)):double,covar_samp(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,covar_pop(CAST(y
AS DOUBLE), CAST(x AS DOUBLE)):double,regr_count(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):bigint,regr_slope(CAST(y
AS DOUBLE), CAST(x AS DOUBLE)):double,regr_intercept(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_r2(CAST(y
AS DOUBLE), CAST(x AS DOUBLE)):double,regr_sxx(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_syy(CAST(y
AS DOUBLE), CAST(x AS DOUBLE)):double,regr_sxy(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_avgx(CAST(y
AS DOUBLE), CAST(x AS DOUBLE)):double,regr_avgy(CAST(y AS DOUBLE), CAST(x AS DOUBLE)):double,regr_count(CAST(y
AS DOUBLE), CAST(x AS DOUBLE)):bigint>
+-- !query 1 output
+1	1.25	1.25	1.0	1.6666666666666667	1.25	4	1.0	0.0	1.0	5.0	5.0	5.0	2.5	2.5	4
+2	1.25	0.0	NULL	0.0	0.0	4	0.0	1.0	1.0	5.0	0.0	0.0	2.5	1.0	4
+3	0.0	1.25	NULL	0.0	0.0	4	NULL	NULL	NULL	0.0	5.0	0.0	1.0	2.5	4
+4	1.25	1.25	1.0	1.6666666666666667	1.25	4	1.0	-10.0	1.0	5.0	5.0	5.0	12.5	2.5	4
+5	NULL	1.25	NULL	NULL	NULL	0	NULL	NULL	NULL	NULL	NULL	NULL	NULL	NULL	0
+6	1.25	NULL	NULL	NULL	NULL	0	NULL	NULL	NULL	NULL	NULL	NULL	NULL	NULL	0
+7	0.0	0.0	NaN	NaN	0.0	1	NULL	NULL	NULL	0.0	0.0	0.0	1.0	1.0	1
+
+
+-- !query 2
+select id, regr_count(y,x) over (partition by px) from t1 order by id
+-- !query 2 schema
+struct<id:int,regr_count(CAST(y AS DOUBLE), CAST(x AS DOUBLE)) OVER (PARTITION BY px ROWS
BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):bigint>
+-- !query 2 output
+101	4
+102	4
+103	4
+104	4
+201	4
+202	4
+203	4
+204	4
+301	4
+302	4
+303	4
+304	4
+401	4
+402	4
+403	4
+404	4
+501	0
+502	0
+503	0
+504	0
+601	0
+602	0
+603	0
+604	0
+701	0
+702	0
+703	0
+704	0
+800	1

http://git-wip-us.apache.org/repos/asf/spark/blob/94d67144/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index e7776e3..4337fb2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -36,6 +36,8 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp:
Doub
 class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
+  val absTol = 1e-8
+
   test("groupBy") {
     checkAnswer(
       testData2.groupBy("a").agg(sum($"b")),
@@ -416,7 +418,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
   }
 
   test("moments") {
-    val absTol = 1e-8
 
     val sparkVariance = testData2.agg(variance('a))
     checkAggregatesWithTol(sparkVariance, Row(4.0 / 5.0), absTol)
@@ -686,4 +687,72 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext
{
       }
     }
   }
+
+  test("SPARK-23907: regression functions") {
+    val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b")
+    val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12))
+      .toDF("a", "b")
+    val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)](
+      (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b")
+    checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6)))
+    checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1)))
+    checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0)))
+
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol)
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol)
+
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null),
absTol)
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null),
absTol)
+
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol)
+
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null),
absTol)
+
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol)
+
+    checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol)
+    checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null),
absTol)
+    checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")),
+      Row(null), absTol)
+
+
+    checkAggregatesWithTol(correlatedData.groupBy().agg(
+      regr_count("a", "b"),
+      regr_avgx("a", "b"),
+      regr_avgy("a", "b"),
+      regr_sxx("a", "b"),
+      regr_syy("a", "b"),
+      regr_sxy("a", "b"),
+      regr_slope("a", "b"),
+      regr_r2("a", "b"),
+      regr_intercept("a", "b")),
+      Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092),
+      absTol)
+    checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg(
+      regr_count("a", "b"),
+      regr_avgx("a", "b"),
+      regr_avgy("a", "b"),
+      regr_sxx("a", "b"),
+      regr_syy("a", "b"),
+      regr_sxy("a", "b"),
+      regr_slope("a", "b"),
+      regr_r2("a", "b"),
+      regr_intercept("a", "b")),
+      Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149),
+      absTol)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message