spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-11536][SQL] Remove the internal implicit conversion from Expression to Column in functions.scala
Date Thu, 05 Nov 2015 23:34:08 GMT
Repository: spark
Updated Branches:
  refs/heads/master d9e30c59c -> b6974f8fe


[SPARK-11536][SQL] Remove the internal implicit conversion from Expression to Column in functions.scala

Author: Reynold Xin <rxin@databricks.com>

Closes #9505 from rxin/SPARK-11536.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b6974f8f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b6974f8f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b6974f8f

Branch: refs/heads/master
Commit: b6974f8fed1726a381636e996834111a8e7ced8d
Parents: d9e30c5
Author: Reynold Xin <rxin@databricks.com>
Authored: Thu Nov 5 15:34:05 2015 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Thu Nov 5 15:34:05 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/functions.scala  | 580 ++++++++++---------
 1 file changed, 299 insertions(+), 281 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b6974f8f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index c70c965..0462758 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -51,7 +51,7 @@ import org.apache.spark.util.Utils
 object functions {
 // scalastyle:on
 
-  private[this] implicit def toColumn(expr: Expression): Column = Column(expr)
+  private def withExpr(expr: Expression): Column = Column(expr)
 
   /**
    * Returns a [[Column]] based on the given column name.
@@ -128,7 +128,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
+  def approxCountDistinct(e: Column): Column = withExpr { ApproxCountDistinct(e.expr) }
 
   /**
    * Aggregate function: returns the approximate number of distinct items in a group.
@@ -144,7 +144,9 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd)
+  def approxCountDistinct(e: Column, rsd: Double): Column = withExpr {
+    ApproxCountDistinct(e.expr, rsd)
+  }
 
   /**
    * Aggregate function: returns the approximate number of distinct items in a group.
@@ -162,7 +164,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def avg(e: Column): Column = Average(e.expr)
+  def avg(e: Column): Column = withExpr { Average(e.expr) }
 
   /**
    * Aggregate function: returns the average of the values in a group.
@@ -178,8 +180,9 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def corr(column1: Column, column2: Column): Column =
+  def corr(column1: Column, column2: Column): Column = withExpr {
     Corr(column1.expr, column2.expr)
+  }
 
   /**
    * Aggregate function: returns the Pearson Correlation Coefficient for two columns.
@@ -187,8 +190,9 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def corr(columnName1: String, columnName2: String): Column =
+  def corr(columnName1: String, columnName2: String): Column = {
     corr(Column(columnName1), Column(columnName2))
+  }
 
   /**
    * Aggregate function: returns the number of items in a group.
@@ -196,10 +200,12 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def count(e: Column): Column = e.expr match {
-    // Turn count(*) into count(1)
-    case s: Star => Count(Literal(1))
-    case _ => Count(e.expr)
+  def count(e: Column): Column = withExpr {
+    e.expr match {
+      // Turn count(*) into count(1)
+      case s: Star => Count(Literal(1))
+      case _ => Count(e.expr)
+    }
   }
 
   /**
@@ -217,8 +223,9 @@ object functions {
    * @since 1.3.0
    */
   @scala.annotation.varargs
-  def countDistinct(expr: Column, exprs: Column*): Column =
+  def countDistinct(expr: Column, exprs: Column*): Column = withExpr {
     CountDistinct((expr +: exprs).map(_.expr))
+  }
 
   /**
    * Aggregate function: returns the number of distinct items in a group.
@@ -236,7 +243,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def first(e: Column): Column = First(e.expr)
+  def first(e: Column): Column = withExpr { First(e.expr) }
 
   /**
    * Aggregate function: returns the first value of a column in a group.
@@ -252,7 +259,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def kurtosis(e: Column): Column = Kurtosis(e.expr)
+  def kurtosis(e: Column): Column = withExpr { Kurtosis(e.expr) }
 
   /**
    * Aggregate function: returns the last value in a group.
@@ -260,7 +267,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def last(e: Column): Column = Last(e.expr)
+  def last(e: Column): Column = withExpr { Last(e.expr) }
 
   /**
    * Aggregate function: returns the last value of the column in a group.
@@ -276,7 +283,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def max(e: Column): Column = Max(e.expr)
+  def max(e: Column): Column = withExpr { Max(e.expr) }
 
   /**
    * Aggregate function: returns the maximum value of the column in a group.
@@ -310,7 +317,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def min(e: Column): Column = Min(e.expr)
+  def min(e: Column): Column = withExpr { Min(e.expr) }
 
   /**
    * Aggregate function: returns the minimum value of the column in a group.
@@ -326,7 +333,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def skewness(e: Column): Column = Skewness(e.expr)
+  def skewness(e: Column): Column = withExpr { Skewness(e.expr) }
 
   /**
    * Aggregate function: alias for [[stddev_samp]].
@@ -334,7 +341,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def stddev(e: Column): Column = StddevSamp(e.expr)
+  def stddev(e: Column): Column = withExpr { StddevSamp(e.expr) }
 
   /**
    * Aggregate function: returns the unbiased sample standard deviation of
@@ -343,7 +350,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def stddev_samp(e: Column): Column = StddevSamp(e.expr)
+  def stddev_samp(e: Column): Column = withExpr { StddevSamp(e.expr) }
 
   /**
    * Aggregate function: returns the population standard deviation of
@@ -352,7 +359,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def stddev_pop(e: Column): Column = StddevPop(e.expr)
+  def stddev_pop(e: Column): Column = withExpr { StddevPop(e.expr) }
 
   /**
    * Aggregate function: returns the sum of all values in the expression.
@@ -360,7 +367,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def sum(e: Column): Column = Sum(e.expr)
+  def sum(e: Column): Column = withExpr { Sum(e.expr) }
 
   /**
    * Aggregate function: returns the sum of all values in the given column.
@@ -376,7 +383,7 @@ object functions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def sumDistinct(e: Column): Column = SumDistinct(e.expr)
+  def sumDistinct(e: Column): Column = withExpr { SumDistinct(e.expr) }
 
   /**
    * Aggregate function: returns the sum of distinct values in the expression.
@@ -392,7 +399,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def variance(e: Column): Column = VarianceSamp(e.expr)
+  def variance(e: Column): Column = withExpr { VarianceSamp(e.expr) }
 
   /**
    * Aggregate function: returns the unbiased variance of the values in a group.
@@ -400,7 +407,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def var_samp(e: Column): Column = VarianceSamp(e.expr)
+  def var_samp(e: Column): Column = withExpr { VarianceSamp(e.expr) }
 
   /**
    * Aggregate function: returns the population variance of the values in a group.
@@ -408,7 +415,7 @@ object functions {
    * @group agg_funcs
    * @since 1.6.0
    */
-  def var_pop(e: Column): Column = VariancePop(e.expr)
+  def var_pop(e: Column): Column = withExpr { VariancePop(e.expr) }
 
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Window functions
@@ -429,9 +436,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def cumeDist(): Column = {
-    UnresolvedWindowFunction("cume_dist", Nil)
-  }
+  def cumeDist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) }
 
   /**
    * Window function: returns the rank of rows within a window partition, without any gaps.
@@ -446,9 +451,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def denseRank(): Column = {
-    UnresolvedWindowFunction("dense_rank", Nil)
-  }
+  def denseRank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) }
 
   /**
    * Window function: returns the value that is `offset` rows before the current row, and
@@ -460,9 +463,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lag(e: Column, offset: Int): Column = {
-    lag(e, offset, null)
-  }
+  def lag(e: Column, offset: Int): Column = lag(e, offset, null)
 
   /**
    * Window function: returns the value that is `offset` rows before the current row, and
@@ -474,9 +475,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lag(columnName: String, offset: Int): Column = {
-    lag(columnName, offset, null)
-  }
+  def lag(columnName: String, offset: Int): Column = lag(columnName, offset, null)
 
   /**
    * Window function: returns the value that is `offset` rows before the current row, and
@@ -502,7 +501,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lag(e: Column, offset: Int, defaultValue: Any): Column = {
+  def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
     UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil)
   }
 
@@ -516,9 +515,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lead(columnName: String, offset: Int): Column = {
-    lead(columnName, offset, null)
-  }
+  def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) }
 
   /**
    * Window function: returns the value that is `offset` rows after the current row, and
@@ -530,9 +527,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lead(e: Column, offset: Int): Column = {
-    lead(e, offset, null)
-  }
+  def lead(e: Column, offset: Int): Column = { lead(e, offset, null) }
 
   /**
    * Window function: returns the value that is `offset` rows after the current row, and
@@ -558,7 +553,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def lead(e: Column, offset: Int, defaultValue: Any): Column = {
+  def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr {
     UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil)
   }
 
@@ -572,9 +567,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def ntile(n: Int): Column = {
-    UnresolvedWindowFunction("ntile", lit(n).expr :: Nil)
-  }
+  def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) }
 
   /**
    * Window function: returns the relative rank (i.e. percentile) of rows within a window partition.
@@ -589,9 +582,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def percentRank(): Column = {
-    UnresolvedWindowFunction("percent_rank", Nil)
-  }
+  def percentRank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) }
 
   /**
    * Window function: returns the rank of rows within a window partition.
@@ -606,9 +597,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def rank(): Column = {
-    UnresolvedWindowFunction("rank", Nil)
-  }
+  def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) }
 
   /**
    * Window function: returns a sequential number starting at 1 within a window partition.
@@ -618,9 +607,7 @@ object functions {
    * @group window_funcs
    * @since 1.4.0
    */
