flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From twal...@apache.org
Subject flink git commit: [FLINK-5571] [table] add open and close methods for UserDefinedFunction
Date Mon, 20 Feb 2017 14:39:26 GMT
Repository: flink
Updated Branches:
  refs/heads/master 0bdf3a74c -> b820fd3ca


[FLINK-5571] [table] add open and close methods for UserDefinedFunction

This closes #3176.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b820fd3c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b820fd3c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b820fd3c

Branch: refs/heads/master
Commit: b820fd3ca038e411bc7f43e1c35637bf62981fe5
Parents: 0bdf3a7
Author: godfreyhe <godfreyhe@163.com>
Authored: Fri Jan 20 14:42:12 2017 +0800
Committer: twalthr <twalthr@apache.org>
Committed: Mon Feb 20 15:38:34 2017 +0100

----------------------------------------------------------------------
 docs/dev/table_api.md                           |  92 +++++++++++++++++
 .../flink/table/codegen/CodeGenerator.scala     |  78 +++++++++++---
 .../flink/table/functions/FunctionContext.scala |  66 ++++++++++++
 .../table/functions/UserDefinedFunction.scala   |  17 ++-
 .../table/runtime/CorrelateFlatMapRunner.scala  |   7 ++
 .../flink/table/runtime/FlatMapRunner.scala     |   7 ++
 .../table/api/scala/batch/sql/CalcITCase.scala  |   6 +-
 .../api/scala/batch/table/CalcITCase.scala      |   8 +-
 .../table/api/scala/stream/sql/SqlITCase.scala  |   6 +-
 .../api/scala/stream/table/CalcITCase.scala     |   8 +-
 .../UserDefinedScalarFunctionTest.scala         |  31 +++++-
 .../expressions/utils/ExpressionTestBase.scala  |  40 ++++++-
 .../utils/UserDefinedScalarFunctions.scala      |  97 ++++++++++++++++-
 .../runtime/dataset/DataSetCalcITCase.scala     | 103 +++++++++++++++++++
 .../dataset/DataSetCorrelateITCase.scala        |  52 +++++++++-
 .../datastream/DataStreamCalcITCase.scala       |  81 +++++++++++++++
 .../datastream/DataStreamCorrelateITCase.scala  |  67 ++++++++++--
 .../utils/UserDefinedFunctionTestUtils.scala    |  53 ++++++++++
 .../table/utils/UserDefinedTableFunctions.scala |  58 ++++++++++-
 19 files changed, 828 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/docs/dev/table_api.md
----------------------------------------------------------------------
diff --git a/docs/dev/table_api.md b/docs/dev/table_api.md
index 99ae711..22fd636 100644
--- a/docs/dev/table_api.md
+++ b/docs/dev/table_api.md
@@ -4724,6 +4724,11 @@ ELEMENT(ARRAY)
 </div>
 </div>
 
+{% top %}
+
+User-defined Functions
+----------------
+
 ### User-defined Scalar Functions
 
 If a required scalar function is not contained in the built-in functions, it is possible to define custom, user-defined scalar functions for both the Table API and SQL. A user-defined scalar functions maps zero, one, or multiple scalar values to a new scalar value.
@@ -4933,6 +4938,93 @@ class CustomTypeSplit extends TableFunction[Row] {
 </div>
 </div>
 
+### Advanced Function Features
+
+Sometimes it might be necessary for a user-defined function to get global runtime information or do some setup/clean-up work before the actual work. User-defined functions provide `open()` and `close()` methods that can be overriden and provide similar functionality as the methods in `RichFunction` of DataSet or DataStream API.
+
+The `open()` method is called once before the evaluation method. The `close()` method after the last call to the evaluation method.
+
+The `open()` method provides a `FunctionContext` that contains information about the context in which user-defined functions are executed, such as the metric group, the distributed cache files, or the global job parameters.
+
+The following information can be obtained by calling the corresponding methods of `FunctionContext`:
+
+| Method                                | Description                                            |
+| :------------------------------------ | :----------------------------------------------------- |
+| `getMetricGroup()`                    | Metric group for this parallel subtask.                |
+| `getCachedFile(name)`                 | Local temporary file copy of a distributed cache file. |
+| `getJobParameter(name, defaultValue)` | Global job parameter value associated with given key.  |
+
+The following example snippet shows how to use `FunctionContext` in a scalar function for accessing a global job parameter:
+
+<div class="codetabs" markdown="1">
+<div data-lang="java" markdown="1">
+{% highlight java %}
+public class HashCode extends ScalarFunction {
+
+    private int factor = 0;
+
+    @Override
+    public void open(FunctionContext context) throws Exception {
+        // access "hashcode_factor" parameter
+        // "12" would be the default value if parameter does not exist
+        factor = Integer.valueOf(context.getJobParameter("hashcode_factor", "12")); 
+    }
+
+    public int eval(String s) {
+        return s.hashCode() * factor;
+    }
+}
+
+ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env);
+
+// set job parameter
+Configuration conf = new Configuration();
+conf.setString("hashcode_factor", "31");
+env.getConfig().setGlobalJobParameters(conf);
+
+// register the function
+tableEnv.registerFunction("hashCode", new HashCode())
+
+// use the function in Java Table API
+myTable.select("string, string.hashCode(), hashCode(string)");
+
+// use the function in SQL
+tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable");
+{% endhighlight %}
+</div>
+
+<div data-lang="scala" markdown="1">
+{% highlight scala %}
+object hashCode extends ScalarFunction {
+
+  var hashcode_factor = 12;
+
+  override def open(context: FunctionContext): Unit = {
+    // access "hashcode_factor" parameter
+    // "12" would be the default value if parameter does not exist
+    hashcode_factor = context.getJobParameter("hashcode_factor", "12").toInt
+  }
+
+  def eval(s: String): Int = {
+    s.hashCode() * hashcode_factor
+  }
+}
+
+val tableEnv = TableEnvironment.getTableEnvironment(env)
+
+// use the function in Scala Table API
+myTable.select('string, hashCode('string))
+
+// register and use the function in SQL
+tableEnv.registerFunction("hashCode", hashCode)
+tableEnv.sql("SELECT string, HASHCODE(string) FROM MyTable");
+{% endhighlight %}
+
+</div>
+</div>
+
+
 ### Limitations
 
 The following operations are not supported yet:

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index c679bd8..441b1c0 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -25,12 +25,13 @@ import org.apache.calcite.rex._
 import org.apache.calcite.sql.SqlOperator
 import org.apache.calcite.sql.`type`.SqlTypeName._
 import org.apache.calcite.sql.fun.SqlStdOperatorTable._
