flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From twal...@apache.org
Subject flink git commit: [FLINK-8492] [table] Improve cost estimation for Calcs
Date Wed, 31 Jan 2018 13:14:07 GMT
Repository: flink
Updated Branches:
  refs/heads/master 2b76ecab8 -> 8cf2be7e2


[FLINK-8492] [table] Improve cost estimation for Calcs

This closes #5347.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8cf2be7e
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8cf2be7e
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8cf2be7e

Branch: refs/heads/master
Commit: 8cf2be7e2be72c3b4cc3af36e64bcf30b27ec5bb
Parents: 2b76eca
Author: hequn8128 <chenghequn@gmail.com>
Authored: Wed Jan 24 12:25:13 2018 +0800
Committer: twalthr <twalthr@apache.org>
Committed: Wed Jan 31 14:13:06 2018 +0100

----------------------------------------------------------------------
 .../flink/table/plan/nodes/CommonCalc.scala     | 27 ++++--
 .../rules/dataSet/DataSetCorrelateRule.scala    | 12 +--
 .../datastream/DataStreamCorrelateRule.scala    | 11 +--
 .../flink/table/plan/util/CorrelateUtil.scala   | 67 +++++++++++++++
 .../table/api/batch/sql/SetOperatorsTest.scala  |  6 +-
 .../flink/table/api/batch/table/CalcTest.scala  | 19 +++++
 .../table/api/batch/table/CorrelateTest.scala   | 89 +++++++++++++++++++-
 .../flink/table/api/stream/table/CalcTest.scala | 19 +++++
 .../table/api/stream/table/CorrelateTest.scala  | 85 ++++++++++++++++++-
 9 files changed, 311 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala
index 2f1871b..36df67a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCalc.scala
@@ -19,6 +19,7 @@
 package org.apache.flink.table.plan.nodes
 
 import org.apache.calcite.plan.{RelOptCost, RelOptPlanner}
+import org.apache.calcite.rel.metadata.RelMdUtil
 import org.apache.calcite.rex._
 import org.apache.flink.api.common.functions.Function
 import org.apache.flink.table.api.TableConfig
