flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From k...@apache.org
Subject [1/4] flink git commit: [FLINK-3849] [table] Add FilterableTableSource interface and rules for pushing it (1)
Date Fri, 17 Mar 2017 10:03:03 GMT
Repository: flink
Updated Branches:
  refs/heads/master ab014ef94 -> 78f22aaec


http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala
new file mode 100644
index 0000000..c4059d5
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramExpressionExtractorTest.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import java.math.BigDecimal
+
+import org.apache.calcite.adapter.java.JavaTypeFactory
+import org.apache.calcite.plan._
+import org.apache.calcite.plan.volcano.VolcanoPlanner
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
+import org.apache.calcite.rel.core.TableScan
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.`type`.SqlTypeName._
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
+import org.apache.flink.table.calcite.{FlinkRelBuilder, FlinkTypeFactory}
+import org.apache.flink.table.expressions.{Expression, ExpressionParser}
+import org.apache.flink.table.plan.util.RexProgramExpressionExtractor._
+import org.apache.flink.table.plan.schema.CompositeRelDataType
+import org.apache.flink.table.utils.CommonTestData
+import org.junit.Test
+import org.junit.Assert._
+
+import scala.collection.JavaConverters._
+
+class RexProgramExpressionExtractorTest {
+
+  private val typeFactory = new FlinkTypeFactory(RelDataTypeSystem.DEFAULT)
+  private val allFieldTypes = List(VARCHAR, DECIMAL, INTEGER, DOUBLE).map(typeFactory.createSqlType)
+  private val allFieldTypeInfos: Array[TypeInformation[_]] =
+    Array(BasicTypeInfo.STRING_TYPE_INFO,
+      BasicTypeInfo.BIG_DEC_TYPE_INFO,
+      BasicTypeInfo.INT_TYPE_INFO,
+      BasicTypeInfo.DOUBLE_TYPE_INFO)
+  private val allFieldNames = List("name", "id", "amount", "price")
+
+  @Test
+  def testExtractExpression(): Unit = {
+    val builder: RexBuilder = new RexBuilder(typeFactory)
+    val program = buildRexProgram(
+      allFieldNames, allFieldTypes, typeFactory, builder)
+    val firstExp = ExpressionParser.parseExpression("id > 6")
+    val secondExp = ExpressionParser.parseExpression("amount * price < 100")
+    val expected: Array[Expression] = Array(firstExp, secondExp)
+    val actual = extractPredicateExpressions(
+      program,
+      builder,
+      CommonTestData.getMockTableEnvironment.getFunctionCatalog)
+
+    assertEquals(expected.length, actual.length)
+    // todo
+  }
+
+  @Test
+  def testRewriteRexProgramWithCondition(): Unit = {
+    val originalRexProgram = buildRexProgram(
+      allFieldNames, allFieldTypes, typeFactory, new RexBuilder(typeFactory))
+    val array = Array(
+      "$0",
+      "$1",
+      "$2",
+      "$3",
+      "*($t2, $t3)",
+      "100",
+      "<($t4, $t5)",
+      "6",
+      ">($t1, $t7)",
+      "AND($t6, $t8)")
+    assertTrue(extractExprStrList(originalRexProgram) sameElements array)
+
+    val tEnv = CommonTestData.getMockTableEnvironment
+    val builder = FlinkRelBuilder.create(tEnv.getFrameworkConfig)
+    val tableScan = new MockTableScan(builder.getRexBuilder)
+    val newExpression = ExpressionParser.parseExpression("amount * price < 100")
+    val newRexProgram = rewriteRexProgram(
+      originalRexProgram,
+      tableScan,
+      Array(newExpression)
+    )(builder)
+
+    val newArray = Array(
+      "$0",
+      "$1",
+      "$2",
+      "$3",
+      "*($t2, $t3)",
+      "100",
+      "<($t4, $t5)")
+    assertTrue(extractExprStrList(newRexProgram) sameElements newArray)
+  }
+
+//  @Test
+//  def testVerifyExpressions(): Unit = {
+//    val strPart = "f1 < 4"
+//    val part = parseExpression(strPart)
+//
+//    val shortFalseOrigin = parseExpression(s"f0 > 10 || $strPart")
+//    assertFalse(verifyExpressions(shortFalseOrigin, part))
+//
+//    val longFalseOrigin = parseExpression(s"(f0 > 10 || (($strPart) > POWER(f0, f1)))
&& 2")
+//    assertFalse(verifyExpressions(longFalseOrigin, part))
+//
+//    val shortOkayOrigin = parseExpression(s"f0 > 10 && ($strPart)")
+//    assertTrue(verifyExpressions(shortOkayOrigin, part))
+//
+//    val longOkayOrigin = parseExpression(s"f0 > 10 && (($strPart) > POWER(f0,
f1))")
+//    assertTrue(verifyExpressions(longOkayOrigin, part))
+//
+//    val longOkayOrigin2 = parseExpression(s"(f0 > 10 || (2 > POWER(f0, f1))) &&
$strPart")
+//    assertTrue(verifyExpressions(longOkayOrigin2, part))
+//  }
+
+  private def buildRexProgram(
+      fieldNames: List[String],
+      fieldTypes: Seq[RelDataType],
+      typeFactory: JavaTypeFactory,
+      rexBuilder: RexBuilder): RexProgram = {
+
+    val inputRowType = typeFactory.createStructType(fieldTypes.asJava, fieldNames.asJava)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+
+    val t0 = rexBuilder.makeInputRef(fieldTypes(2), 2)
+    val t1 = rexBuilder.makeInputRef(fieldTypes(1), 1)
+    val t2 = rexBuilder.makeInputRef(fieldTypes(3), 3)
+    // t3 = t0 * t2
+    val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
+    val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+    val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
+    // project: amount, amount * price
+    builder.addProject(t0, "amount")
+    builder.addProject(t3, "total")
+    // t6 = t3 < t4
+    val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
+    // t7 = t1 > t5
+    val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
+    val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
+    // condition: t6 and t7
+    // (t0 * t2 < t4 && t1 > t5)
+    builder.addCondition(t8)
+    builder.getProgram
+  }
+
+  /**
+    * extract all expression string list from input RexProgram expression lists
+    *
+    * @param rexProgram input RexProgram instance to analyze
+    * @return all expression string list of input RexProgram expression lists
+    */
+  private def extractExprStrList(rexProgram: RexProgram) =
+    rexProgram.getExprList.asScala.map(_.toString).toArray
+
+  class MockTableScan(
+      rexBuilder: RexBuilder)
+    extends TableScan(
+      RelOptCluster.create(new VolcanoPlanner(), rexBuilder),
+      RelTraitSet.createEmpty,
+      new MockRelOptTable)
+
+  class MockRelOptTable
+    extends RelOptAbstractTable(
+      null,
+      "mockRelTable",
+      new CompositeRelDataType(
+        new RowTypeInfo(allFieldTypeInfos, allFieldNames.toArray), typeFactory))
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala
new file mode 100644
index 0000000..cea9eee
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/util/RexProgramProjectExtractorTest.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.plan.util
+
+import java.math.BigDecimal
+
+import org.apache.calcite.adapter.java.JavaTypeFactory
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeSystem}
+import org.apache.calcite.rex.{RexBuilder, RexProgram, RexProgramBuilder}
+import org.apache.calcite.sql.`type`.SqlTypeName.{BIGINT, DOUBLE, INTEGER, VARCHAR}
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
+import org.apache.flink.table.plan.util.RexProgramProjectExtractor._
+import org.junit.Assert.{assertArrayEquals, assertTrue}
+import org.junit.{Before, Test}
+
+import scala.collection.JavaConverters._
+
+/**
+  * This class is responsible for testing RexProgramProjectExtractor.
+  */
+class RexProgramProjectExtractorTest {
+  private var typeFactory: JavaTypeFactory = _
+  private var rexBuilder: RexBuilder = _
+  private var allFieldTypes: Seq[RelDataType] = _
+  private val allFieldNames = List("name", "id", "amount", "price")
+
+  @Before
+  def setUp(): Unit = {
+    typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
+    rexBuilder = new RexBuilder(typeFactory)
+    allFieldTypes = List(VARCHAR, BIGINT, INTEGER, DOUBLE).map(typeFactory.createSqlType(_))
+  }
+
+  @Test
+  def testExtractRefInputFields(): Unit = {
+    val usedFields = extractRefInputFields(buildRexProgram())
+    assertArrayEquals(usedFields, Array(2, 3, 1))
+  }
+
+  @Test
+  def testRewriteRexProgram(): Unit = {
+    val originRexProgram = buildRexProgram()
+    assertTrue(extractExprStrList(originRexProgram).sameElements(Array(
+      "$0",
+      "$1",
+      "$2",
+      "$3",
+      "*($t2, $t3)",
+      "100",
+      "<($t4, $t5)",
+      "6",
+      ">($t1, $t7)",
+      "AND($t6, $t8)")))
+    // use amount, id, price fields to create a new RexProgram
+    val usedFields = Array(2, 3, 1)
+    val types = usedFields.map(allFieldTypes(_)).toList.asJava
+    val names = usedFields.map(allFieldNames(_)).toList.asJava
+    val inputRowType = typeFactory.createStructType(types, names)
+    val newRexProgram = rewriteRexProgram(originRexProgram, inputRowType, usedFields, rexBuilder)
+    assertTrue(extractExprStrList(newRexProgram).sameElements(Array(
+      "$0",
+      "$1",
+      "$2",
+      "*($t0, $t1)",
+      "100",
+      "<($t3, $t4)",
+      "6",
+      ">($t2, $t6)",
+      "AND($t5, $t7)")))
+  }
+
+  private def buildRexProgram(): RexProgram = {
+    val types = allFieldTypes.asJava
+    val names = allFieldNames.asJava
+    val inputRowType = typeFactory.createStructType(types, names)
+    val builder = new RexProgramBuilder(inputRowType, rexBuilder)
+    val t0 = rexBuilder.makeInputRef(types.get(2), 2)
+    val t1 = rexBuilder.makeInputRef(types.get(1), 1)
+    val t2 = rexBuilder.makeInputRef(types.get(3), 3)
+    val t3 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, t0, t2))
+    val t4 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(100L))
+    val t5 = rexBuilder.makeExactLiteral(BigDecimal.valueOf(6L))
+    // project: amount, amount * price
+    builder.addProject(t0, "amount")
+    builder.addProject(t3, "total")
+    // condition: amount * price < 100 and id > 6
+    val t6 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, t3, t4))
+    val t7 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.GREATER_THAN, t1, t5))
+    val t8 = builder.addExpr(rexBuilder.makeCall(SqlStdOperatorTable.AND, List(t6, t7).asJava))
+    builder.addCondition(t8)
+    builder.getProgram
+  }
+
+  /**
+    * extract all expression string list from input RexProgram expression lists
+    *
+    * @param rexProgram input RexProgram instance to analyze
+    * @return all expression string list of input RexProgram expression lists
+    */
+  private def extractExprStrList(rexProgram: RexProgram) = {
+    rexProgram.getExprList.asScala.map(_.toString)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/9f6cd2e7/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
index 6e4859b..a720f02 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
@@ -21,14 +21,21 @@ package org.apache.flink.table.utils
 import java.io.{File, FileOutputStream, OutputStreamWriter}
 import java.util
 
-import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource}
+import org.apache.calcite.tools.RuleSet
 import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