-  def rowNumber(): Column = {
-    UnresolvedWindowFunction("row_number", Nil)
-  }
+  def rowNumber(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) }
 
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Non-aggregate functions
@@ -632,7 +619,7 @@ object functions {
    * @group normal_funcs
    * @since 1.3.0
    */
-  def abs(e: Column): Column = Abs(e.expr)
+  def abs(e: Column): Column = withExpr { Abs(e.expr) }
 
   /**
    * Creates a new array column. The input columns must all have the same data type.
@@ -641,7 +628,7 @@ object functions {
    * @since 1.4.0
    */
   @scala.annotation.varargs
-  def array(cols: Column*): Column = CreateArray(cols.map(_.expr))
+  def array(cols: Column*): Column = withExpr { CreateArray(cols.map(_.expr)) }
 
   /**
    * Creates a new array column. The input columns must all have the same data type.
@@ -679,14 +666,14 @@ object functions {
    * @since 1.3.0
    */
   @scala.annotation.varargs
-  def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
+  def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) }
 
   /**
    * Creates a string column for the file name of the current Spark task.
    *
    * @group normal_funcs
    */
-  def inputFileName(): Column = InputFileName()
+  def inputFileName(): Column = withExpr { InputFileName() }
 
   /**
    * Return true iff the column is NaN.
@@ -694,7 +681,7 @@ object functions {
    * @group normal_funcs
    * @since 1.5.0
    */