-import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction, Function, MapFunction}
+import org.apache.flink.api.common.functions._
 import org.apache.flink.api.common.io.GenericInputFormat
 import org.apache.flink.api.common.typeinfo.{AtomicType, SqlTimeTypeInfo, TypeInformation}
 import org.apache.flink.api.common.typeutils.CompositeType
 import org.apache.flink.api.java.typeutils.{GenericTypeInfo, PojoTypeInfo, RowTypeInfo, TupleTypeInfo}
 import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
+import org.apache.flink.configuration.Configuration
 import org.apache.flink.table.api.TableConfig
 import org.apache.flink.table.calcite.FlinkTypeFactory
 import org.apache.flink.table.codegen.CodeGenUtils._
@@ -38,7 +39,7 @@ import org.apache.flink.table.codegen.GeneratedExpression.{NEVER_NULL, NO_CODE}
 import org.apache.flink.table.codegen.Indenter.toISC
 import org.apache.flink.table.codegen.calls.FunctionGenerator
 import org.apache.flink.table.codegen.calls.ScalarOperators._
-import org.apache.flink.table.functions.UserDefinedFunction
+import org.apache.flink.table.functions.{FunctionContext, UserDefinedFunction}
 import org.apache.flink.table.runtime.TableFunctionCollector
 import org.apache.flink.table.typeutils.TypeCheckUtils._
 import org.apache.flink.types.Row
