spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From lix...@apache.org
Subject spark git commit: [SPARK-20674][SQL] Support registering UserDefinedFunction as named UDF
Date Tue, 09 May 2017 16:24:41 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.2 08e1b78f0 -> 73aa23b8e


[SPARK-20674][SQL] Support registering UserDefinedFunction as named UDF

## What changes were proposed in this pull request?
For some reason we don't have an API to register UserDefinedFunction as named UDF. It is a
no brainer to add one, in addition to the existing register functions we have.

## How was this patch tested?
Added a test case in UDFSuite for the new API.

Author: Reynold Xin <rxin@databricks.com>

Closes #17915 from rxin/SPARK-20674.

(cherry picked from commit d099f414d2cb53f5a61f6e77317c736be6f953a0)
Signed-off-by: Xiao Li <gatorsmile@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/73aa23b8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/73aa23b8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/73aa23b8

Branch: refs/heads/branch-2.2
Commit: 73aa23b8ef64960e7f171aa07aec396667a2339d
Parents: 08e1b78
Author: Reynold Xin <rxin@databricks.com>
Authored: Tue May 9 09:24:28 2017 -0700
Committer: Xiao Li <gatorsmile@gmail.com>
Committed: Tue May 9 09:24:36 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/UDFRegistration.scala  | 22 +++++++++++++++++---
 .../scala/org/apache/spark/sql/UDFSuite.scala   |  7 +++++++
 2 files changed, 26 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/73aa23b8/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index a576733..6accf1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -70,15 +70,31 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
extends
    * @param name the name of the UDAF.
    * @param udaf the UDAF needs to be registered.
    * @return the registered UDAF.
+   *
+   * @since 1.5.0
    */
-  def register(
-      name: String,
-      udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
+  def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction
= {
     def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
     functionRegistry.registerFunction(name, builder)
     udaf
   }
 
+  /**
+   * Register a user-defined function (UDF), for a UDF that's already defined using the DataFrame
+   * API (i.e. of type UserDefinedFunction).
+   *
+   * @param name the name of the UDF.
+   * @param udf the UDF needs to be registered.
+   * @return the registered UDF.
+   *
+   * @since 2.2.0
+   */
+  def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = {
+    def builder(children: Seq[Expression]) = udf.apply(children.map(Column.apply) : _*).expr
+    functionRegistry.registerFunction(name, builder)
+    udf
+  }
+
   // scalastyle:off line.size.limit
 
   /* register 0-22 were generated by this script

http://git-wip-us.apache.org/repos/asf/spark/blob/73aa23b8/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index ae6b2bc..6f8723a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -93,6 +93,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
     assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
   }
 
+  test("UDF defined using UserDefinedFunction") {
+    import functions.udf
+    val foo = udf((x: Int) => x + 1)
+    spark.udf.register("foo", foo)
+    assert(sql("select foo(5)").head().getInt(0) == 6)
+  }
+
   test("ZeroArgument UDF") {
     spark.udf.register("random0", () => { Math.random()})
     assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message