-  def isNaN(e: Column): Column = IsNaN(e.expr)
+  def isNaN(e: Column): Column = withExpr { IsNaN(e.expr) }
 
   /**
    * A column expression that generates monotonically increasing 64-bit integers.
@@ -711,7 +698,7 @@ object functions {
    * @group normal_funcs
    * @since 1.4.0
    */
-  def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID()
+  def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() }
 
   /**
    * Returns col1 if it is not NaN, or col2 if col1 is NaN.
@@ -721,7 +708,7 @@ object functions {
    * @group normal_funcs
    * @since 1.5.0
    */
-  def nanvl(col1: Column, col2: Column): Column = NaNvl(col1.expr, col2.expr)
+  def nanvl(col1: Column, col2: Column): Column = withExpr { NaNvl(col1.expr, col2.expr) }
 
   /**
    * Unary minus, i.e. negate the expression.
@@ -760,7 +747,7 @@ object functions {
    * @group normal_funcs
    * @since 1.4.0
    */
-  def rand(seed: Long): Column = Rand(seed)
+  def rand(seed: Long): Column = withExpr { Rand(seed) }
 
   /**
    * Generate a random column with i.i.d. samples from U[0.0, 1.0].
@@ -776,7 +763,7 @@ object functions {
    * @group normal_funcs
    * @since 1.4.0
    */
-  def randn(seed: Long): Column = Randn(seed)
+  def randn(seed: Long): Column = withExpr { Randn(seed) }
 
   /**
    * Generate a column with i.i.d. samples from the standard normal distribution.
@@ -794,7 +781,7 @@ object functions {
    * @group normal_funcs
    * @since 1.4.0
    */
-  def sparkPartitionId(): Column = SparkPartitionID()
+  def sparkPartitionId(): Column = withExpr { SparkPartitionID() }
 
   /**
    * Computes the square root of the specified float value.
@@ -802,7 +789,7 @@ object functions {
    * @group math_funcs
    * @since 1.3.0
    */
-  def sqrt(e: Column): Column = Sqrt(e.expr)
+  def sqrt(e: Column): Column = withExpr { Sqrt(e.expr) }
 
   /**
    * Computes the square root of the specified float value.
@@ -823,9 +810,7 @@ object functions {
    * @since 1.4.0
    */
   @scala.annotation.varargs
-  def struct(cols: Column*): Column = {
-    CreateStruct(cols.map(_.expr))
-  }
+  def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) }
 
   /**
    * Creates a new struct column that composes multiple input columns.
@@ -858,7 +843,7 @@ object functions {
    * @group normal_funcs
    * @since 1.4.0
    */
