flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From twal...@apache.org
Subject flink git commit: [FLINK-3748] [table] Add CASE function to Table API
Date Tue, 19 Apr 2016 08:08:10 GMT
Repository: flink
Updated Branches:
  refs/heads/master 85fcfc4d4 -> 4be297ec2


[FLINK-3748] [table] Add CASE function to Table API

This closes #1893.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/4be297ec
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/4be297ec
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/4be297ec

Branch: refs/heads/master
Commit: 4be297ec29601a25ee5c12268d38b5f4f6203bd6
Parents: 85fcfc4
Author: twalthr <twalthr@apache.org>
Authored: Wed Apr 13 14:36:36 2016 +0200
Committer: twalthr <twalthr@apache.org>
Committed: Tue Apr 19 10:07:33 2016 +0200

----------------------------------------------------------------------
 .../flink/api/scala/table/expressionDsl.scala   | 13 +++++
 .../flink/api/table/codegen/CodeGenerator.scala |  3 +
 .../table/codegen/calls/ScalarOperators.scala   | 59 ++++++++++++++++++++
 .../table/expressions/ExpressionParser.scala    |  7 ++-
 .../flink/api/table/expressions/call.scala      |  4 +-
 .../flink/api/table/expressions/logic.scala     | 21 +++++++
 .../api/table/plan/RexNodeTranslator.scala      |  5 ++
 .../api/java/table/test/ExpressionsITCase.java  | 41 ++++++++++++++
 .../api/scala/sql/test/ExpressionsITCase.scala  | 51 +++++++++++++++++
 .../scala/table/test/ExpressionsITCase.scala    | 29 ++++++++++
 10 files changed, 229 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
index 505d872..c6f14f3 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/scala/table/expressionDsl.scala
@@ -65,6 +65,19 @@ trait ImplicitExpressionOperations {
 
   def as(name: Symbol) = Naming(expr, name.name)
 
+  /**
+    * Conditional operator that decides which of two other expressions should be evaluated
+    * based on a evaluated boolean condition.
+    *
+    * e.g. (42 > 5).eval("A", "B") leads to "A"
+    *
+    * @param ifTrue expression to be evaluated if condition holds
+    * @param ifFalse expression to be evaluated if condition does not hold
+    */
+  def eval(ifTrue: Expression, ifFalse: Expression) = {
+    Eval(expr, ifTrue, ifFalse)
+  }
+
   // scalar functions
 
   /**

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
index d41674c..e090a29 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/CodeGenerator.scala
@@ -723,6 +723,9 @@ class CodeGenerator(
         requireBoolean(operand)
         generateNot(nullCheck, operand)
 
+      case CASE =>
+        generateIfElse(nullCheck, operands, resultType)
+
       // casting
       case CAST =>
         val operand = operands.head

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
index f71b643..182b843 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/codegen/calls/ScalarOperators.scala
@@ -385,6 +385,65 @@ object ScalarOperators {
     }
   }
 
+  def generateIfElse(
+      nullCheck: Boolean,
+      operands: Seq[GeneratedExpression],
+      resultType: TypeInformation[_],
+      i: Int = 0)
+    : GeneratedExpression = {
+    // else part
+    if (i == operands.size - 1) {
+      generateCast(nullCheck, operands(i), resultType)
+    }
+    else {
+      // check that the condition is boolean
+      // we do not check for null instead we use the default value
+      // thus null is false
+      requireBoolean(operands(i))
+      val condition = operands(i)
+      val trueAction = generateCast(nullCheck, operands(i + 1), resultType)
+      val falseAction = generateIfElse(nullCheck, operands, resultType, i + 2)
+
+      val resultTerm = newName("result")
+      val nullTerm = newName("isNull")
+      val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
+
+      val operatorCode = if (nullCheck) {
+        s"""
+          |${condition.code}
+          |$resultTypeTerm $resultTerm;
+          |boolean $nullTerm;
+          |if (${condition.resultTerm}) {
+          |  ${trueAction.code}
+          |  $resultTerm = ${trueAction.resultTerm};
+          |  $nullTerm = ${trueAction.nullTerm};
+          |}
+          |else {
+          |  ${falseAction.code}
+          |  $resultTerm = ${falseAction.resultTerm};
+          |  $nullTerm = ${falseAction.nullTerm};
+          |}
+          |""".stripMargin
+      }
+      else {
+        s"""
+          |${condition.code}
+          |$resultTypeTerm $resultTerm;
+          |if (${condition.resultTerm}) {
+          |  ${trueAction.code}
+          |  $resultTerm = ${trueAction.resultTerm};
+          |}
+          |else {
+          |  ${falseAction.code}
+          |  $resultTerm = ${falseAction.resultTerm};
+          |}
+          |""".stripMargin
+      }
+
+      GeneratedExpression(resultTerm, nullTerm, operatorCode, resultType)
+    }
+  }
+
   // ----------------------------------------------------------------------------------------------
 
   private def generateUnaryOperatorIfNotNull(

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
index 8a24d3c..e488d1b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/ExpressionParser.scala
@@ -134,6 +134,11 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers
{
     case e ~ _ ~ target ~ _ => Naming(e, target.name)
   }
 
+  lazy val eval: PackratParser[Expression] = atom ~
+      ".eval(" ~ expression ~ "," ~ expression ~ ")" ^^ {
+    case condition ~ _ ~ ifTrue ~ _ ~ ifFalse ~ _ => Eval(condition, ifTrue, ifFalse)
+  }
+
   // general function calls
 
   lazy val functionCall = ident ~ "(" ~ rep1sep(expression, ",") ~ ")" ^^ {
@@ -200,7 +205,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
 
   lazy val suffix =
     isNull | isNotNull |
-      sum | min | max | count | avg | cast | nullLiteral |
+      sum | min | max | count | avg | cast | nullLiteral | eval |
       specialFunctionCalls | functionCall | functionCallWithoutArgs |
       specialSuffixFunctionCalls | suffixFunctionCall | suffixFunctionCallWithoutArgs |
       atom

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
index 280d213..c26cd7a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/call.scala
@@ -17,8 +17,6 @@
  */
 package org.apache.flink.api.table.expressions
 
-import scala.collection.JavaConversions._
-
 import org.apache.calcite.rex.RexNode
 import org.apache.calcite.sql.SqlOperator
 import org.apache.calcite.sql.fun.SqlStdOperatorTable
@@ -35,7 +33,7 @@ case class Call(functionName: String, args: Expression*) extends Expression
{
   override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
     relBuilder.call(
       BuiltInFunctionNames.toSqlOperator(functionName),
-      args.map(_.toRexNode))
+      args.map(_.toRexNode): _*)
   }
 
   override def toString = s"\\$functionName(${args.mkString(", ")})"

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala
index 99da371..37a6597 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/expressions/logic.scala
@@ -18,6 +18,7 @@
 package org.apache.flink.api.table.expressions
 
 import org.apache.calcite.rex.RexNode
