flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From twal...@apache.org
Subject [1/5] flink git commit: [FLINK-4469] [table] Add support for user defined table function in Table API & SQL
Date Wed, 07 Dec 2016 15:57:20 GMT
Repository: flink
Updated Branches:
  refs/heads/master c024b0b6c -> 684defbf3


http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
new file mode 100644
index 0000000..f19f7f9
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionITCase.scala
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.stream
+
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.stream.utils.StreamITCase
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils.{TableFunc0, TableFunc1}
+import org.apache.flink.api.table.{Row, TableEnvironment}
+import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
+import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase
+import org.junit.Assert._
+import org.junit.Test
+
+import scala.collection.mutable
+
+class UserDefinedTableFunctionITCase extends StreamingMultipleProgramsTestBase {
+
+  @Test
+  def testSQLCrossApply(): Unit = {
+
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+    tEnv.registerTable("MyTable", t)
+
+    tEnv.registerFunction("split", new TableFunc0)
+
+    val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable, LATERAL TABLE(split(c)) AS t(n,a)"
+
+    val result = tEnv.sql(sqlQuery).toDataStream[Row]
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testSQLOuterApply(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+    tEnv.registerTable("MyTable", t)
+
+    tEnv.registerFunction("split", new TableFunc0)
+
+    val sqlQuery = "SELECT MyTable.c, t.n, t.a FROM MyTable " +
+      "LEFT JOIN LATERAL TABLE(split(c)) AS t(n,a) ON TRUE"
+
+    val result = tEnv.sql(sqlQuery).toDataStream[Row]
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "nosharp,null,null", "Jack#22,Jack,22",
+      "John#19,John,19", "Anna#44,Anna,44")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testTableAPICrossApply(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+    val func0 = new TableFunc0
+
+    val result = t
+      .crossApply(func0('c) as('d, 'e))
+      .select('c, 'd, 'e)
+      .toDataStream[Row]
+
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("Jack#22,Jack,22", "John#19,John,19", "Anna#44,Anna,44")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testTableAPIOuterApply(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+    val func0 = new TableFunc0
+
+    val result = t
+      .outerApply(func0('c) as('d, 'e))
+      .select('c, 'd, 'e)
+      .toDataStream[Row]
+
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList(
+      "nosharp,null,null", "Jack#22,Jack,22",
+      "John#19,John,19", "Anna#44,Anna,44")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testTableAPIWithFilter(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+    val func0 = new TableFunc0
+
+    val result = t
+      .crossApply(func0('c) as('name, 'age))
+      .select('c, 'name, 'age)
+      .filter('age > 20)
+      .toDataStream[Row]
+
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("Jack#22,Jack,22", "Anna#44,Anna,44")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  @Test
+  def testTableAPIWithScalarFunction(): Unit = {
+    val env = StreamExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env)
+    StreamITCase.clear
+
+    val t = getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c)
+    val func1 = new TableFunc1
+
+    val result = t
+      .crossApply(func1('c.substring(2)) as 's)
+      .select('c, 's)
+      .toDataStream[Row]
+
+    result.addSink(new StreamITCase.StringSink)
+    env.execute()
+
+    val expected = mutable.MutableList("Jack#22,ack", "Jack#22,22", "John#19,ohn",
+                                       "John#19,19", "Anna#44,nna", "Anna#44,44")
+    assertEquals(expected.sorted, StreamITCase.testResults.sorted)
+  }
+
+  private def getSmall3TupleDataStream(
+    env: StreamExecutionEnvironment)
+  : DataStream[(Int, Long, String)] = {
+
+    val data = new mutable.MutableList[(Int, Long, String)]
+    data.+=((1, 1L, "Jack#22"))
+    data.+=((2, 2L, "John#19"))
+    data.+=((3, 2L, "Anna#44"))
+    data.+=((4, 3L, "nosharp"))
+    env.fromCollection(data)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
new file mode 100644
index 0000000..bc01819
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/stream/UserDefinedTableFunctionTest.scala
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.scala.stream
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala._
+import org.apache.flink.api.scala.table._
+import org.apache.flink.api.table.expressions.utils._
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+import org.apache.flink.api.table.utils.TableTestBase
+import org.apache.flink.api.table.utils.TableTestUtil._
+import org.apache.flink.api.table._
+import org.apache.flink.streaming.api.environment.{StreamExecutionEnvironment => JavaExecutionEnv}
+import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream}
+import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment =>
ScalaExecutionEnv}
+import org.junit.Assert.{assertTrue, fail}
+import org.junit.Test
+import org.mockito.Mockito._
+
+class UserDefinedTableFunctionTest extends TableTestBase {
+
+  @Test
+  def testTableAPI(): Unit = {
+    // mock
+    val ds = mock(classOf[DataStream[Row]])
+    val jDs = mock(classOf[JDataStream[Row]])
+    val typeInfo: TypeInformation[Row] = new RowTypeInfo(Seq(Types.INT, Types.LONG, Types.STRING))
+    when(ds.javaStream).thenReturn(jDs)
+    when(jDs.getType).thenReturn(typeInfo)
+
+    // Scala environment
+    val env = mock(classOf[ScalaExecutionEnv])
+    val tableEnv = TableEnvironment.getTableEnvironment(env)
+    val in1 = ds.toTable(tableEnv).as('a, 'b, 'c)
+
+    // Java environment
+    val javaEnv = mock(classOf[JavaExecutionEnv])
+    val javaTableEnv = TableEnvironment.getTableEnvironment(javaEnv)
+    val in2 = javaTableEnv.fromDataStream(jDs).as("a, b, c")
+
+    // test cross apply
+    val func1 = new TableFunc1
+    javaTableEnv.registerFunction("func1", func1)
+    var scalaTable = in1.crossApply(func1('c) as ('s)).select('c, 's)
+    var javaTable = in2.crossApply("func1(c) as (s)").select("c, s")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test outer apply
+    scalaTable = in1.outerApply(func1('c) as ('s)).select('c, 's)
+    javaTable = in2.outerApply("func1(c) as (s)").select("c, s")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test overloading
+    scalaTable = in1.crossApply(func1('c, "$") as ('s)).select('c, 's)
+    javaTable = in2.crossApply("func1(c, '$') as (s)").select("c, s")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test custom result type
+    val func2 = new TableFunc2
+    javaTableEnv.registerFunction("func2", func2)
+    scalaTable = in1.crossApply(func2('c) as ('name, 'len)).select('c, 'name, 'len)
+    javaTable = in2.crossApply("func2(c) as (name, len)").select("c, name, len")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test hierarchy generic type
+    val hierarchy = new HierarchyTableFunction
+    javaTableEnv.registerFunction("hierarchy", hierarchy)
+    scalaTable = in1.crossApply(hierarchy('c) as ('name, 'adult, 'len))
+      .select('c, 'name, 'len, 'adult)
+    javaTable = in2.crossApply("hierarchy(c) as (name, adult, len)")
+      .select("c, name, len, adult")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test pojo type
+    val pojo = new PojoTableFunc
+    javaTableEnv.registerFunction("pojo", pojo)
+    scalaTable = in1.crossApply(pojo('c))
+      .select('c, 'name, 'age)
+    javaTable = in2.crossApply("pojo(c)")
+      .select("c, name, age")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test with filter
+    scalaTable = in1.crossApply(func2('c) as ('name, 'len))
+      .select('c, 'name, 'len).filter('len > 2)
+    javaTable = in2.crossApply("func2(c) as (name, len)")
+      .select("c, name, len").filter("len > 2")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // test with scalar function
+    scalaTable = in1.crossApply(func1('c.substring(2)) as ('s))
+      .select('a, 'c, 's)
+    javaTable = in2.crossApply("func1(substring(c, 2)) as (s)")
+      .select("a, c, s")
+    verifyTableEquals(scalaTable, javaTable)
+
+    // check scala object is forbidden
+    expectExceptionThrown(
+      tableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
+    expectExceptionThrown(
+      javaTableEnv.registerFunction("func3", ObjectTableFunction), "Scala object")
+    expectExceptionThrown(
+      in1.crossApply(ObjectTableFunction('a, 1)),"Scala object")
+
+  }
+
+
+  @Test
+  def testInvalidTableFunction(): Unit = {
+    // mock
+    val util = streamTestUtil()
+    val t = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val tEnv = TableEnvironment.getTableEnvironment(mock(classOf[JavaExecutionEnv]))
+
+    //=================== check scala object is forbidden =====================
+    // Scala table environment register
+    expectExceptionThrown(util.addFunction("udtf", ObjectTableFunction), "Scala object")
+    // Java table environment register
+    expectExceptionThrown(tEnv.registerFunction("udtf", ObjectTableFunction), "Scala object")
+    // Scala Table API directly call
+    expectExceptionThrown(t.crossApply(ObjectTableFunction('a, 1)), "Scala object")
+
+
+    //============ throw exception when table function is not registered =========
+    // Java Table API call
+    expectExceptionThrown(t.crossApply("nonexist(a)"), "Undefined function: NONEXIST")
+    // SQL API call
+    expectExceptionThrown(
+      util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(nonexist(a))"),
+      "No match found for function signature nonexist(<NUMERIC>)")
+
+
+    //========= throw exception when the called function is a scalar function ====
+    util.addFunction("func0", Func0)
+    // Java Table API call
+    expectExceptionThrown(
+      t.crossApply("func0(a)"),
+      "only accept TableFunction",
+      classOf[TableException])
+    // SQL API call
+    // NOTE: it doesn't throw an exception but an AssertionError, maybe a Calcite bug
+    expectExceptionThrown(
+      util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func0(a))"),
+      null,
+      classOf[AssertionError])
+
+    //========== throw exception when the parameters is not correct ===============
+    // Java Table API call
+    util.addFunction("func2", new TableFunc2)
+    expectExceptionThrown(
+      t.crossApply("func2(c, c)"),
+      "Given parameters of function 'FUNC2' do not match any signature")
+    // SQL API call
+    expectExceptionThrown(
+      util.tEnv.sql("SELECT * FROM MyTable, LATERAL TABLE(func2(c, c))"),
+      "No match found for function signature func2(<CHARACTER>, <CHARACTER>)")
+  }
+
+  private def expectExceptionThrown(
+      function: => Unit,
+      keywords: String,
+      clazz: Class[_ <: Throwable] = classOf[ValidationException])
+    : Unit = {
+    try {
+      function
+      fail(s"Expected a $clazz, but no exception is thrown.")
+    } catch {
+      case e if e.getClass == clazz =>
+        if (keywords != null) {
+          assertTrue(
+            s"The exception message '${e.getMessage}' doesn't contain keyword '$keywords'",
+            e.getMessage.contains(keywords))
+        }
+      case e: Throwable => fail(s"Expected throw ${clazz.getSimpleName}, but is $e.")
+    }
+  }
+
+  @Test
+  def testSQLWithCrossApply(): Unit = {
+    val util = streamTestUtil()
+    val func1 = new TableFunc1
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func1", func1)
+
+    val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c)) AS T(s)"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "func1($cor0.c)"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647)
f0)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery, expected)
+
+    // test overloading
+
+    val sqlQuery2 = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(c, '$')) AS T(s)"
+
+    val expected2 = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "func1($cor0.c, '$')"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647)
f0)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery2, expected2)
+  }
+
+  @Test
+  def testSQLWithOuterApply(): Unit = {
+    val util = streamTestUtil()
+    val func1 = new TableFunc1
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func1", func1)
+
+    val sqlQuery = "SELECT c, s FROM MyTable LEFT JOIN LATERAL TABLE(func1(c)) AS T(s) ON
TRUE"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "func1($cor0.c)"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647)
f0)"),
+        term("joinType", "LEFT")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithCustomType(): Unit = {
+    val util = streamTestUtil()
+    val func2 = new TableFunc2
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func2", func2)
+
+    val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name,
len)"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "func2($cor0.c)"),
+        term("function", func2.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+               "VARCHAR(2147483647) f0, INTEGER f1)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS name", "f1 AS len")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithHierarchyType(): Unit = {
+    val util = streamTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = new HierarchyTableFunction
+    util.addFunction("hierarchy", function)
+
+    val sqlQuery = "SELECT c, T.* FROM MyTable, LATERAL TABLE(hierarchy(c)) AS T(name, adult,
len)"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "hierarchy($cor0.c)"),
+        term("function", function.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+               " VARCHAR(2147483647) f0, BOOLEAN f1, INTEGER f2)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS name", "f1 AS adult", "f2 AS len")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithPojoType(): Unit = {
+    val util = streamTestUtil()
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = new PojoTableFunc
+    util.addFunction("pojo", function)
+
+    val sqlQuery = "SELECT c, name, age FROM MyTable, LATERAL TABLE(pojo(c))"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "pojo($cor0.c)"),
+        term("function", function.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c," +
+               " INTEGER age, VARCHAR(2147483647) name)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "name", "age")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+  @Test
+  def testSQLWithFilter(): Unit = {
+    val util = streamTestUtil()
+    val func2 = new TableFunc2
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func2", func2)
+
+    val sqlQuery = "SELECT c, name, len FROM MyTable, LATERAL TABLE(func2(c)) AS T(name,
len) " +
+      "WHERE len > 2"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "func2($cor0.c)"),
+        term("function", func2.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, " +
+               "VARCHAR(2147483647) f0, INTEGER f1)"),
+        term("joinType", "INNER"),
+        term("condition", ">($1, 2)")
+      ),
+      term("select", "c", "f0 AS name", "f1 AS len")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+
+  @Test
+  def testSQLWithScalarFunction(): Unit = {
+    val util = streamTestUtil()
+    val func1 = new TableFunc1
+    util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    util.addFunction("func1", func1)
+
+    val sqlQuery = "SELECT c, s FROM MyTable, LATERAL TABLE(func1(SUBSTRING(c, 2))) AS T(s)"
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", "func1(SUBSTRING($cor0.c, 2))"),
+        term("function", func1.getClass.getCanonicalName),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, VARCHAR(2147483647)
f0)"),
+        term("joinType", "INNER")
+      ),
+      term("select", "c", "f0 AS s")
+    )
+
+    util.verifySql(sqlQuery, expected)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
index 95cb331..ffe3cd3 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/UserDefinedScalarFunctionTest.scala
@@ -24,7 +24,7 @@ import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.scala.table._
 import org.apache.flink.api.table.expressions.utils._
-import org.apache.flink.api.table.functions.UserDefinedFunction
+import org.apache.flink.api.table.functions.ScalarFunction
 import org.apache.flink.api.table.typeutils.RowTypeInfo
 import org.apache.flink.api.table.{Row, Types}
 import org.junit.Test
@@ -208,7 +208,7 @@ class UserDefinedScalarFunctionTest extends ExpressionTestBase {
     )).asInstanceOf[TypeInformation[Any]]
   }
 
-  override def functions: Map[String, UserDefinedFunction] = Map(
+  override def functions: Map[String, ScalarFunction] = Map(
     "Func0" -> Func0,
     "Func1" -> Func1,
     "Func2" -> Func2,

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
index 84b61da..958fd25 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/ExpressionTestBase.scala
@@ -30,7 +30,7 @@ import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
 import org.apache.flink.api.table._
 import org.apache.flink.api.table.codegen.{CodeGenerator, Compiler, GeneratedFunction}
 import org.apache.flink.api.table.expressions.{Expression, ExpressionParser}
-import org.apache.flink.api.table.functions.UserDefinedFunction
+import org.apache.flink.api.table.functions.ScalarFunction
 import org.apache.flink.api.table.plan.nodes.dataset.{DataSetCalc, DataSetConvention}
 import org.apache.flink.api.table.plan.rules.FlinkRuleSets
 import org.apache.flink.api.table.typeutils.RowTypeInfo
@@ -79,7 +79,7 @@ abstract class ExpressionTestBase {
 
   def typeInfo: TypeInformation[Any]
 
-  def functions: Map[String, UserDefinedFunction] = Map()
+  def functions: Map[String, ScalarFunction] = Map()
 
   @Before
   def resetTestExprs() = {

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
new file mode 100644
index 0000000..1e6bdb8
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/expressions/utils/UserDefinedTableFunctions.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.api.table.expressions.utils
+
+import java.lang.Boolean
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.tuple.Tuple3
+import org.apache.flink.api.table.Row
+import org.apache.flink.api.table.functions.TableFunction
+import org.apache.flink.api.table.typeutils.RowTypeInfo
+
+
+case class SimpleUser(name: String, age: Int)
+
+class TableFunc0 extends TableFunction[SimpleUser] {
+  // make sure input element's format is "<string>#<int>"
+  def eval(user: String): Unit = {
+    if (user.contains("#")) {
+      val splits = user.split("#")
+      collect(SimpleUser(splits(0), splits(1).toInt))
+    }
+  }
+}
+
+class TableFunc1 extends TableFunction[String] {
+  def eval(str: String): Unit = {
+    if (str.contains("#")){
+      str.split("#").foreach(collect)
+    }
+  }
+
+  def eval(str: String, prefix: String): Unit = {
+    if (str.contains("#")) {
+      str.split("#").foreach(s => collect(prefix + s))
+    }
+  }
+}
+
+
+class TableFunc2 extends TableFunction[Row] {
+  def eval(str: String): Unit = {
+    if (str.contains("#")) {
+      str.split("#").foreach({ s =>
+        val row = new Row(2)
+        row.setField(0, s)
+        row.setField(1, s.length)
+        collect(row)
+      })
+    }
+  }
+
+  override def getResultType: TypeInformation[Row] = {
+    new RowTypeInfo(Seq(BasicTypeInfo.STRING_TYPE_INFO,
+                        BasicTypeInfo.INT_TYPE_INFO))
+  }
+}
+
+class HierarchyTableFunction extends SplittableTableFunction[Boolean, Integer] {
+  def eval(user: String) {
+    if (user.contains("#")) {
+      val splits = user.split("#")
+      val age = splits(1).toInt
+      collect(new Tuple3[String, Boolean, Integer](splits(0), age >= 20, age))
+    }
+  }
+}
+
+abstract class SplittableTableFunction[A, B] extends TableFunction[Tuple3[String, A, B]]
{}
+
+class PojoTableFunc extends TableFunction[PojoUser] {
+  def eval(user: String) {
+    if (user.contains("#")) {
+      val splits = user.split("#")
+      collect(new PojoUser(splits(0), splits(1).toInt))
+    }
+  }
+}
+
+class PojoUser() {
+  var name: String = _
+  var age: Int = 0
+
+  def this(name: String, age: Int) {
+    this()
+    this.name = name
+    this.age = age
+  }
+}
+
+// ----------------------------------------------------------------------------------------------
+// Invalid Table Functions
+// ----------------------------------------------------------------------------------------------
+
+
+// this is used to check whether scala object is forbidden
+object ObjectTableFunction extends TableFunction[Integer] {
+  def eval(a: Int, b: Int): Unit = {
+    collect(a)
+    collect(b)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/e139f59c/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
index 539bb61..73f50f5 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/utils/TableTestBase.scala
@@ -24,6 +24,7 @@ import org.apache.flink.api.java.{DataSet => JDataSet}
 import org.apache.flink.api.scala.table._
 import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
 import org.apache.flink.api.table.expressions.Expression
+import org.apache.flink.api.table.functions.{ScalarFunction, TableFunction}
 import org.apache.flink.api.table.{Table, TableEnvironment}
 import org.apache.flink.streaming.api.datastream.{DataStream => JDataStream}
 import org.apache.flink.streaming.api.scala.{DataStream, StreamExecutionEnvironment}
@@ -43,6 +44,12 @@ class TableTestBase {
     StreamTableTestUtil()
   }
 
+  def verifyTableEquals(expected: Table, actual: Table): Unit = {
+    assertEquals("Logical Plan do not match",
+                 RelOptUtil.toString(expected.getRelNode),
+                 RelOptUtil.toString(actual.getRelNode))
+  }
+
 }
 
 abstract class TableTestUtil {
@@ -54,6 +61,9 @@ abstract class TableTestUtil {
   }
 
   def addTable[T: TypeInformation](name: String, fields: Expression*): Table
+  def addFunction[T: TypeInformation](name: String, function: TableFunction[T]): Unit
+  def addFunction(name: String, function: ScalarFunction): Unit
+
   def verifySql(query: String, expected: String): Unit
   def verifyTable(resultTable: Table, expected: String): Unit
 
@@ -119,6 +129,17 @@ case class BatchTableTestUtil() extends TableTestUtil {
     t
   }
 
+  def addFunction[T: TypeInformation](
+      name: String,
+      function: TableFunction[T])
+    : Unit = {
+    tEnv.registerFunction(name, function)
+  }
+
+  def addFunction(name: String, function: ScalarFunction): Unit = {
+    tEnv.registerFunction(name, function)
+  }
+
   def verifySql(query: String, expected: String): Unit = {
     verifyTable(tEnv.sql(query), expected)
   }
@@ -164,6 +185,17 @@ case class StreamTableTestUtil() extends TableTestUtil {
     t
   }
 
+  def addFunction[T: TypeInformation](
+      name: String,
+      function: TableFunction[T])
+    : Unit = {
+    tEnv.registerFunction(name, function)
+  }
+
+  def addFunction(name: String, function: ScalarFunction): Unit = {
+    tEnv.registerFunction(name, function)
+  }
+
   def verifySql(query: String, expected: String): Unit = {
     verifyTable(tEnv.sql(query), expected)
   }


Mime
View raw message