-  def when(condition: Column, value: Any): Column = {
+  def when(condition: Column, value: Any): Column = withExpr {
     CaseWhen(Seq(condition.expr, lit(value).expr))
   }
 
@@ -868,7 +853,7 @@ object functions {
    * @group normal_funcs
    * @since 1.4.0
    */
-  def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr)
+  def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) }
 
   /**
    * Parses the expression string into the column that it represents, similar to
@@ -893,7 +878,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def acos(e: Column): Column = Acos(e.expr)
+  def acos(e: Column): Column = withExpr { Acos(e.expr) }
 
   /**
    * Computes the cosine inverse of the given column; the returned angle is in the range
@@ -911,7 +896,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def asin(e: Column): Column = Asin(e.expr)
+  def asin(e: Column): Column = withExpr { Asin(e.expr) }
 
   /**
    * Computes the sine inverse of the given column; the returned angle is in the range
@@ -928,7 +913,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def atan(e: Column): Column = Atan(e.expr)
+  def atan(e: Column): Column = withExpr { Atan(e.expr) }
 
   /**
    * Computes the tangent inverse of the given column.
@@ -945,7 +930,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr)
+  def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) }
 
   /**
    * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
@@ -982,7 +967,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr)
+  def atan2(l: Column, r: Double): Column = atan2(l, lit(r))
 
   /**
    * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
@@ -1000,7 +985,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r)
+  def atan2(l: Double, r: Column): Column = atan2(lit(l), r)
 
   /**
    * Returns the angle theta from the conversion of rectangular coordinates (x, y) to
@@ -1018,7 +1003,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def bin(e: Column): Column = Bin(e.expr)
+  def bin(e: Column): Column = withExpr { Bin(e.expr) }
 
   /**
    * An expression that returns the string representation of the binary value of the given long
@@ -1035,7 +1020,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def cbrt(e: Column): Column = Cbrt(e.expr)
+  def cbrt(e: Column): Column = withExpr { Cbrt(e.expr) }
 
   /**
    * Computes the cube-root of the given column.
@@ -1051,7 +1036,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def ceil(e: Column): Column = Ceil(e.expr)
+  def ceil(e: Column): Column = withExpr { Ceil(e.expr) }
 
   /**
    * Computes the ceiling of the given column.
@@ -1067,8 +1052,9 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def conv(num: Column, fromBase: Int, toBase: Int): Column =
+  def conv(num: Column, fromBase: Int, toBase: Int): Column = withExpr {
     Conv(num.expr, lit(fromBase).expr, lit(toBase).expr)
+  }
 
   /**
    * Computes the cosine of the given value.
@@ -1076,7 +1062,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def cos(e: Column): Column = Cos(e.expr)
+  def cos(e: Column): Column = withExpr { Cos(e.expr) }
 
   /**
    * Computes the cosine of the given column.
@@ -1092,7 +1078,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def cosh(e: Column): Column = Cosh(e.expr)
+  def cosh(e: Column): Column = withExpr { Cosh(e.expr) }
 
   /**
    * Computes the hyperbolic cosine of the given column.
@@ -1108,7 +1094,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def exp(e: Column): Column = Exp(e.expr)
+  def exp(e: Column): Column = withExpr { Exp(e.expr) }
 
   /**
    * Computes the exponential of the given column.
@@ -1124,7 +1110,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def expm1(e: Column): Column = Expm1(e.expr)
+  def expm1(e: Column): Column = withExpr { Expm1(e.expr) }
 
   /**
    * Computes the exponential of the given column.
@@ -1140,7 +1126,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def factorial(e: Column): Column = Factorial(e.expr)
+  def factorial(e: Column): Column = withExpr { Factorial(e.expr) }
 
   /**
    * Computes the floor of the given value.
@@ -1148,7 +1134,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def floor(e: Column): Column = Floor(e.expr)
+  def floor(e: Column): Column = withExpr { Floor(e.expr) }
 
   /**
    * Computes the floor of the given column.
@@ -1166,7 +1152,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def greatest(exprs: Column*): Column = {
+  def greatest(exprs: Column*): Column = withExpr {
     require(exprs.length > 1, "greatest requires at least 2 arguments.")
     Greatest(exprs.map(_.expr))
   }
@@ -1189,7 +1175,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def hex(column: Column): Column = Hex(column.expr)
+  def hex(column: Column): Column = withExpr { Hex(column.expr) }
 
   /**
    * Inverse of hex. Interprets each pair of characters as a hexadecimal number
@@ -1198,7 +1184,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def unhex(column: Column): Column = Unhex(column.expr)
+  def unhex(column: Column): Column = withExpr { Unhex(column.expr) }
 
   /**
    * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
@@ -1206,7 +1192,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr)
+  def hypot(l: Column, r: Column): Column = withExpr { Hypot(l.expr, r.expr) }
 
   /**
    * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
@@ -1239,7 +1225,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr)
+  def hypot(l: Column, r: Double): Column = hypot(l, lit(r))
 
   /**
    * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
@@ -1255,7 +1241,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r)
+  def hypot(l: Double, r: Column): Column = hypot(lit(l), r)
 
   /**
    * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
@@ -1273,7 +1259,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def least(exprs: Column*): Column = {
+  def least(exprs: Column*): Column = withExpr {
     require(exprs.length > 1, "least requires at least 2 arguments.")
     Least(exprs.map(_.expr))
   }
@@ -1296,7 +1282,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def log(e: Column): Column = Log(e.expr)
+  def log(e: Column): Column = withExpr { Log(e.expr) }
 
   /**
    * Computes the natural logarithm of the given column.
@@ -1312,7 +1298,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr)
+  def log(base: Double, a: Column): Column = withExpr { Logarithm(lit(base).expr, a.expr) }
 
   /**
    * Returns the first argument-base logarithm of the second argument.
@@ -1328,7 +1314,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def log10(e: Column): Column = Log10(e.expr)
+  def log10(e: Column): Column = withExpr { Log10(e.expr) }
 
   /**
    * Computes the logarithm of the given value in base 10.
@@ -1344,7 +1330,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def log1p(e: Column): Column = Log1p(e.expr)
+  def log1p(e: Column): Column = withExpr { Log1p(e.expr) }
 
   /**
    * Computes the natural logarithm of the given column plus one.
@@ -1360,7 +1346,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def log2(expr: Column): Column = Log2(expr.expr)
+  def log2(expr: Column): Column = withExpr { Log2(expr.expr) }
 
   /**
    * Computes the logarithm of the given value in base 2.
@@ -1376,7 +1362,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr)
+  def pow(l: Column, r: Column): Column = withExpr { Pow(l.expr, r.expr) }
 
   /**
    * Returns the value of the first argument raised to the power of the second argument.
@@ -1408,7 +1394,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def pow(l: Column, r: Double): Column = pow(l, lit(r).expr)
+  def pow(l: Column, r: Double): Column = pow(l, lit(r))
 
   /**
    * Returns the value of the first argument raised to the power of the second argument.
@@ -1424,7 +1410,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def pow(l: Double, r: Column): Column = pow(lit(l).expr, r)
+  def pow(l: Double, r: Column): Column = pow(lit(l), r)
 
   /**
    * Returns the value of the first argument raised to the power of the second argument.
@@ -1440,7 +1426,9 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr)
+  def pmod(dividend: Column, divisor: Column): Column = withExpr {
+    Pmod(dividend.expr, divisor.expr)
+  }
 
   /**
    * Returns the double value that is closest in value to the argument and
@@ -1449,7 +1437,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def rint(e: Column): Column = Rint(e.expr)
+  def rint(e: Column): Column = withExpr { Rint(e.expr) }
 
   /**
    * Returns the double value that is closest in value to the argument and
@@ -1466,7 +1454,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def round(e: Column): Column = round(e.expr, 0)
+  def round(e: Column): Column = round(e, 0)
 
   /**
    * Round the value of `e` to `scale` decimal places if `scale` >= 0
@@ -1475,7 +1463,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale))
+  def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) }
 
   /**
    * Shift the the given value numBits left. If the given value is a long value, this function
@@ -1484,7 +1472,7 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr)
+  def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) }
 
   /**
    * Shift the the given value numBits right. If the given value is a long value, it will return
@@ -1493,7 +1481,9 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr)
+  def shiftRight(e: Column, numBits: Int): Column = withExpr {
+    ShiftRight(e.expr, lit(numBits).expr)
+  }
 
   /**
    * Unsigned shift the the given value numBits right. If the given value is a long value,
@@ -1502,8 +1492,9 @@ object functions {
    * @group math_funcs
    * @since 1.5.0
    */