@@ -149,12 +150,9 @@ trait CommonCalc {
     // conditions, etc. We only want to account for computations, not for simple projections.
     // CASTs in RexProgram are reduced as far as possible by ReduceExpressionsRule
     // in normalization stage. So we should ignore CASTs here in optimization stage.
-    val compCnt = calcProgram.getExprList.asScala.toList.count {
-      case _: RexInputRef => false
-      case _: RexLiteral => false
-      case c: RexCall if c.getOperator.getName.equals("CAST") => false
-      case _ => true
-    }
+    // Also, we add 1 to take calc RelNode number into consideration, so the cost of merged
calc
+    // RelNode will be less than the total cost of un-merged calcs.
+    val compCnt = calcProgram.getExprList.asScala.toList.count(isComputation) + 1
 
     val newRowCnt = estimateRowCount(calcProgram, rowCnt)
     planner.getCostFactory.makeCost(newRowCnt, newRowCnt * compCnt, 0)
@@ -166,9 +164,24 @@ trait CommonCalc {
 
     if (calcProgram.getCondition != null) {
       // we reduce the result card to push filters down
-      (rowCnt * 0.75).max(1.0)
+      val exprs = calcProgram.expandLocalRef(calcProgram.getCondition)
+      val selectivity = RelMdUtil.guessSelectivity(exprs, false)
+      (rowCnt * selectivity).max(1.0)
     } else {
       rowCnt
     }
   }
+
+  /**
+    * Return true if the input rexNode do not access a field or literal, i.e. computations,
+    * conditions, etc.
+    */
+  private[flink] def isComputation(rexNode: RexNode): Boolean = {
+    rexNode match {
+      case _: RexInputRef => false
+      case _: RexLiteral => false
+      case c: RexCall if c.getOperator.getName.equals("CAST") => false
+      case _ => true
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetCorrelateRule.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetCorrelateRule.scala
index 79cd64a..af9b516 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetCorrelateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetCorrelateRule.scala
@@ -25,6 +25,7 @@ import org.apache.calcite.rex.RexNode
 import org.apache.flink.table.plan.nodes.FlinkConventions
 import org.apache.flink.table.plan.nodes.dataset.DataSetCorrelate
 import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate,
FlinkLogicalTableFunctionScan}
+import org.apache.flink.table.plan.util.CorrelateUtil
 
 class DataSetCorrelateRule
   extends ConverterRule(
@@ -37,14 +38,11 @@ class DataSetCorrelateRule
       val join: FlinkLogicalCorrelate = call.rel(0).asInstanceOf[FlinkLogicalCorrelate]
       val right = join.getRight.asInstanceOf[RelSubset].getOriginal
 
-
       right match {
         // right node is a table function
         case scan: FlinkLogicalTableFunctionScan => true
         // a filter is pushed above the table function
-        case calc: FlinkLogicalCalc =>
-          calc.getInput.asInstanceOf[RelSubset]
-            .getOriginal.isInstanceOf[FlinkLogicalTableFunctionScan]
+        case calc: FlinkLogicalCalc if CorrelateUtil.getTableFunctionScan(calc).isDefined
=> true
         case _ => false
       }
     }
@@ -61,9 +59,11 @@ class DataSetCorrelateRule
             convertToCorrelate(rel.getRelList.get(0), condition)
 
           case calc: FlinkLogicalCalc =>
+            val tableScan = CorrelateUtil.getTableFunctionScan(calc).get
+            val newCalc = CorrelateUtil.getMergedCalc(calc)
             convertToCorrelate(
-              calc.getInput.asInstanceOf[RelSubset].getOriginal,
-              Some(calc.getProgram.expandLocalRef(calc.getProgram.getCondition)))
+              tableScan,
+              Some(newCalc.getProgram.expandLocalRef(newCalc.getProgram.getCondition)))
 
           case scan: FlinkLogicalTableFunctionScan =>
             new DataSetCorrelate(

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCorrelateRule.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCorrelateRule.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCorrelateRule.scala
index cd0663e..ae0370d 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCorrelateRule.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamCorrelateRule.scala
@@ -26,6 +26,7 @@ import org.apache.flink.table.plan.nodes.FlinkConventions
 import org.apache.flink.table.plan.nodes.datastream.DataStreamCorrelate
 import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalCorrelate,
FlinkLogicalTableFunctionScan}
 import org.apache.flink.table.plan.schema.RowSchema
+import org.apache.flink.table.plan.util.CorrelateUtil
 
 class DataStreamCorrelateRule
   extends ConverterRule(
@@ -42,9 +43,7 @@ class DataStreamCorrelateRule
       // right node is a table function
       case scan: FlinkLogicalTableFunctionScan => true
       // a filter is pushed above the table function
-      case calc: FlinkLogicalCalc =>
-        calc.getInput.asInstanceOf[RelSubset]
-            .getOriginal.isInstanceOf[FlinkLogicalTableFunctionScan]
+      case calc: FlinkLogicalCalc if CorrelateUtil.getTableFunctionScan(calc).isDefined =>
true
       case _ => false
     }
   }
@@ -61,9 +60,11 @@ class DataStreamCorrelateRule
           convertToCorrelate(rel.getRelList.get(0), condition)
 
         case calc: FlinkLogicalCalc =>