@@ -122,6 +123,14 @@ class CodeGenerator(
   // we use a LinkedHashSet to keep the insertion order
   private val reusableInitStatements = mutable.LinkedHashSet[String]()
 
+  // set of open statements for RichFunction that will be added only once
+  // we use a LinkedHashSet to keep the insertion order
+  private val reusableOpenStatements = mutable.LinkedHashSet[String]()
+
+  // set of close statements for RichFunction that will be added only once
+  // we use a LinkedHashSet to keep the insertion order
+  private val reusableCloseStatements = mutable.LinkedHashSet[String]()
+
   // set of statements that will be added only once per record
   // we use a LinkedHashSet to keep the insertion order
   private val reusablePerRecordStatements = mutable.LinkedHashSet[String]()
@@ -150,6 +159,20 @@ class CodeGenerator(
   }
 
   /**
+    * @return code block of statements that need to be placed in the open() method of RichFunction
+    */
+  def reuseOpenCode(): String = {
+    reusableOpenStatements.mkString("", "\n", "\n")
+  }
+
+  /**
+    * @return code block of statements that need to be placed in the close() method of RichFunction
+    */
+  def reuseCloseCode(): String = {
+    reusableCloseStatements.mkString("", "\n", "\n")
+  }
+
+  /**
     * @return code block of statements that need to be placed in the SAM of the Function
     */
   def reusePerRecordCode(): String = {
@@ -240,27 +263,33 @@ class CodeGenerator(
     // manual casting here
     val samHeader =
       // FlatMapFunction
-      if (clazz == classOf[FlatMapFunction[_,_]]) {
+      if (clazz == classOf[FlatMapFunction[_, _]]) {
+        val baseClass = classOf[RichFlatMapFunction[_, _]]
         val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
-        (s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)",
+        (baseClass,
+          s"void flatMap(Object _in1, org.apache.flink.util.Collector $collectorTerm)",
           List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
       }
 
       // MapFunction
-      else if (clazz == classOf[MapFunction[_,_]]) {
+      else if (clazz == classOf[MapFunction[_, _]]) {
+        val baseClass = classOf[RichMapFunction[_, _]]
         val inputTypeTerm = boxedTypeTermForTypeInfo(input1)
-        ("Object map(Object _in1)",
+        (baseClass,
+          "Object map(Object _in1)",
           List(s"$inputTypeTerm $input1Term = ($inputTypeTerm) _in1;"))
       }
 
       // FlatJoinFunction
-      else if (clazz == classOf[FlatJoinFunction[_,_,_]]) {
+      else if (clazz == classOf[FlatJoinFunction[_, _, _]]) {
+        val baseClass = classOf[RichFlatJoinFunction[_, _, _]]
         val inputTypeTerm1 = boxedTypeTermForTypeInfo(input1)
         val inputTypeTerm2 = boxedTypeTermForTypeInfo(input2.getOrElse(
-            throw new CodeGenException("Input 2 for FlatJoinFunction should not be null")))
-        (s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)",
+          throw new CodeGenException("Input 2 for FlatJoinFunction should not be null")))
+        (baseClass,
+          s"void join(Object _in1, Object _in2, org.apache.flink.util.Collector $collectorTerm)",
           List(s"$inputTypeTerm1 $input1Term = ($inputTypeTerm1) _in1;",
-          s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
+               s"$inputTypeTerm2 $input2Term = ($inputTypeTerm2) _in2;"))
       }
       else {
         // TODO more functions
@@ -269,7 +298,7 @@ class CodeGenerator(
 
     val funcCode = j"""
       public class $funcName
-          implements ${clazz.getCanonicalName} {
+          extends ${samHeader._1.getCanonicalName} {
 
         ${reuseMemberCode()}
 
@@ -280,12 +309,22 @@ class CodeGenerator(
         ${reuseConstructorCode(funcName)}
 
         @Override
-        public ${samHeader._1} throws Exception {
-          ${samHeader._2.mkString("\n")}
+        public void open(${classOf[Configuration].getCanonicalName} parameters) throws Exception {
+          ${reuseOpenCode()}
+        }
+
+        @Override
+        public ${samHeader._2} throws Exception {
+          ${samHeader._3.mkString("\n")}
           ${reusePerRecordCode()}
           ${reuseInputUnboxingCode()}
           $bodyCode
         }
+
+        @Override
+        public void close() throws Exception {
+          ${reuseCloseCode()}
+        }
       }
     """.stripMargin
 
@@ -1480,6 +1519,19 @@ class CodeGenerator(
         |$fieldTerm = ($classQualifier) $constructorTerm.newInstance();
        """.stripMargin
     reusableInitStatements.add(constructorAccessibility)
+
+    val openFunction =
+      s"""
+         |$fieldTerm.open(new ${classOf[FunctionContext].getCanonicalName}(getRuntimeContext()));
+       """.stripMargin
+    reusableOpenStatements.add(openFunction)
+
+    val closeFunction =
+      s"""
+         |$fieldTerm.close();
+       """.stripMargin
+    reusableCloseStatements.add(closeFunction)
+
     fieldTerm
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala
new file mode 100644
index 0000000..beeb686
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/FunctionContext.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions
+
+import java.io.File
+
+import org.apache.flink.api.common.functions.RuntimeContext
+import org.apache.flink.metrics.MetricGroup
+
+/**
+  * A FunctionContext allows to obtain global runtime information about the context in which the
+  * user-defined function is executed. The information include the metric group,
+  * the distributed cache files, and the global job parameters.
+  *
+  * @param context the runtime context in which the Flink Function is executed
+  */
+class FunctionContext(context: RuntimeContext) {
+
+  /**
+    * Returns the metric group for this parallel subtask.
+    *
+    * @return metric group for this parallel subtask.
+    */
+  def getMetricGroup: MetricGroup = context.getMetricGroup
+
+  /**
+    * Gets the local temporary file copy of a distributed cache files.
+    *
+    * @param name distributed cache file name
+    * @return local temporary file copy of a distributed cache file.
+    */
+  def getCachedFile(name: String): File = context.getDistributedCache.getFile(name)
+
+  /**
+    * Gets the global job parameter value associated with the given key as a string.
+    *
+    * @param key          key pointing to the associated value
+    * @param defaultValue default value which is returned in case global job parameter is null
+    *                     or there is no value associated with the given key
+    * @return (default) value associated with the given key
+    */
+  def getJobParameter(key: String, defaultValue: String): String = {
+    val conf = context.getExecutionConfig.getGlobalJobParameters
+    if (conf != null && conf.toMap.containsKey(key)) {
+      conf.toMap.get(key)
+    } else {
+      defaultValue
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
index b99ab8d..c313d80 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala
@@ -23,5 +23,20 @@ package org.apache.flink.table.functions
   *
   * User-defined functions must have a default constructor and must be instantiable during runtime.
   */
-trait UserDefinedFunction {
+abstract class UserDefinedFunction {
+  /**
+    * Setup method for user-defined function. It can be used for initialization work.
+    *
+    * By default, this method does nothing.
+    */
+  @throws(classOf[Exception])
+  def open(context: FunctionContext): Unit = {}
+
+  /**
+    * Tear-down method for user-defined function. It can be used for clean up work.
+    *
+    * By default, this method does nothing.
+    */
+  @throws(classOf[Exception])
+  def close(): Unit = {}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
index 4e803da..a0415e1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
@@ -18,6 +18,7 @@
 
 package org.apache.flink.table.runtime
 
+import org.apache.flink.api.common.functions.util.FunctionUtils
 import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction}
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable
@@ -52,6 +53,8 @@ class CorrelateFlatMapRunner[IN, OUT](
     val constructor = flatMapClazz.getConstructor(classOf[TableFunctionCollector[_]])
     LOG.debug("Instantiating FlatMapFunction.")
     function = constructor.newInstance(collector).asInstanceOf[FlatMapFunction[IN, OUT]]
+    FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext)
+    FunctionUtils.openFunction(function, parameters)
   }
 
   override def flatMap(in: IN, out: Collector[OUT]): Unit = {
@@ -62,4 +65,8 @@ class CorrelateFlatMapRunner[IN, OUT](
   }
 
   override def getProducedType: TypeInformation[OUT] = returnType
+
+  override def close(): Unit = {
+    FunctionUtils.closeFunction(function)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
index a7bd980..b446306 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/FlatMapRunner.scala
@@ -18,6 +18,7 @@
 
 package org.apache.flink.table.runtime
 
+import org.apache.flink.api.common.functions.util.FunctionUtils
 import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction}
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable
@@ -43,10 +44,16 @@ class FlatMapRunner[IN, OUT](
     val clazz = compile(getRuntimeContext.getUserCodeClassLoader, name, code)
     LOG.debug("Instantiating FlatMapFunction.")
     function = clazz.newInstance()
+    FunctionUtils.setFunctionRuntimeContext(function, getRuntimeContext)
+    FunctionUtils.openFunction(function, parameters)
   }
 
   override def flatMap(in: IN, out: Collector[OUT]): Unit =
     function.flatMap(in, out)
 
   override def getProducedType: TypeInformation[OUT] = returnType
+
+  override def close(): Unit = {
+    FunctionUtils.closeFunction(function)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
index 3710642..00f4782 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/CalcITCase.scala
@@ -23,11 +23,11 @@ import java.sql.{Date, Time, Timestamp}
 import java.util
 
 import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.table.api.scala._
 import org.apache.flink.table.api.scala.batch.sql.FilterITCase.MyHashCode
-import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
 import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.table.api.scala._
-import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
 import org.apache.flink.table.api.{TableEnvironment, ValidationException}
 import org.apache.flink.table.functions.ScalarFunction
 import org.apache.flink.test.util.TestBaseUtils

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
index 2f853f3..b78dd91 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala
@@ -22,15 +22,15 @@ import java.sql.{Date, Time, Timestamp}
 import java.util
 
 import org.apache.flink.api.scala._
-import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.table.api.scala._
 import org.apache.flink.api.scala.util.CollectionDataSets
-import org.apache.flink.types.Row
 import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.api.scala.batch.utils.{TableProgramsCollectionTestBase, TableProgramsTestBase}
 import org.apache.flink.table.expressions.Literal
 import org.apache.flink.table.functions.ScalarFunction
 import org.apache.flink.test.util.TestBaseUtils
+import org.apache.flink.types.Row
 import org.junit._
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
index 97e76fa..70bec72 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala
@@ -19,12 +19,12 @@
 package org.apache.flink.table.api.scala.stream.sql
 
 import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
-import org.apache.flink.table.api.scala._
 import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
 import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
 import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
+import org.apache.flink.types.Row
 import org.junit.Assert._
 import org.junit._
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
index f541eb4..5969e91 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/CalcITCase.scala
@@ -19,13 +19,13 @@
 package org.apache.flink.table.api.scala.stream.table
 
 import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.expressions.Literal
 import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
 import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
 import org.apache.flink.table.api.{TableEnvironment, TableException}
+import org.apache.flink.table.expressions.Literal
+import org.apache.flink.types.Row
 import org.junit.Assert._
 import org.junit.Test
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
index da8c748..a6c1760 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -179,7 +179,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
       "Func12(f8)",
       "+0 00:00:01.000")
   }
-
+  
   @Test
   def testJavaBoxedPrimitives(): Unit = {
     val JavaFunc0 = new JavaFunc0()
@@ -211,6 +211,30 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
       "null and 15 and null")
   }
 
+  @Test
+  def testRichFunctions(): Unit = {
+    val richFunc0 = new RichFunc0
+    val richFunc1 = new RichFunc1
+    val richFunc2 = new RichFunc2
+    testAllApis(
+      richFunc0('f0),
+      "RichFunc0(f0)",
+      "RichFunc0(f0)",
+      "43")
+
+    testAllApis(
+      richFunc1('f0),
+      "RichFunc1(f0)",
+      "RichFunc1(f0)",
+      "42")
+
+    testAllApis(
+      richFunc2('f1),
+      "RichFunc2(f1)",
+      "RichFunc2(f1)",
+      "#Test")
+  }
+
   // ----------------------------------------------------------------------------------------------
 
   override def testData: Any = {
@@ -256,7 +280,10 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
     "Func11" -> Func11,
     "Func12" -> Func12,
     "JavaFunc0" -> new JavaFunc0,
-    "JavaFunc1" -> new JavaFunc1
+    "JavaFunc1" -> new JavaFunc1,
+    "RichFunc0" -> new RichFunc0,
+    "RichFunc1" -> new RichFunc1,
+    "RichFunc2" -> new RichFunc2
   )
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
index 679942c..30da5ba 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/ExpressionTestBase.scala
@@ -19,17 +19,23 @@
 package org.apache.flink.table.expressions.utils
 
 import org.apache.calcite.plan.hep.{HepMatchOrder, HepPlanner, HepProgramBuilder}
+import java.util
+import java.util.concurrent.Future
 import org.apache.calcite.rex.RexNode
 import org.apache.calcite.sql.`type`.SqlTypeName._
 import org.apache.calcite.sql2rel.RelDecorrelator
 import org.apache.calcite.tools.{Programs, RelBuilder}
-import org.apache.flink.api.common.functions.{Function, MapFunction}
+import org.apache.flink.api.common.TaskInfo
+import org.apache.flink.api.common.accumulators.Accumulator
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.common.functions.util.RuntimeUDFContext
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
 import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.api.java.{DataSet => JDataSet}
 import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
-import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.types.Row
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.core.fs.Path
 import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig, TableEnvironment}
 import org.apache.flink.table.calcite.FlinkPlannerImpl
 import org.apache.flink.table.codegen.{CodeGenerator, Compiler, GeneratedFunction}
@@ -37,6 +43,7 @@ import org.apache.flink.table.expressions.{Expression, ExpressionParser}
 import org.apache.flink.table.functions.ScalarFunction
 import org.apache.flink.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention}
 import org.apache.flink.table.plan.rules.FlinkRuleSets
+import org.apache.flink.types.Row
 import org.junit.Assert._
 import org.junit.{After, Before}
 import org.mockito.Mockito._
@@ -69,7 +76,8 @@ abstract class ExpressionTestBase {
     new HepPlanner(builder.build, context._2.getFrameworkConfig.getContext)
   }
 
-  private def prepareContext(typeInfo: TypeInformation[Any]): (RelBuilder, TableEnvironment) = {
+  private def prepareContext(typeInfo: TypeInformation[Any])
+    : (RelBuilder, TableEnvironment, ExecutionEnvironment) = {
     // create DataSetTable
     val dataSetMock = mock(classOf[DataSet[Any]])
     val jDataSetMock = mock(classOf[JDataSet[Any]])
@@ -85,7 +93,7 @@ abstract class ExpressionTestBase {
     val relBuilder = tEnv.getRelBuilder
     relBuilder.scan(tableName)
 
-    (relBuilder, tEnv)
+    (relBuilder, tEnv, env)
   }
 
   def testData: Any
@@ -130,8 +138,30 @@ abstract class ExpressionTestBase {
     // compile and evaluate
     val clazz = new TestCompiler[MapFunction[Any, Row], Row]().compile(genFunc)
     val mapper = clazz.newInstance()
+
+    val isRichFunction = mapper.isInstanceOf[RichFunction]
+
+    // call setRuntimeContext method and open method for RichFunction
+    if (isRichFunction) {
+      val richMapper = mapper.asInstanceOf[RichMapFunction[_, _]]
+      val t = new RuntimeUDFContext(
+        new TaskInfo("ExpressionTest", 1, 0, 1, 1),
+        null,
+        context._3.getConfig,
+        new util.HashMap[String, Future[Path]](),
+        new util.HashMap[String, Accumulator[_, _]](),
+        null)
+      richMapper.setRuntimeContext(t)
+      richMapper.open(new Configuration())
+    }
+
     val result = mapper.map(testData)
 
+    // call close method for RichFunction
+    if (isRichFunction) {
+      mapper.asInstanceOf[RichMapFunction[_, _]].close()
+    }
+
     // compare
     testExprs
       .zipWithIndex

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
index 4e9b6d3..f0b347d 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/utils/UserDefinedScalarFunctions.scala
@@ -22,7 +22,11 @@ import java.sql.{Date, Time, Timestamp}
 
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.table.api.Types
-import org.apache.flink.table.functions.ScalarFunction
+import org.apache.flink.table.functions.{ScalarFunction, FunctionContext}
+import org.junit.Assert
+
+import scala.collection.mutable
+import scala.io.Source
 
 case class SimplePojo(name: String, age: Int)
 
@@ -119,3 +123,94 @@ object Func12 extends ScalarFunction {
     Types.INTERVAL_MILLIS
   }
 }
+
+class RichFunc0 extends ScalarFunction {
+  var openCalled = false
+  var closeCalled = false
+
+  override def open(context: FunctionContext): Unit = {
+    super.open(context)
+    if (openCalled) {
+      Assert.fail("Open called more than once.")
+    } else {
+      openCalled = true
+    }
+    if (closeCalled) {
+      Assert.fail("Close called before open.")
+    }
+  }
+
+  def eval(index: Int): Int = {
+    if (!openCalled) {
+      Assert.fail("Open was not called before eval.")
+    }
+    if (closeCalled) {
+      Assert.fail("Close called before eval.")
+    }
+
+    index + 1
+  }
+
+  override def close(): Unit = {
+    super.close()
+    if (closeCalled) {
+      Assert.fail("Close called more than once.")
+    } else {
+      closeCalled = true
+    }
+    if (!openCalled) {
+      Assert.fail("Open was not called before close.")
+    }
+  }
+}
+
+class RichFunc1 extends ScalarFunction {
+  var added = Int.MaxValue
+
+  override def open(context: FunctionContext): Unit = {
+    added = context.getJobParameter("int.value", "0").toInt
+  }
+
+  def eval(index: Int): Int = {
+    index + added
+  }
+
+  override def close(): Unit = {
+    added = Int.MaxValue
+  }
+}
+
+class RichFunc2 extends ScalarFunction {
+  var prefix = "ERROR_VALUE"
+
+  override def open(context: FunctionContext): Unit = {
+    prefix = context.getJobParameter("string.value", "")
+  }
+
+  def eval(value: String): String = {
+    prefix + "#" + value
+  }
+
+  override def close(): Unit = {
+    prefix = "ERROR_VALUE"
+  }
+}
+
+class RichFunc3 extends ScalarFunction {
+  private val words = mutable.HashSet[String]()
+
+  override def open(context: FunctionContext): Unit = {
+    val file = context.getCachedFile("words")
+    for (line <- Source.fromFile(file.getCanonicalPath).getLines) {
+      words.add(line.trim)
+    }
+  }
+
+  def eval(value: String): Boolean = {
+    words.contains(value)
+  }
+
+  override def close(): Unit = {
+    words.clear()
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala
new file mode 100644
index 0000000..f0b3b44
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCalcITCase.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.dataset
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.util.CollectionDataSets
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.expressions.utils.{RichFunc1, RichFunc2, RichFunc3}
+import org.apache.flink.table.utils._
+import org.apache.flink.test.util.TestBaseUtils
+import org.apache.flink.types.Row
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+
+import scala.collection.JavaConverters._
+
+@RunWith(classOf[Parameterized])
+class DataSetCalcITCase(
+  configMode: TableConfigMode)
+  extends TableProgramsClusterTestBase(configMode) {
+
+  @Test
+  def testUserDefinedScalarFunctionWithParameter(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.registerFunction("RichFunc2", new RichFunc2)
+    UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "ABC"))
+
+    val ds = CollectionDataSets.getSmall3TupleDataSet(env)
+    tEnv.registerDataSet("t1", ds, 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT c FROM t1 where RichFunc2(c)='ABC#Hello'"
+
+    val result = tEnv.sql(sqlQuery)
+
+    val expected = "Hello"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testUserDefinedScalarFunctionWithDistributedCache(): Unit = {
+    val words = "Hello\nWord"
+    val filePath = UserDefinedFunctionTestUtils.writeCacheFile("test_words", words)
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.registerCachedFile(filePath, "words")
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.registerFunction("RichFunc3", new RichFunc3)
+
+    val ds = CollectionDataSets.getSmall3TupleDataSet(env)
+    tEnv.registerDataSet("t1", ds, 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT c FROM t1 where RichFunc3(c)=true"
+
+    val result = tEnv.sql(sqlQuery)
+
+    val expected = "Hello"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testMultipleUserDefinedScalarFunctions(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.registerFunction("RichFunc1", new RichFunc1)
+    tEnv.registerFunction("RichFunc2", new RichFunc2)
+    UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "Abc"))
+
+    val ds = CollectionDataSets.getSmall3TupleDataSet(env)
+    tEnv.registerDataSet("t1", ds, 'a, 'b, 'c)
+
+    val sqlQuery = "SELECT c FROM t1 where " +
+      "RichFunc2(c)='Abc#Hello' or RichFunc1(a)=3 and b=2"
+
+    val result = tEnv.sql(sqlQuery)
+
+    val expected = "Hello\nHello world"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
index 818f52b..cd1ffb5 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/dataset/DataSetCorrelateITCase.scala
@@ -20,14 +20,16 @@ package org.apache.flink.table.runtime.dataset
 import java.sql.{Date, Timestamp}
 
 import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
-import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
-import org.apache.flink.table.api.scala._
+import org.apache.flink.api.scala.util.CollectionDataSets
 import org.apache.flink.table.api.TableEnvironment
 import org.apache.flink.table.api.java.utils.UserDefinedTableFunctions.JavaTableFunc0
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsTestBase.TableConfigMode
+import org.apache.flink.table.api.scala.batch.utils.TableProgramsClusterTestBase
+import org.apache.flink.table.expressions.utils.RichFunc2
 import org.apache.flink.table.utils._
 import org.apache.flink.test.util.TestBaseUtils
+import org.apache.flink.types.Row
 import org.junit.Test
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
@@ -147,7 +149,7 @@ class DataSetCorrelateITCase(
   }
 
   @Test
-  def testUDTFWithScalarFunction(): Unit = {
+  def testUserDefinedTableFunctionWithScalarFunction(): Unit = {
     val env = ExecutionEnvironment.getExecutionEnvironment
     val tableEnv = TableEnvironment.getTableEnvironment(env, config)
     val in = testData(env).toTable(tableEnv).as('a, 'b, 'c)
@@ -185,6 +187,46 @@ class DataSetCorrelateITCase(
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }
 
+  @Test
+  def testUserDefinedTableFunctionWithParameter(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    val richTableFunc1 = new RichTableFunc1
+    tEnv.registerFunction("RichTableFunc1", richTableFunc1)
+    UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> "#"))
+
+    val result = testData(env)
+      .toTable(tEnv, 'a, 'b, 'c)
+      .join(richTableFunc1('c) as 's)
+      .select('a, 's)
+
+    val expected = "1,Jack\n" + "1,22\n" + "2,John\n" + "2,19\n" + "3,Anna\n" + "3,44"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testUserDefinedTableFunctionWithScalarFunctionWithParameters(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    val richTableFunc1 = new RichTableFunc1
+    tEnv.registerFunction("RichTableFunc1", richTableFunc1)
+    val richFunc2 = new RichFunc2
+    tEnv.registerFunction("RichFunc2", richFunc2)
+    UserDefinedFunctionTestUtils.setJobParameters(
+      env,
+      Map("word_separator" -> "#", "string.value" -> "test"))
+
+    val result = CollectionDataSets.getSmall3TupleDataSet(env)
+      .toTable(tEnv, 'a, 'b, 'c)
+      .join(richTableFunc1(richFunc2('c)) as 's)
+      .select('a, 's)
+
+    val expected = "1,Hi\n1,test\n2,Hello\n2,test\n3,Hello world\n3,test"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
   private def testData(
       env: ExecutionEnvironment)
     : DataSet[(Int, Long, String)] = {

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala
new file mode 100644
index 0000000..1d48f2c
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCalcITCase.scala
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.datastream
+
+import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
+import org.apache.flink.table.expressions.utils.{RichFunc1, RichFunc2}
+import org.apache.flink.table.utils.UserDefinedFunctionTestUtils
+import org.apache.flink.types.Row
+import org.junit.Assert._
+import org.junit.Test
+
+import scala.collection.mutable
+
+class DataStreamCalcITCase extends StreamingMultipleProgramsTestBase {
+
+  @Test
+  def testUserDefinedFunctionWithParameter(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.registerFunction("RichFunc2", new RichFunc2)
+    UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "ABC"))
+
+    StreamITCase.testResults = mutable.MutableList()
+
+    val result = StreamTestData.get3TupleDataStream(env)
+      .toTable(tEnv, 'a, 'b, 'c)
+      .where("RichFunc2(c)='ABC#Hello'")
+      .select('c)
+
+    val results = result.toDataStream[Row]
+    results.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("Hello")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testMultipleUserDefinedFunctions(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    tEnv.registerFunction("RichFunc1", new RichFunc1)
+    tEnv.registerFunction("RichFunc2", new RichFunc2)
+    UserDefinedFunctionTestUtils.setJobParameters(env, Map("string.value" -> "Abc"))
+
+    StreamITCase.testResults = mutable.MutableList()
+
+    val result = StreamTestData.get3TupleDataStream(env)
+      .toTable(tEnv, 'a, 'b, 'c)
+      .where("RichFunc2(c)='Abc#Hello' || RichFunc1(a)=3 && b=2")
+      .select('c)
+
+    val results = result.toDataStream[Row]
+    results.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("Hello", "Hello world")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
index eb20517..f8a697d 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/datastream/DataStreamCorrelateITCase.scala
@@ -18,13 +18,14 @@
 package org.apache.flink.table.runtime.datastream
 
 import org.apache.flink.api.scala._
-import org.apache.flink.types.Row
-import org.apache.flink.table.api.scala.stream.utils.StreamITCase
-import org.apache.flink.table.api.scala._
-import org.apache.flink.table.utils.TableFunc0
 import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
 import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
 import org.apache.flink.table.api.TableEnvironment
+import org.apache.flink.table.api.scala._
+import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData}
+import org.apache.flink.table.expressions.utils.RichFunc2
+import org.apache.flink.table.utils.{RichTableFunc1, TableFunc0, UserDefinedFunctionTestUtils}
+import org.apache.flink.types.Row
 import org.junit.Assert._
 import org.junit.Test
 
@@ -76,9 +77,63 @@ class DataStreamCorrelateITCase extends StreamingMultipleProgramsTestBase {
     assertEquals(expected.sorted, StreamITCase.testResults.sorted)
   }
 
+  @Test
+  def testUserDefinedTableFunctionWithParameter(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    val tableFunc1 = new RichTableFunc1
+    tEnv.registerFunction("RichTableFunc1", tableFunc1)
+    UserDefinedFunctionTestUtils.setJobParameters(env, Map("word_separator" -> " "))
+    StreamITCase.testResults = mutable.MutableList()
+
+    val result = StreamTestData.getSmall3TupleDataStream(env)
+      .toTable(tEnv, 'a, 'b, 'c)
+      .join(tableFunc1('c) as 's)
+      .select('a, 's)
+
+    val results = result.toDataStream[Row]
+    results.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("3,Hello", "3,world")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testUserDefinedTableFunctionWithUserDefinedScalarFunction(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    val tableFunc1 = new RichTableFunc1
+    val richFunc2 = new RichFunc2
+    tEnv.registerFunction("RichTableFunc1", tableFunc1)
+    tEnv.registerFunction("RichFunc2", richFunc2)
+    UserDefinedFunctionTestUtils.setJobParameters(
+      env,
+      Map("word_separator" -> "#", "string.value" -> "test"))
+    StreamITCase.testResults = mutable.MutableList()
+
+    val result = StreamTestData.getSmall3TupleDataStream(env)
+      .toTable(tEnv, 'a, 'b, 'c)
+      .join(tableFunc1(richFunc2('c)) as 's)
+      .select('a, 's)
+
+    val results = result.toDataStream[Row]
+    results.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "1,Hi",
+      "1,test",
+      "2,Hello",
+      "2,test",
+      "3,Hello world",
+      "3,test")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
   private def testData(
-    env: StreamExecutionEnvironment)
-  : DataStream[(Int, Long, String)] = {
+      env: StreamExecutionEnvironment)
+    : DataStream[(Int, Long, String)] = {
 
     val data = new mutable.MutableList[(Int, Long, String)]
     data.+=((1, 1L, "Jack#22"))

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala
new file mode 100644
index 0000000..deaedc9
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedFunctionTestUtils.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.utils
+
+import java.io.File
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
+
+object UserDefinedFunctionTestUtils {
+
+  def setJobParameters(env: ExecutionEnvironment, parameters: Map[String, String]): Unit = {
+    val conf = new Configuration()
+    parameters.foreach {
+      case (k, v) => conf.setString(k, v)
+    }
+    env.getConfig.setGlobalJobParameters(conf)
+  }
+
+  def setJobParameters(env: StreamExecutionEnvironment, parameters: Map[String, String]): Unit = {
+    val conf = new Configuration()
+    parameters.foreach {
+      case (k, v) => conf.setString(k, v)
+    }
+    env.getConfig.setGlobalJobParameters(conf)
+  }
+
+  def writeCacheFile(fileName: String, contents: String): String = {
+    val tempFile = File.createTempFile(this.getClass.getName + "-" + fileName, "tmp")
+    tempFile.deleteOnExit()
+    Files.write(contents, tempFile, Charsets.UTF_8)
+    tempFile.getAbsolutePath
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/b820fd3c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
index 54861ea..5db9d5f 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedTableFunctions.scala
@@ -21,9 +21,11 @@ import java.lang.Boolean
 
 import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
 import org.apache.flink.api.java.tuple.Tuple3
-import org.apache.flink.types.Row
 import org.apache.flink.api.java.typeutils.RowTypeInfo
-import org.apache.flink.table.functions.TableFunction
+import org.apache.flink.table.api.ValidationException
+import org.apache.flink.table.functions.{TableFunction, FunctionContext}
+import org.apache.flink.types.Row
+import org.junit.Assert
 
 
 case class SimpleUser(name: String, age: Int)
@@ -115,3 +117,55 @@ object ObjectTableFunction extends TableFunction[Integer] {
     collect(b)
   }
 }
+
+class RichTableFunc0 extends TableFunction[String] {
+  var openCalled = false
+  var closeCalled = false
+
+  override def open(context: FunctionContext): Unit = {
+    super.open(context)
+    if (closeCalled) {
+      Assert.fail("Close called before open.")
+    }
+    openCalled = true
+  }
+
+  def eval(str: String): Unit = {
+    if (!openCalled) {
+      Assert.fail("Open was not called before eval.")
+    }
+    if (closeCalled) {
+      Assert.fail("Close called before eval.")
+    }
+
+    if (!str.contains("#")) {
+      collect(str)
+    }
+  }
+
+  override def close(): Unit = {
+    super.close()
+    if (!openCalled) {
+      Assert.fail("Open was not called before close.")
+    }
+    closeCalled = true
+  }
+}
+
+class RichTableFunc1 extends TableFunction[String] {
+  var separator: Option[String] = None
+
+  override def open(context: FunctionContext): Unit = {
+    separator = Some(context.getJobParameter("word_separator", ""))
+  }
+
+  def eval(str: String): Unit = {
+    if (str.contains(separator.getOrElse(throw new ValidationException(s"no separator")))) {
+      str.split(separator.get).foreach(collect)
+    }
+  }
+
+  override def close(): Unit = {
+    separator = None
+  }
+}


Mime
View raw message