-  def shiftRightUnsigned(e: Column, numBits: Int): Column =
+  def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr {
     ShiftRightUnsigned(e.expr, lit(numBits).expr)
+  }
 
   /**
    * Computes the signum of the given value.
@@ -1511,7 +1502,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def signum(e: Column): Column = Signum(e.expr)
+  def signum(e: Column): Column = withExpr { Signum(e.expr) }
 
   /**
    * Computes the signum of the given column.
@@ -1527,7 +1518,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def sin(e: Column): Column = Sin(e.expr)
+  def sin(e: Column): Column = withExpr { Sin(e.expr) }
 
   /**
    * Computes the sine of the given column.
@@ -1543,7 +1534,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def sinh(e: Column): Column = Sinh(e.expr)
+  def sinh(e: Column): Column = withExpr { Sinh(e.expr) }
 
   /**
    * Computes the hyperbolic sine of the given column.
@@ -1559,7 +1550,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def tan(e: Column): Column = Tan(e.expr)
+  def tan(e: Column): Column = withExpr { Tan(e.expr) }
 
   /**
    * Computes the tangent of the given column.
@@ -1575,7 +1566,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def tanh(e: Column): Column = Tanh(e.expr)
+  def tanh(e: Column): Column = withExpr { Tanh(e.expr) }
 
   /**
    * Computes the hyperbolic tangent of the given column.
@@ -1591,7 +1582,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def toDegrees(e: Column): Column = ToDegrees(e.expr)
+  def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) }
 
   /**
    * Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
@@ -1607,7 +1598,7 @@ object functions {
    * @group math_funcs
    * @since 1.4.0
    */
-  def toRadians(e: Column): Column = ToRadians(e.expr)
+  def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) }
 
   /**
    * Converts an angle measured in degrees to an approximately equivalent angle measured in radians.
@@ -1628,7 +1619,7 @@ object functions {
    * @group misc_funcs
    * @since 1.5.0
    */
-  def md5(e: Column): Column = Md5(e.expr)
+  def md5(e: Column): Column = withExpr { Md5(e.expr) }
 
   /**
    * Calculates the SHA-1 digest of a binary column and returns the value
@@ -1637,7 +1628,7 @@ object functions {
    * @group misc_funcs
    * @since 1.5.0
    */
-  def sha1(e: Column): Column = Sha1(e.expr)
+  def sha1(e: Column): Column = withExpr { Sha1(e.expr) }
 
   /**
    * Calculates the SHA-2 family of hash functions of a binary column and
@@ -1652,7 +1643,7 @@ object functions {
   def sha2(e: Column, numBits: Int): Column = {
     require(Seq(0, 224, 256, 384, 512).contains(numBits),
       s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)")
-    Sha2(e.expr, lit(numBits).expr)
+    withExpr { Sha2(e.expr, lit(numBits).expr) }
   }
 
   /**
@@ -1662,7 +1653,7 @@ object functions {
    * @group misc_funcs
    * @since 1.5.0
    */
-  def crc32(e: Column): Column = Crc32(e.expr)
+  def crc32(e: Column): Column = withExpr { Crc32(e.expr) }
 
   //////////////////////////////////////////////////////////////////////////////////////////////
   // String functions
@@ -1675,7 +1666,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def ascii(e: Column): Column = Ascii(e.expr)
+  def ascii(e: Column): Column = withExpr { Ascii(e.expr) }
 
   /**
    * Computes the BASE64 encoding of a binary column and returns it as a string column.
@@ -1684,7 +1675,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def base64(e: Column): Column = Base64(e.expr)
+  def base64(e: Column): Column = withExpr { Base64(e.expr) }
 
   /**
    * Concatenates multiple input string columns together into a single string column.
@@ -1693,7 +1684,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def concat(exprs: Column*): Column = Concat(exprs.map(_.expr))
+  def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
 
   /**
    * Concatenates multiple input string columns together into a single string column,
@@ -1703,7 +1694,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def concat_ws(sep: String, exprs: Column*): Column = {
+  def concat_ws(sep: String, exprs: Column*): Column = withExpr {
     ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr))
   }
 
@@ -1715,7 +1706,9 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr)
+  def decode(value: Column, charset: String): Column = withExpr {
+    Decode(value.expr, lit(charset).expr)
+  }
 
   /**
    * Computes the first argument into a binary from a string using the provided character set
@@ -1725,7 +1718,9 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr)
+  def encode(value: Column, charset: String): Column = withExpr {
+    Encode(value.expr, lit(charset).expr)
+  }
 
   /**
    * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places,
@@ -1737,7 +1732,9 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr)
+  def format_number(x: Column, d: Int): Column = withExpr {
+    FormatNumber(x.expr, lit(d).expr)
+  }
 
   /**
    * Formats the arguments in printf-style and returns the result as a string column.
@@ -1746,7 +1743,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def format_string(format: String, arguments: Column*): Column = {
+  def format_string(format: String, arguments: Column*): Column = withExpr {
     FormatString((lit(format) +: arguments).map(_.expr): _*)
   }
 
@@ -1759,7 +1756,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def initcap(e: Column): Column = InitCap(e.expr)
+  def initcap(e: Column): Column = withExpr { InitCap(e.expr) }
 
   /**
    * Locate the position of the first occurrence of substr column in the given string.
@@ -1771,7 +1768,9 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr)
+  def instr(str: Column, substring: String): Column = withExpr {
+    StringInstr(str.expr, lit(substring).expr)
+  }
 
   /**
    * Computes the length of a given string or binary column.
@@ -1779,7 +1778,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def length(e: Column): Column = Length(e.expr)
+  def length(e: Column): Column = withExpr { Length(e.expr) }
 
   /**
    * Converts a string column to lower case.
@@ -1787,14 +1786,14 @@ object functions {
    * @group string_funcs
    * @since 1.3.0
    */