+          val tableScan = CorrelateUtil.getTableFunctionScan(calc).get
+          val newCalc = CorrelateUtil.getMergedCalc(calc)
           convertToCorrelate(
-            calc.getInput.asInstanceOf[RelSubset].getOriginal,
-            Some(calc.getProgram.expandLocalRef(calc.getProgram.getCondition)))
+            tableScan,
+            Some(newCalc.getProgram.expandLocalRef(newCalc.getProgram.getCondition)))
 
         case scan: FlinkLogicalTableFunctionScan =>
           new DataStreamCorrelate(

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/CorrelateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/CorrelateUtil.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/CorrelateUtil.scala
new file mode 100644
index 0000000..a74fa35
--- /dev/null
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/CorrelateUtil.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.plan.util
+
+import org.apache.calcite.plan.volcano.RelSubset
+import org.apache.calcite.rex.{RexProgram, RexProgramBuilder}
+import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalCalc, FlinkLogicalTableFunctionScan}
+
+/**
+  * A utility for datasteam and dataset correlate rules.
+  */
+object CorrelateUtil {
+
+  /**
+    * Get [[FlinkLogicalTableFunctionScan]] from the input calc. Returns None if there is
no table
+    * function at the end.
+    */
+  def getTableFunctionScan(calc: FlinkLogicalCalc): Option[FlinkLogicalTableFunctionScan]
= {
+    val child = calc.getInput.asInstanceOf[RelSubset].getOriginal
+    child match {
+      case scan: FlinkLogicalTableFunctionScan => Some(scan)
+      case calc: FlinkLogicalCalc => getTableFunctionScan(calc)
+      case _ => None
+    }
+  }
+
+  /**
+    * Merge continuous calcs.
+    *
+    * @param calc the input calc
+    * @return the single merged calc
+    */
+  def getMergedCalc(calc: FlinkLogicalCalc): FlinkLogicalCalc = {
+    val child = calc.getInput.asInstanceOf[RelSubset].getOriginal
+    child match {
+      case logicalCalc: FlinkLogicalCalc =>
+        val bottomCalc = getMergedCalc(logicalCalc)
+        val topCalc = calc
+        val topProgram: RexProgram = topCalc.getProgram
+        val mergedProgram: RexProgram = RexProgramBuilder
+          .mergePrograms(
+            topCalc.getProgram,
+            bottomCalc.getProgram,
+            topCalc.getCluster.getRexBuilder)
+        assert(mergedProgram.getOutputRowType eq topProgram.getOutputRowType)
+        topCalc.copy(topCalc.getTraitSet, bottomCalc.getInput, mergedProgram)
+          .asInstanceOf[FlinkLogicalCalc]
+      case _ =>
+        calc
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala
index d51fc42..e6f4a46 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala
@@ -113,7 +113,7 @@ class SetOperatorsTest extends TableTestBase {
             term("join", "a", "b", "c", "$f0", "$f1"),
             term("joinType", "NestedLoopInnerJoin")
           ),
-          term("select", "a AS $f0", "b AS $f1", "c AS $f2", "$f0 AS $f3", "$f1 AS $f4",
"b AS $f5")
+          term("select", "a AS $f0", "c AS $f2", "$f0 AS $f3", "$f1 AS $f4", "b AS $f5")
         ),
         unaryNode(
           "DataSetAggregate",
@@ -127,11 +127,11 @@ class SetOperatorsTest extends TableTestBase {
           term("select", "$f0", "MIN($f1) AS $f1")
         ),
         term("where", "=($f5, $f00)"),
-        term("join", "$f0", "$f1", "$f2", "$f3", "$f4", "$f5", "$f00", "$f10"),
+        term("join", "$f0", "$f2", "$f3", "$f4", "$f5", "$f00", "$f1"),
         term("joinType", "LeftOuterJoin")
       ),
       term("select", "$f0 AS a", "$f2 AS c"),
-      term("where", "OR(=($f3, 0), AND(IS NULL($f10), >=($f4, $f3), IS NOT NULL($f5)))")
+      term("where", "OR(=($f3, 0), AND(IS NULL($f1), >=($f4, $f3), IS NOT NULL($f5)))")
     )
 
     util.verifySql(

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
index c2fa647..59ecec0 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala
@@ -312,6 +312,25 @@ class CalcTest extends TableTestBase {
 
     util.verifyTable(resultTable, expected)
   }
+
+  @Test
+  def testMultiFilter(): Unit = {
+    val util = batchTestUtil()
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.select('a, 'b)
+      .filter('a > 0)
+      .filter('b < 2)
+      .filter(('a % 2) === 1)
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      batchTableNode(0),
+      term("select", "a", "b"),
+      term("where", "AND(AND(>(a, 0), <(b, 2)), =(MOD(a, 2), 1))")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
 }
 
 object CalcTest {

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
index bcaa8b7..8381e97 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CorrelateTest.scala
@@ -18,13 +18,18 @@
 
 package org.apache.flink.table.api.batch.table
 
+import org.apache.calcite.rel.rules.{CalcMergeRule, FilterCalcMergeRule, ProjectCalcMergeRule}
+import org.apache.calcite.tools.RuleSets
 import org.apache.flink.api.scala._
-import org.apache.flink.table.api.ValidationException
 import org.apache.flink.table.api.scala._
+import org.apache.flink.table.calcite.{CalciteConfig, CalciteConfigBuilder}
+import org.apache.flink.table.plan.rules.FlinkRuleSets
 import org.apache.flink.table.utils.TableTestUtil._
-import org.apache.flink.table.utils.{TableFunc1, TableTestBase}
+import org.apache.flink.table.utils.{TableFunc0, TableFunc1, TableTestBase}
 import org.junit.Test
 
+import scala.collection.JavaConversions._
+
 class CorrelateTest extends TableTestBase {
 
   @Test
@@ -125,4 +130,84 @@ class CorrelateTest extends TableTestBase {
 
     util.verifyTable(result, expected)
   }
+
+  @Test
+  def testCorrelateWithMultiFilter(): Unit = {
+    val util = batchTestUtil()
+    val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = util.addFunction("func1", new TableFunc0)
+
+    val result = sourceTable.select('a, 'b, 'c)
+      .join(function('c) as('d, 'e))
+      .select('c, 'd, 'e)
+      .where('e > 10)
+      .where('e > 20)
+      .select('c, 'd)
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", s"${function.functionIdentifier}($$2)"),
+        term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+        term("select", "a", "b", "c", "d", "e"),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
+               "VARCHAR(65536) d, INTEGER e)"),
+        term("joinType", "INNER"),
+        term("condition", "AND(>($1, 10), >($1, 20))")
+      ),
+      term("select", "c", "d")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testCorrelateWithMultiFilterAndWithoutCalcMergeRules(): Unit = {
+    val util = batchTestUtil()
+
+    val logicalRuleSet = FlinkRuleSets.LOGICAL_OPT_RULES.filter {
+        case CalcMergeRule.INSTANCE => false
+        case FilterCalcMergeRule.INSTANCE => false
+        case ProjectCalcMergeRule.INSTANCE => false
+        case _ => true
+      }
+
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .replaceLogicalOptRuleSet(RuleSets.ofList(logicalRuleSet.toList))
+      .build()
+
+    util.tableEnv.getConfig.setCalciteConfig(cc)
+
+    val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = util.addFunction("func1", new TableFunc0)
+
+    val result = sourceTable.select('a, 'b, 'c)
+      .join(function('c) as('d, 'e))
+      .select('c, 'd, 'e)
+      .where('e > 10)
+      .where('e > 20)
+      .select('c, 'd)
+
+    val expected = unaryNode(
+      "DataSetCalc",
+      unaryNode(
+        "DataSetCorrelate",
+        batchTableNode(0),
+        term("invocation", s"${function.functionIdentifier}($$2)"),
+        term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+        term("select", "a", "b", "c", "d", "e"),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
+               "VARCHAR(65536) d, INTEGER e)"),
+        term("joinType", "INNER"),
+        term("condition", "AND(>($1, 10), >($1, 20))")
+      ),
+      term("select", "c", "d")
+    )
+
+    util.verifyTable(result, expected)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala
index 02f84c0..8cbc03c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CalcTest.scala
@@ -94,6 +94,25 @@ class CalcTest extends TableTestBase {
 
     util.verifyTable(resultTable, expected)
   }
+
+  @Test
+  def testMultiFilter(): Unit = {
+    val util = streamTestUtil()
+    val sourceTable = util.addTable[(Int, Long, String, Double)]("MyTable", 'a, 'b, 'c, 'd)
+    val resultTable = sourceTable.select('a, 'b)
+      .filter('a > 0)
+      .filter('b < 2)
+      .filter(('a % 2) === 1)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      streamTableNode(0),
+      term("select", "a", "b"),
+      term("where", "AND(AND(>(a, 0), <(b, 2)), =(MOD(a, 2), 1))")
+    )
+
+    util.verifyTable(resultTable, expected)
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8cf2be7e/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
index 72421a8..7e766ec 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/table/CorrelateTest.scala
@@ -17,14 +17,19 @@
  */
 package org.apache.flink.table.api.stream.table
 
+import org.apache.calcite.rel.rules.{CalcMergeRule, FilterCalcMergeRule, ProjectCalcMergeRule}
+import org.apache.calcite.tools.RuleSets
 import org.apache.flink.api.scala._
-import org.apache.flink.table.api.{TableException, ValidationException}
 import org.apache.flink.table.api.scala._
+import org.apache.flink.table.calcite.{CalciteConfig, CalciteConfigBuilder}
 import org.apache.flink.table.expressions.utils.Func13
+import org.apache.flink.table.plan.rules.FlinkRuleSets
 import org.apache.flink.table.utils.TableTestUtil._
 import org.apache.flink.table.utils._
 import org.junit.Test
 
+import scala.collection.JavaConversions._
+
 class CorrelateTest extends TableTestBase {
 
   @Test
@@ -229,4 +234,82 @@ class CorrelateTest extends TableTestBase {
     util.verifyTable(result, expected)
   }
 
+  @Test
+  def testCorrelateWithMultiFilter(): Unit = {
+    val util = streamTestUtil()
+    val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = util.addFunction("func1", new TableFunc0)
+
+    val result = sourceTable.select('a, 'b, 'c)
+      .join(function('c) as('d, 'e))
+      .select('c, 'd, 'e)
+      .where('e > 10)
+      .where('e > 20)
+      .select('c, 'd)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", s"${function.functionIdentifier}($$2)"),
+        term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+        term("select", "a", "b", "c", "d", "e"),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
+               "VARCHAR(65536) d, INTEGER e)"),
+        term("joinType", "INNER"),
+        term("condition", "AND(>($1, 10), >($1, 20))")
+      ),
+      term("select", "c", "d")
+    )
+
+    util.verifyTable(result, expected)
+  }
+
+  @Test
+  def testCorrelateWithMultiFilterAndWithoutCalcMergeRules(): Unit = {
+    val util = streamTestUtil()
+
+    val logicalRuleSet = FlinkRuleSets.LOGICAL_OPT_RULES.filter {
+      case CalcMergeRule.INSTANCE => false
+      case FilterCalcMergeRule.INSTANCE => false
+      case ProjectCalcMergeRule.INSTANCE => false
+      case _ => true
+    }
+
+    val cc: CalciteConfig = new CalciteConfigBuilder()
+      .replaceLogicalOptRuleSet(RuleSets.ofList(logicalRuleSet.toList))
+      .build()
+
+    util.tableEnv.getConfig.setCalciteConfig(cc)
+
+    val sourceTable = util.addTable[(Int, Long, String)]("MyTable", 'a, 'b, 'c)
+    val function = util.addFunction("func1", new TableFunc0)
+    val result = sourceTable.select('a, 'b, 'c)
+      .join(function('c) as('d, 'e))
+      .select('c, 'd, 'e)
+      .where('e > 10)
+      .where('e > 20)
+      .select('c, 'd)
+
+    val expected = unaryNode(
+      "DataStreamCalc",
+      unaryNode(
+        "DataStreamCorrelate",
+        streamTableNode(0),
+        term("invocation", s"${function.functionIdentifier}($$2)"),
+        term("correlate", s"table(${function.getClass.getSimpleName}(c))"),
+        term("select", "a", "b", "c", "d", "e"),
+        term("rowType",
+             "RecordType(INTEGER a, BIGINT b, VARCHAR(65536) c, " +
+               "VARCHAR(65536) d, INTEGER e)"),
+        term("joinType", "INNER"),
+        term("condition", "AND(>($1, 10), >($1, 20))")
+      ),
+      term("select", "c", "d")
+    )
+
+    util.verifyTable(result, expected)
+  }
 }


Mime
View raw message