+import org.apache.calcite.sql.fun.SqlStdOperatorTable
 import org.apache.calcite.tools.RelBuilder
 
 abstract class BinaryPredicate extends BinaryExpression { self: Product => }
@@ -54,3 +55,23 @@ case class Or(left: Expression, right: Expression) extends BinaryPredicate
{
     relBuilder.or(left.toRexNode, right.toRexNode)
   }
 }
+
+case class Eval(
+    condition: Expression,
+    ifTrue: Expression,
+    ifFalse: Expression)
+  extends Expression {
+  def children = Seq(condition, ifTrue, ifFalse)
+
+  override def toString = s"($condition)? $ifTrue : $ifFalse"
+
+  override val name = Expression.freshName("if-" + condition.name +
+    "-then-" + ifTrue.name + "-else-" + ifFalse.name)
+
+  override def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
+    val c = condition.toRexNode
+    val t = ifTrue.toRexNode
+    val f = ifFalse.toRexNode
+    relBuilder.call(SqlStdOperatorTable.CASE, c, t, f)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala
index 0c1db28..f946ed9 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/RexNodeTranslator.scala
@@ -55,6 +55,11 @@ object RexNodeTranslator {
         val l = extractAggCalls(b.left, tableEnv)
         val r = extractAggCalls(b.right, tableEnv)
         (b.makeCopy(List(l._1, r._1)), l._2 ::: r._2)
+      case e: Eval =>
+        val c = extractAggCalls(e.condition, tableEnv)
+        val t = extractAggCalls(e.ifTrue, tableEnv)
+        val f = extractAggCalls(e.ifFalse, tableEnv)
+        (e.makeCopy(List(c._1, t._1, f._1)), c._2 ::: t._2 ::: f._2)
 
       // Scalar functions
       case c@Call(name, args@_*) =>

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/ExpressionsITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/ExpressionsITCase.java
b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/ExpressionsITCase.java
index 879c74f..bbd9352 100644
--- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/ExpressionsITCase.java
+++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/ExpressionsITCase.java
@@ -132,5 +132,46 @@ public class ExpressionsITCase extends TableProgramsTestBase {
 		}
 	}
 
+	@Test
+	public void testEval() throws Exception {
+		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
+
+		DataSource<Tuple2<Integer, Boolean>> input =
+				env.fromElements(new Tuple2<>(5, true));
+
+		Table table =
+				tableEnv.fromDataSet(input, "a, b");
+
+		Table result = table.select(
+				"(b && true).eval('true', 'false')," +
+					"false.eval('true', 'false')," +
+					"true.eval(true.eval(true.eval(10, 4), 4), 4)");
+
+		DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
+		List<Row> results = ds.collect();
+		String expected = "true,false,10";
+		compareResultAsText(results, expected);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void testEvalInvalidTypes() throws Exception {
+		ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		BatchTableEnvironment tableEnv = TableEnvironment.getTableEnvironment(env, config());
+
+		DataSource<Tuple2<Integer, Boolean>> input =
+				env.fromElements(new Tuple2<>(5, true));
+
+		Table table =
+				tableEnv.fromDataSet(input, "a, b");
+
+		Table result = table.select("(b && true).eval(5, 'false')");
+
+		DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
+		List<Row> results = ds.collect();
+		String expected = "true,false,3,10";
+		compareResultAsText(results, expected);
+	}
+
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala
index fb69d08..9aac4f8 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/sql/test/ExpressionsITCase.scala
@@ -69,4 +69,55 @@ class ExpressionsITCase(
     }
   }
 
+  @Test
+  def testCase(): Unit = {
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val sqlQuery = "SELECT " +
+      "CASE 11 WHEN 1 THEN 'a' ELSE 'b' END," +
+      "CASE 2 WHEN 1 THEN 'a' ELSE 'b' END," +
+      "CASE 1 WHEN 1, 2 THEN '1 or 2' WHEN 2 THEN 'not possible' WHEN 3, 2 THEN '3' " +
+      "  ELSE 'none of the above' END" +
+      " FROM MyTable"
+
+    val ds = env.fromElements((1, 0))
+    tEnv.registerDataSet("MyTable", ds, 'a, 'b)
+
+    val result = tEnv.sql(sqlQuery)
+
+    val expected = "b,b,1 or 2"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test
+  def testCaseWithNull(): Unit = {
+    if (!config.getNullCheck) {
+      return
+    }
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val sqlQuery = "SELECT " +
+      "CASE WHEN 'a'='a' THEN 1 END," +
+      "CASE 2 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END," +
+      "CASE a WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END," +
+      "CASE b WHEN 1 THEN 11 WHEN 2 THEN 4 ELSE NULL END," +
+      "CASE 42 WHEN 1 THEN 'a' WHEN 2 THEN 'bcd' END," +
+      "CASE 1 WHEN 1 THEN true WHEN 2 THEN false ELSE NULL END" +
+      " FROM MyTable"
+
+    val ds = env.fromElements((1, 0))
+    tEnv.registerDataSet("MyTable", ds, 'a, 'b)
+
+    val result = tEnv.sql(sqlQuery)
+
+    val expected = "1,bcd,11,null,null,true"
+    val results = result.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/4be297ec/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/ExpressionsITCase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/ExpressionsITCase.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/ExpressionsITCase.scala
index 0e07eaa..59b835c 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/ExpressionsITCase.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/ExpressionsITCase.scala
@@ -129,6 +129,35 @@ class ExpressionsITCase(
     }
   }
 
+  @Test
+  def testEval(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val t = env.fromElements((5, true)).toTable(tEnv, 'a, 'b)
+      .select(
+        ('b && true).eval("true", "false"),
+        false.eval("true", "false"),
+        true.eval(true.eval(true.eval(10, 4), 4), 4))
+
+    val expected = "true,false,10"
+    val results = t.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
+  @Test(expected = classOf[IllegalArgumentException])
+  def testEvalInvalidTypes(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    val tEnv = TableEnvironment.getTableEnvironment(env, config)
+
+    val t = env.fromElements((5, true)).toTable(tEnv, 'a, 'b)
+      .select(('b && true).eval(5, "false"))
+
+    val expected = "true,false,3,10"
+    val results = t.toDataSet[Row].collect()
+    TestBaseUtils.compareResultAsText(results.asJava, expected)
+  }
+
   // Date literals not yet supported
   @Ignore
   @Test


Mime
View raw message