-  def lower(e: Column): Column = Lower(e.expr)
+  def lower(e: Column): Column = withExpr { Lower(e.expr) }
 
   /**
    * Computes the Levenshtein distance of the two given string columns.
    * @group string_funcs
    * @since 1.5.0
    */
-  def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr)
+  def levenshtein(l: Column, r: Column): Column = withExpr { Levenshtein(l.expr, r.expr) }
 
   /**
    * Locate the position of the first occurrence of substr.
@@ -1804,7 +1803,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def locate(substr: String, str: Column): Column = {
+  def locate(substr: String, str: Column): Column = withExpr {
     new StringLocate(lit(substr).expr, str.expr)
   }
 
@@ -1817,7 +1816,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def locate(substr: String, str: Column, pos: Int): Column = {
+  def locate(substr: String, str: Column, pos: Int): Column = withExpr {
     StringLocate(lit(substr).expr, str.expr, lit(pos).expr)
   }
 
@@ -1827,7 +1826,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def lpad(str: Column, len: Int, pad: String): Column = {
+  def lpad(str: Column, len: Int, pad: String): Column = withExpr {
     StringLPad(str.expr, lit(len).expr, lit(pad).expr)
   }
 
@@ -1837,7 +1836,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def ltrim(e: Column): Column = StringTrimLeft(e.expr)
+  def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) }
 
   /**
    * Extract a specific(idx) group identified by a java regex, from the specified string column.
@@ -1845,7 +1844,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = {
+  def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = withExpr {
     RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr)
   }
 
@@ -1855,7 +1854,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def regexp_replace(e: Column, pattern: String, replacement: String): Column = {
+  def regexp_replace(e: Column, pattern: String, replacement: String): Column = withExpr {
     RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr)
   }
 
@@ -1866,7 +1865,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def unbase64(e: Column): Column = UnBase64(e.expr)
+  def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) }
 
   /**
    * Right-padded with pad to a length of len.
@@ -1874,7 +1873,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def rpad(str: Column, len: Int, pad: String): Column = {
+  def rpad(str: Column, len: Int, pad: String): Column = withExpr {
     StringRPad(str.expr, lit(len).expr, lit(pad).expr)
   }
 
@@ -1884,7 +1883,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def repeat(str: Column, n: Int): Column = {
+  def repeat(str: Column, n: Int): Column = withExpr {
     StringRepeat(str.expr, lit(n).expr)
   }
 
@@ -1894,9 +1893,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def reverse(str: Column): Column = {
-    StringReverse(str.expr)
-  }
+  def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }
 
   /**
    * Trim the spaces from right end for the specified string value.
@@ -1904,7 +1901,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def rtrim(e: Column): Column = StringTrimRight(e.expr)
+  def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) }
 
   /**
    * * Return the soundex code for the specified expression.
@@ -1912,7 +1909,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def soundex(e: Column): Column = SoundEx(e.expr)
+  def soundex(e: Column): Column = withExpr { SoundEx(e.expr) }
 
   /**
    * Splits str around pattern (pattern is a regular expression).
@@ -1921,7 +1918,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def split(str: Column, pattern: String): Column = {
+  def split(str: Column, pattern: String): Column = withExpr {
     StringSplit(str.expr, lit(pattern).expr)
   }
 
@@ -1933,8 +1930,9 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def substring(str: Column, pos: Int, len: Int): Column =
+  def substring(str: Column, pos: Int, len: Int): Column = withExpr {
     Substring(str.expr, lit(pos).expr, lit(len).expr)
+  }
 
   /**
    * Returns the substring from string str before count occurrences of the delimiter delim.
@@ -1944,8 +1942,9 @@ object functions {
    *
    * @group string_funcs
    */
-  def substring_index(str: Column, delim: String, count: Int): Column =
+  def substring_index(str: Column, delim: String, count: Int): Column = withExpr {
     SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
+  }
 
   /**
    * Translate any character in the src by a character in replaceString.
@@ -1956,8 +1955,9 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def translate(src: Column, matchingString: String, replaceString: String): Column =
+  def translate(src: Column, matchingString: String, replaceString: String): Column = withExpr {
     StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr)
+  }
 
   /**
    * Trim the spaces from both ends for the specified string column.
@@ -1965,7 +1965,7 @@ object functions {
    * @group string_funcs
    * @since 1.5.0
    */
-  def trim(e: Column): Column = StringTrim(e.expr)
+  def trim(e: Column): Column = withExpr { StringTrim(e.expr) }
 
   /**
    * Converts a string column to upper case.
@@ -1973,7 +1973,7 @@ object functions {
    * @group string_funcs
    * @since 1.3.0
    */
