spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lix...@apache.org
Subject spark git commit: [SPARK-24957][SQL][FOLLOW-UP] Clean the code for AVERAGE
Date Thu, 02 Aug 2018 06:00:21 GMT
Repository: spark
Updated Branches:
  refs/heads/master c9914cf04 -> 166f34618


[SPARK-24957][SQL][FOLLOW-UP] Clean the code for AVERAGE

## What changes were proposed in this pull request?
This PR is to refactor the code in AVERAGE by dsl.

## How was this patch tested?
N/A

Author: Xiao Li <gatorsmile@gmail.com>

Closes #21951 from gatorsmile/refactor1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/166f3461
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/166f3461
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/166f3461

Branch: refs/heads/master
Commit: 166f346185cc0b27a7e2b2a3b42df277e5901f2f
Parents: c9914cf
Author: Xiao Li <gatorsmile@gmail.com>
Authored: Wed Aug 1 23:00:17 2018 -0700
Committer: Xiao Li <gatorsmile@gmail.com>
Committed: Wed Aug 1 23:00:17 2018 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/catalyst/dsl/package.scala |  1 +
 .../sql/catalyst/expressions/aggregate/Average.scala      | 10 ++++------
 2 files changed, 5 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/166f3461/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 89e8c99..9870854 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -166,6 +166,7 @@ package object dsl {
     def maxDistinct(e: Expression): Expression = Max(e).toAggregateExpression(isDistinct
= true)
     def upper(e: Expression): Expression = Upper(e)
     def lower(e: Expression): Expression = Lower(e)
+    def coalesce(args: Expression*): Expression = Coalesce(args)
     def sqrt(e: Expression): Expression = Sqrt(e)
     def abs(e: Expression): Expression = Abs(e)
     def star(names: String*): Expression = names match {

http://git-wip-us.apache.org/repos/asf/spark/blob/166f3461/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 9ccf5aa..f1fad77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -46,7 +46,7 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate
{
   override lazy val aggBufferAttributes = sum :: count :: Nil
 
   override lazy val initialValues = Seq(
-    /* sum = */ Cast(Literal(0), sumDataType),
+    /* sum = */ Literal(0).cast(sumDataType),
     /* count = */ Literal(0L)
   )
 
@@ -58,18 +58,16 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate
{
   // If all input are nulls, count will be 0 and we will get null after the division.
   override lazy val evaluateExpression = child.dataType match {
     case _: DecimalType =>
-      Cast(
-        DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)),
-        resultType)
+      DecimalPrecision.decimalAndDecimal(sum / count.cast(DecimalType.LongDecimal)).cast(resultType)
     case _ =>
-      Cast(sum, resultType) / Cast(count, resultType)
+      sum.cast(resultType) / count.cast(resultType)
   }
 
   protected def updateExpressionsDef: Seq[Expression] = Seq(
     /* sum = */
     Add(
       sum,
-      Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
+      coalesce(child.cast(sumDataType), Literal(0).cast(sumDataType))),
     /* count = */ If(IsNull(child), count, count + 1L)
   )
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message