-import org.apache.flink.api.common.typeutils.TypeSerializer
-import org.apache.flink.api.java.typeutils.{PojoField, PojoTypeInfo, TypeExtractor}
+import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.api.java.{DataSet, ExecutionEnvironment}
-import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
-import org.apache.flink.table.sources.{BatchTableSource, CsvTableSource, TableSource}
-import org.apache.flink.api.scala._
+import org.apache.flink.streaming.api.datastream.DataStream
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
+import org.apache.flink.table.api.{Table, TableConfig, TableEnvironment}
+import org.apache.flink.table.expressions._
+import org.apache.flink.table.sinks.TableSink
+import org.apache.flink.table.sources._
+import org.apache.flink.types.Row
+
+import scala.collection.JavaConverters._
 
 object CommonTestData {
 
@@ -98,4 +105,113 @@ object CommonTestData {
       this(null, null)
     }
   }
+
+  def getMockTableEnvironment: TableEnvironment = new MockTableEnvironment
+
+  def getFilterableTableSource = new TestFilterableTableSource
+}
+
+class MockTableEnvironment extends TableEnvironment(new TableConfig) {
+
+  override private[flink] def writeToSink[T](table: Table, sink: TableSink[T]): Unit = ???
+
+  override protected def checkValidTableName(name: String): Unit = ???
+
+  override def sql(query: String): Table = ???
+
+  override def registerTableSource(name: String, tableSource: TableSource[_]): Unit = ???
+
+  override protected def getBuiltInNormRuleSet: RuleSet = ???
+
+  override protected def getBuiltInOptRuleSet: RuleSet = ???
+}
+
+class TestFilterableTableSource
+    extends BatchTableSource[Row]
+    with StreamTableSource[Row]
+    with FilterableTableSource
+    with DefinedFieldNames {
+
+  import org.apache.flink.table.api.Types._
+
+  val fieldNames = Array("name", "id", "amount", "price")
+  val fieldTypes = Array[TypeInformation[_]](STRING, LONG, INT, DOUBLE)
+
+  private var filterLiteral: Literal = _
+  private var filterPredicates: Array[Expression] = Array.empty
+
+  /** Returns the data of the table as a [[DataSet]]. */
+  override def getDataSet(execEnv: ExecutionEnvironment): DataSet[Row] = {
+    execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType)
+  }
+
+  /** Returns the data of the table as a [[DataStream]]. */
+  def getDataStream(execEnv: StreamExecutionEnvironment): DataStream[Row] = {
+    execEnv.fromCollection[Row](generateDynamicCollection(33).asJava, getReturnType)
+  }
+
+  private def generateDynamicCollection(num: Int): Seq[Row] = {
+
+    if (filterLiteral == null) {
+      throw new RuntimeException("filter expression was not set")
+    }
+
+    val filterValue = filterLiteral.value.asInstanceOf[Number].intValue()
+
+    def shouldCreateRow(value: Int): Boolean = {
+      value > filterValue
+    }
+
+    for {
+      cnt <- 0 until num
+      if shouldCreateRow(cnt)
+    } yield {
+        val row = new Row(fieldNames.length)
+        fieldNames.zipWithIndex.foreach { case (name, index) =>
+          name match {
+            case "name" =>
+              row.setField(index, s"Record_$cnt")
+            case "id" =>
+              row.setField(index, cnt.toLong)
+            case "amount" =>
+              row.setField(index, cnt.toInt)
+            case "price" =>
+              row.setField(index, cnt.toDouble)
+          }
+        }
+      row
+      }
+  }
+
+  /** Returns the [[TypeInformation]] for the return type. */
+  override def getReturnType: TypeInformation[Row] = new RowTypeInfo(fieldTypes, fieldNames)
+
+  /** Returns the names of the table fields. */
+  override def getFieldNames: Array[String] = fieldNames
+
+  /** Returns the indices of the table fields. */
+  override def getFieldIndices: Array[Int] = fieldNames.indices.toArray
+
+  override def getPredicate: Array[Expression] = filterPredicates
+
+  /** Return an unsupported predicates expression. */
+  override def setPredicate(predicates: Array[Expression]): Array[Expression] = {
+    predicates(0) match {
+      case gt: GreaterThan =>
+        gt.left match {
+          case f: ResolvedFieldReference =>
+            gt.right match {
+              case l: Literal =>
+                if (f.name.equals("amount")) {
+                  filterLiteral = l
+                  filterPredicates = Array(predicates(0))
+                  Array(predicates(1))
+                } else predicates
+              case _ => predicates
+            }
+          case _ => predicates
+        }
+      case _ => predicates
+    }
+  }
 }


Mime
View raw message