-  def upper(e: Column): Column = Upper(e.expr)
+  def upper(e: Column): Column = withExpr { Upper(e.expr) }
 
   //////////////////////////////////////////////////////////////////////////////////////////////
   // DateTime functions
@@ -1985,8 +1985,9 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def add_months(startDate: Column, numMonths: Int): Column =
+  def add_months(startDate: Column, numMonths: Int): Column = withExpr {
     AddMonths(startDate.expr, Literal(numMonths))
+  }
 
   /**
    * Returns the current date as a date column.
@@ -1994,7 +1995,7 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def current_date(): Column = CurrentDate()
+  def current_date(): Column = withExpr { CurrentDate() }
 
   /**
    * Returns the current timestamp as a timestamp column.
@@ -2002,7 +2003,7 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def current_timestamp(): Column = CurrentTimestamp()
+  def current_timestamp(): Column = withExpr { CurrentTimestamp() }
 
   /**
    * Converts a date/timestamp/string to a value of string in the format specified by the date
@@ -2017,71 +2018,72 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def date_format(dateExpr: Column, format: String): Column =
+  def date_format(dateExpr: Column, format: String): Column = withExpr {
     DateFormatClass(dateExpr.expr, Literal(format))
+  }
 
   /**
    * Returns the date that is `days` days after `start`
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days))
+  def date_add(start: Column, days: Int): Column = withExpr { DateAdd(start.expr, Literal(days)) }
 
   /**
    * Returns the date that is `days` days before `start`
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days))
+  def date_sub(start: Column, days: Int): Column = withExpr { DateSub(start.expr, Literal(days)) }
 
   /**
    * Returns the number of days from `start` to `end`.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def datediff(end: Column, start: Column): Column = DateDiff(end.expr, start.expr)
+  def datediff(end: Column, start: Column): Column = withExpr { DateDiff(end.expr, start.expr) }
 
   /**
    * Extracts the year as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def year(e: Column): Column = Year(e.expr)
+  def year(e: Column): Column = withExpr { Year(e.expr) }
 
   /**
    * Extracts the quarter as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def quarter(e: Column): Column = Quarter(e.expr)
+  def quarter(e: Column): Column = withExpr { Quarter(e.expr) }
 
   /**
    * Extracts the month as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def month(e: Column): Column = Month(e.expr)
+  def month(e: Column): Column = withExpr { Month(e.expr) }
 
   /**
    * Extracts the day of the month as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def dayofmonth(e: Column): Column = DayOfMonth(e.expr)
+  def dayofmonth(e: Column): Column = withExpr { DayOfMonth(e.expr) }
 
   /**
    * Extracts the day of the year as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def dayofyear(e: Column): Column = DayOfYear(e.expr)
+  def dayofyear(e: Column): Column = withExpr { DayOfYear(e.expr) }
 
   /**
    * Extracts the hours as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def hour(e: Column): Column = Hour(e.expr)
+  def hour(e: Column): Column = withExpr { Hour(e.expr) }
 
   /**
    * Given a date column, returns the last day of the month which the given date belongs to.
@@ -2091,21 +2093,23 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def last_day(e: Column): Column = LastDay(e.expr)
+  def last_day(e: Column): Column = withExpr { LastDay(e.expr) }
 
   /**
    * Extracts the minutes as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def minute(e: Column): Column = Minute(e.expr)
+  def minute(e: Column): Column = withExpr { Minute(e.expr) }
 
   /*
    * Returns number of months between dates `date1` and `date2`.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr)
+  def months_between(date1: Column, date2: Column): Column = withExpr {
+    MonthsBetween(date1.expr, date2.expr)
+  }
 
   /**
    * Given a date column, returns the first date which is later than the value of the date column
@@ -2120,21 +2124,23 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr)
+  def next_day(date: Column, dayOfWeek: String): Column = withExpr {
+    NextDay(date.expr, lit(dayOfWeek).expr)
+  }
 
   /**
    * Extracts the seconds as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def second(e: Column): Column = Second(e.expr)
+  def second(e: Column): Column = withExpr { Second(e.expr) }
 
   /**
    * Extracts the week number as an integer from a given date/timestamp/string.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def weekofyear(e: Column): Column = WeekOfYear(e.expr)
+  def weekofyear(e: Column): Column = withExpr { WeekOfYear(e.expr) }
 
   /**
    * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
@@ -2143,7 +2149,9 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss"))
+  def from_unixtime(ut: Column): Column = withExpr {
+    FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss"))
+  }
 
   /**
    * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string
@@ -2152,14 +2160,18 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f))
+  def from_unixtime(ut: Column, f: String): Column = withExpr {
+    FromUnixTime(ut.expr, Literal(f))
+  }
 
   /**
    * Gets current Unix timestamp in seconds.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss"))
+  def unix_timestamp(): Column = withExpr {
+    UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss"))
+  }
 
   /**
    * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds),
@@ -2167,7 +2179,9 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss"))
+  def unix_timestamp(s: Column): Column = withExpr {
+    UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss"))
+  }
 
   /**
    * Convert time string with given pattern
@@ -2176,7 +2190,7 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p))
+  def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) }
 
   /**
    * Converts the column into DateType.
@@ -2184,7 +2198,7 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def to_date(e: Column): Column = ToDate(e.expr)
+  def to_date(e: Column): Column = withExpr { ToDate(e.expr) }
 
   /**
    * Returns date truncated to the unit specified by the format.
@@ -2195,22 +2209,27 @@ object functions {
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format))
+  def trunc(date: Column, format: String): Column = withExpr {
+    TruncDate(date.expr, Literal(format))
+  }
 
   /**
    * Assumes given timestamp is UTC and converts to given timezone.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def from_utc_timestamp(ts: Column, tz: String): Column =
-    FromUTCTimestamp(ts.expr, Literal(tz).expr)
+  def from_utc_timestamp(ts: Column, tz: String): Column = withExpr {
+    FromUTCTimestamp(ts.expr, Literal(tz))
+  }
 
   /**
    * Assumes given timestamp is in given timezone and converts to UTC.
    * @group datetime_funcs
    * @since 1.5.0
    */
-  def to_utc_timestamp(ts: Column, tz: String): Column = ToUTCTimestamp(ts.expr, Literal(tz).expr)
+  def to_utc_timestamp(ts: Column, tz: String): Column = withExpr {
+    ToUTCTimestamp(ts.expr, Literal(tz))
+  }
 
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Collection functions
@@ -2221,8 +2240,9 @@ object functions {
    * @group collection_funcs
    * @since 1.5.0
    */
-  def array_contains(column: Column, value: Any): Column =
+  def array_contains(column: Column, value: Any): Column = withExpr {
     ArrayContains(column.expr, Literal(value))
+  }
 
   /**
    * Creates a new row for each element in the given array or map column.
@@ -2230,7 +2250,7 @@ object functions {
    * @group collection_funcs
    * @since 1.3.0
    */
-  def explode(e: Column): Column = Explode(e.expr)
+  def explode(e: Column): Column = withExpr { Explode(e.expr) }
 
   /**
    * Returns length of array or map.
@@ -2238,7 +2258,7 @@ object functions {
    * @group collection_funcs
    * @since 1.5.0
    */
-  def size(e: Column): Column = Size(e.expr)
+  def size(e: Column): Column = withExpr { Size(e.expr) }
 
   /**
    * Sorts the input array for the given column in ascending order,
@@ -2256,7 +2276,7 @@ object functions {
    * @group collection_funcs
    * @since 1.5.0
    */
-  def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr)
+  def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) }
 
   //////////////////////////////////////////////////////////////////////////////////////////////
   //////////////////////////////////////////////////////////////////////////////////////////////
@@ -2296,11 +2316,10 @@ object functions {
      * @deprecated As of 1.5.0, since it's redundant with udf()
      */
     @deprecated("Use udf", "1.5.0")
-    def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
+    def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = withExpr {
       ScalaUDF(f, returnType, Seq($argsInUDF))
     }""")
   }
-  }
   */
   /**
    * Defines a user-defined function of 0 arguments as user-defined function (UDF).
@@ -2435,147 +2454,146 @@ object functions {
   }
 
   //////////////////////////////////////////////////////////////////////////////////////////////////
-
   /**
-   * Call a Scala function of 0 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 0 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function0[_], returnType: DataType): Column = {
+  def callUDF(f: Function0[_], returnType: DataType): Column = withExpr {
     ScalaUDF(f, returnType, Seq())
   }
 
   /**
-   * Call a Scala function of 1 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 1 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
+  def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr))
   }
 
   /**
-   * Call a Scala function of 2 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 2 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
+  def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr))
   }
 
   /**
-   * Call a Scala function of 3 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 3 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
+  def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
   }
 
   /**
-   * Call a Scala function of 4 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 4 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
+  def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
   }
 
   /**
-   * Call a Scala function of 5 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 5 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
+  def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
   }
 
   /**
-   * Call a Scala function of 6 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 6 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
+  def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
   }
 
   /**
-   * Call a Scala function of 7 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 7 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
+  def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
   }
 
   /**
-   * Call a Scala function of 8 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 8 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
+  def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
   }
 
   /**
-   * Call a Scala function of 9 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 9 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
+  def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
   }
 
   /**
-   * Call a Scala function of 10 arguments as user-defined function (UDF). This requires
-   * you to specify the return data type.
-   *
-   * @group udf_funcs
-   * @since 1.3.0
-   * @deprecated As of 1.5.0, since it's redundant with udf()
-   */
+    * Call a Scala function of 10 arguments as user-defined function (UDF). This requires
+    * you to specify the return data type.
+    *
+    * @group udf_funcs
+    * @since 1.3.0
+    * @deprecated As of 1.5.0, since it's redundant with udf()
+    */
   @deprecated("Use udf", "1.5.0")
-  def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
+  def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr {
     ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
   }
 
@@ -2597,7 +2615,7 @@ object functions {
    * @since 1.5.0
    */
   @scala.annotation.varargs
-  def callUDF(udfName: String, cols: Column*): Column = {
+  def callUDF(udfName: String, cols: Column*): Column = withExpr {
     UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false)
   }
 
@@ -2618,7 +2636,7 @@ object functions {
    * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF
    */
   @deprecated("Use callUDF", "1.5.0")
-  def callUdf(udfName: String, cols: Column*): Column = {
+  def callUdf(udfName: String, cols: Column*): Column = withExpr {
     // Note: we avoid using closures here because on file systems that are case-insensitive, the
     // compiled class file for the closure here will conflict with the one in callUDF (upper case).
     val exprs = new Array[Expression](cols.size)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message