Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id B8AA2200BEE for ; Fri, 16 Dec 2016 16:46:38 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id B7516160B10; Fri, 16 Dec 2016 15:46:38 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id C2FE1160B4A for ; Fri, 16 Dec 2016 16:46:33 +0100 (CET) Received: (qmail 77438 invoked by uid 500); 16 Dec 2016 15:46:32 -0000 Mailing-List: contact commits-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list commits@flink.apache.org Received: (qmail 76057 invoked by uid 99); 16 Dec 2016 15:46:31 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 16 Dec 2016 15:46:31 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id A8FACDFCC8; Fri, 16 Dec 2016 15:46:31 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: fhueske@apache.org To: commits@flink.apache.org Date: Fri, 16 Dec 2016 15:46:51 -0000 Message-Id: <01bbfd8da68f4137b59393b8e71fef31@git.apache.org> In-Reply-To: <530cbac4fe6344d0a479ffe65f9dcffb@git.apache.org> References: <530cbac4fe6344d0a479ffe65f9dcffb@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [22/47] flink git commit: [FLINK-4704] [table] Refactor package structure of flink-table. archived-at: Fri, 16 Dec 2016 15:46:38 -0000 http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala new file mode 100644 index 0000000..b31367c --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.expressions + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.typeutils.TypeCheckUtils +import org.apache.flink.table.validate._ + +case class Abs(child: Expression) extends UnaryExpression { + override private[flink] def resultType: TypeInformation[_] = child.resultType + + override private[flink] def validateInput(): ValidationResult = + TypeCheckUtils.assertNumericExpr(child.resultType, "Abs") + + override def toString: String = s"abs($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.ABS, child.toRexNode) + } +} + +case class Ceil(child: Expression) extends UnaryExpression { + override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = + TypeCheckUtils.assertNumericExpr(child.resultType, "Ceil") + + override def toString: String = s"ceil($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.CEIL, child.toRexNode) + } +} + +case class Exp(child: Expression) extends UnaryExpression with InputTypeSpec { + override private[flink] def resultType: TypeInformation[_] = DOUBLE_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = DOUBLE_TYPE_INFO :: Nil + + override def toString: String = s"exp($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.EXP, child.toRexNode) + } +} + + +case class Floor(child: Expression) extends UnaryExpression { + override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = + TypeCheckUtils.assertNumericExpr(child.resultType, "Floor") + + override def toString: String = s"floor($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.FLOOR, child.toRexNode) + } +} + +case class Log10(child: Expression) extends UnaryExpression with InputTypeSpec { + override private[flink] def resultType: TypeInformation[_] = DOUBLE_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = DOUBLE_TYPE_INFO :: Nil + + override def toString: String = s"log10($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.LOG10, child.toRexNode) + } +} + +case class Ln(child: Expression) extends UnaryExpression with InputTypeSpec { + override private[flink] def resultType: TypeInformation[_] = DOUBLE_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = DOUBLE_TYPE_INFO :: Nil + + override def toString: String = s"ln($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.LN, child.toRexNode) + } +} + +case class Power(left: Expression, right: Expression) extends BinaryExpression with InputTypeSpec { + override private[flink] def resultType: TypeInformation[_] = DOUBLE_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = + DOUBLE_TYPE_INFO :: DOUBLE_TYPE_INFO :: Nil + + override def toString: String = s"pow($left, $right)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.POWER, left.toRexNode, right.toRexNode) + } +} + +case class Sqrt(child: Expression) extends UnaryExpression with InputTypeSpec { + override private[flink] def resultType: TypeInformation[_] = DOUBLE_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = + Seq(DOUBLE_TYPE_INFO) + + override def toString: String = s"sqrt($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.POWER, child.toRexNode, Literal(0.5).toRexNode) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ordering.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ordering.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ordering.scala new file mode 100644 index 0000000..7f03827 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/ordering.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.expressions + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.validate._ + +abstract class Ordering extends UnaryExpression { + override private[flink] def validateInput(): ValidationResult = { + if (!child.isInstanceOf[NamedExpression]) { + ValidationFailure(s"Sort should only based on field reference") + } else { + ValidationSuccess + } + } +} + +case class Asc(child: Expression) extends Ordering { + override def toString: String = s"($child).asc" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + child.toRexNode + } + + override private[flink] def resultType: TypeInformation[_] = child.resultType +} + +case class Desc(child: Expression) extends Ordering { + override def toString: String = s"($child).desc" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.desc(child.toRexNode) + } + + override private[flink] def resultType: TypeInformation[_] = child.resultType +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/package.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/package.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/package.scala new file mode 100644 index 0000000..41e0c9f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/package.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table + +/** + * This package contains the base class of AST nodes and all the expression language AST classes. + * Expression trees should not be manually constructed by users. They are implicitly constructed + * from the implicit DSL conversions in + * [[org.apache.flink.table.api.scala.ImplicitExpressionConversions]] and + * [[org.apache.flink.table.api.scala.ImplicitExpressionOperations]]. For the Java API, + * expression trees should be generated from a string parser that parses expressions and creates + * AST nodes. + */ +package object expressions http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala new file mode 100644 index 0000000..f4b58cc --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/stringExpressions.scala @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.expressions + +import scala.collection.JavaConversions._ +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.expressions.TrimMode.TrimMode +import org.apache.flink.table.validate._ + +/** + * Returns the length of this `str`. + */ +case class CharLength(child: Expression) extends UnaryExpression { + override private[flink] def resultType: TypeInformation[_] = INT_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + if (child.resultType == STRING_TYPE_INFO) { + ValidationSuccess + } else { + ValidationFailure(s"CharLength operator requires String input, " + + s"but $child is of type ${child.resultType}") + } + } + + override def toString: String = s"($child).charLength()" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.CHAR_LENGTH, child.toRexNode) + } +} + +/** + * Returns str with the first letter of each word in uppercase. + * All other letters are in lowercase. Words are delimited by white space. + */ +case class InitCap(child: Expression) extends UnaryExpression { + override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + if (child.resultType == STRING_TYPE_INFO) { + ValidationSuccess + } else { + ValidationFailure(s"InitCap operator requires String input, " + + s"but $child is of type ${child.resultType}") + } + } + + override def toString: String = s"($child).initCap()" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.INITCAP, child.toRexNode) + } +} + +/** + * Returns true if `str` matches `pattern`. + */ +case class Like(str: Expression, pattern: Expression) extends BinaryExpression { + private[flink] def left: Expression = str + private[flink] def right: Expression = pattern + + override private[flink] def resultType: TypeInformation[_] = BOOLEAN_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + if (str.resultType == STRING_TYPE_INFO && pattern.resultType == STRING_TYPE_INFO) { + ValidationSuccess + } else { + ValidationFailure(s"Like operator requires (String, String) input, " + + s"but ($str, $pattern) is of type (${str.resultType}, ${pattern.resultType})") + } + } + + override def toString: String = s"($str).like($pattern)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.LIKE, children.map(_.toRexNode)) + } +} + +/** + * Returns str with all characters changed to lowercase. + */ +case class Lower(child: Expression) extends UnaryExpression { + override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + if (child.resultType == STRING_TYPE_INFO) { + ValidationSuccess + } else { + ValidationFailure(s"Lower operator requires String input, " + + s"but $child is of type ${child.resultType}") + } + } + + override def toString: String = s"($child).toLowerCase()" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.LOWER, child.toRexNode) + } +} + +/** + * Returns true if `str` is similar to `pattern`. + */ +case class Similar(str: Expression, pattern: Expression) extends BinaryExpression { + private[flink] def left: Expression = str + private[flink] def right: Expression = pattern + + override private[flink] def resultType: TypeInformation[_] = BOOLEAN_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + if (str.resultType == STRING_TYPE_INFO && pattern.resultType == STRING_TYPE_INFO) { + ValidationSuccess + } else { + ValidationFailure(s"Similar operator requires (String, String) input, " + + s"but ($str, $pattern) is of type (${str.resultType}, ${pattern.resultType})") + } + } + + override def toString: String = s"($str).similarTo($pattern)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.SIMILAR_TO, children.map(_.toRexNode)) + } +} + +/** + * Returns substring of `str` from `begin`(inclusive) for `length`. + */ +case class Substring( + str: Expression, + begin: Expression, + length: Expression) extends Expression with InputTypeSpec { + + def this(str: Expression, begin: Expression) = this(str, begin, CharLength(str)) + + override private[flink] def children: Seq[Expression] = str :: begin :: length :: Nil + + override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = + Seq(STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO) + + override def toString: String = s"($str).substring($begin, $length)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.SUBSTRING, children.map(_.toRexNode)) + } +} + +/** + * Trim `trimString` from `str` according to `trimMode`. + */ +case class Trim( + trimMode: Expression, + trimString: Expression, + str: Expression) extends Expression { + + override private[flink] def children: Seq[Expression] = trimMode :: trimString :: str :: Nil + + override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + trimMode match { + case SymbolExpression(_: TrimMode) => + if (trimString.resultType != STRING_TYPE_INFO) { + ValidationFailure(s"String expected for trimString, get ${trimString.resultType}") + } else if (str.resultType != STRING_TYPE_INFO) { + ValidationFailure(s"String expected for str, get ${str.resultType}") + } else { + ValidationSuccess + } + case _ => ValidationFailure("TrimMode symbol expected.") + } + } + + override def toString: String = s"($str).trim($trimMode, $trimString)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.TRIM, children.map(_.toRexNode)) + } +} + +/** + * Enumeration of trim flags. + */ +object TrimConstants { + val TRIM_DEFAULT_CHAR = Literal(" ") +} + +/** + * Returns str with all characters changed to uppercase. + */ +case class Upper(child: Expression) extends UnaryExpression with InputTypeSpec { + + override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = + Seq(STRING_TYPE_INFO) + + override def toString: String = s"($child).upperCase()" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.UPPER, child.toRexNode) + } +} + +/** + * Returns the position of string needle in string haystack. + */ +case class Position(needle: Expression, haystack: Expression) + extends Expression with InputTypeSpec { + + override private[flink] def children: Seq[Expression] = Seq(needle, haystack) + + override private[flink] def resultType: TypeInformation[_] = INT_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = + Seq(STRING_TYPE_INFO, STRING_TYPE_INFO) + + override def toString: String = s"($needle).position($haystack)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.POSITION, needle.toRexNode, haystack.toRexNode) + } +} + +/** + * Replaces a substring of a string with a replacement string. + * Starting at a position for a given length. + */ +case class Overlay( + str: Expression, + replacement: Expression, + starting: Expression, + position: Expression) + extends Expression with InputTypeSpec { + + def this(str: Expression, replacement: Expression, starting: Expression) = + this(str, replacement, starting, CharLength(replacement)) + + override private[flink] def children: Seq[Expression] = + Seq(str, replacement, starting, position) + + override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = + Seq(STRING_TYPE_INFO, STRING_TYPE_INFO, INT_TYPE_INFO, INT_TYPE_INFO) + + override def toString: String = s"($str).overlay($replacement, $starting, $position)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call( + SqlStdOperatorTable.OVERLAY, + str.toRexNode, + replacement.toRexNode, + starting.toRexNode, + position.toRexNode) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/symbols.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/symbols.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/symbols.scala new file mode 100644 index 0000000..0d71fb2 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/symbols.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.calcite.avatica.util.{TimeUnit, TimeUnitRange} +import org.apache.calcite.rex.RexNode +import org.apache.calcite.sql.fun.SqlTrimFunction +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.TypeInformation + +import scala.language.{existentials, implicitConversions} + +/** + * General expression class to represent a symbol. + */ +case class SymbolExpression(symbol: TableSymbol) extends LeafExpression { + + override private[flink] def resultType: TypeInformation[_] = + throw new UnsupportedOperationException("This should not happen. A symbol has no result type.") + + def toExpr = this // triggers implicit conversion + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + // dirty hack to pass Java enums to Java from Scala + val enum = symbol.enum.asInstanceOf[Enum[T] forSome { type T <: Enum[T] }] + relBuilder.getRexBuilder.makeFlag(enum) + } + + override def toString: String = s"${symbol.symbols}.${symbol.name}" + +} + +/** + * Symbol that wraps a Calcite symbol in form of a Java enum. + */ +trait TableSymbol { + def symbols: TableSymbols + def name: String + def enum: Enum[_] +} + +/** + * Enumeration of symbols. + */ +abstract class TableSymbols extends Enumeration { + + class TableSymbolValue(e: Enum[_]) extends Val(e.name()) with TableSymbol { + override def symbols: TableSymbols = TableSymbols.this + + override def enum: Enum[_] = e + + override def name: String = toString() + } + + protected final def Value(enum: Enum[_]): TableSymbolValue = new TableSymbolValue(enum) + + implicit def symbolToExpression(symbol: TableSymbolValue): SymbolExpression = + SymbolExpression(symbol) + +} + +/** + * Units for working with time intervals. + */ +object TimeIntervalUnit extends TableSymbols { + + type TimeIntervalUnit = TableSymbolValue + + val YEAR = Value(TimeUnitRange.YEAR) + val YEAR_TO_MONTH = Value(TimeUnitRange.YEAR_TO_MONTH) + val MONTH = Value(TimeUnitRange.MONTH) + val DAY = Value(TimeUnitRange.DAY) + val DAY_TO_HOUR = Value(TimeUnitRange.DAY_TO_HOUR) + val DAY_TO_MINUTE = Value(TimeUnitRange.DAY_TO_MINUTE) + val DAY_TO_SECOND = Value(TimeUnitRange.DAY_TO_SECOND) + val HOUR = Value(TimeUnitRange.HOUR) + val HOUR_TO_MINUTE = Value(TimeUnitRange.HOUR_TO_MINUTE) + val HOUR_TO_SECOND = Value(TimeUnitRange.HOUR_TO_SECOND) + val MINUTE = Value(TimeUnitRange.MINUTE) + val MINUTE_TO_SECOND = Value(TimeUnitRange.MINUTE_TO_SECOND) + val SECOND = Value(TimeUnitRange.SECOND) + +} + +/** + * Units for working with time points. + */ +object TimePointUnit extends TableSymbols { + + type TimePointUnit = TableSymbolValue + + val YEAR = Value(TimeUnit.YEAR) + val MONTH = Value(TimeUnit.MONTH) + val DAY = Value(TimeUnit.DAY) + val HOUR = Value(TimeUnit.HOUR) + val MINUTE = Value(TimeUnit.MINUTE) + val SECOND = Value(TimeUnit.SECOND) + val QUARTER = Value(TimeUnit.QUARTER) + val WEEK = Value(TimeUnit.WEEK) + val MILLISECOND = Value(TimeUnit.MILLISECOND) + val MICROSECOND = Value(TimeUnit.MICROSECOND) + +} + +/** + * Modes for trimming strings. + */ +object TrimMode extends TableSymbols { + + type TrimMode = TableSymbolValue + + val BOTH = Value(SqlTrimFunction.Flag.BOTH) + val LEADING = Value(SqlTrimFunction.Flag.LEADING) + val TRAILING = Value(SqlTrimFunction.Flag.TRAILING) + +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/time.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/time.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/time.scala new file mode 100644 index 0000000..f09e2ad --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/time.scala @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.calcite.avatica.util.{TimeUnit, TimeUnitRange} +import org.apache.calcite.rex._ +import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.calcite.sql.fun.SqlStdOperatorTable +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ +import org.apache.flink.api.common.typeinfo.{SqlTimeTypeInfo, TypeInformation} +import org.apache.flink.table.calcite.FlinkRelBuilder +import org.apache.flink.table.expressions.ExpressionUtils.{divide, getFactor, mod} +import org.apache.flink.table.expressions.TimeIntervalUnit.TimeIntervalUnit +import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval +import org.apache.flink.table.typeutils.{TimeIntervalTypeInfo, TypeCheckUtils} +import org.apache.flink.table.validate.{ValidationResult, ValidationFailure, ValidationSuccess} + +import scala.collection.JavaConversions._ + +case class Extract(timeIntervalUnit: Expression, temporal: Expression) extends Expression { + + override private[flink] def children: Seq[Expression] = timeIntervalUnit :: temporal :: Nil + + override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + if (!TypeCheckUtils.isTemporal(temporal.resultType)) { + return ValidationFailure(s"Extract operator requires Temporal input, " + + s"but $temporal is of type ${temporal.resultType}") + } + + timeIntervalUnit match { + case SymbolExpression(TimeIntervalUnit.YEAR) + | SymbolExpression(TimeIntervalUnit.MONTH) + | SymbolExpression(TimeIntervalUnit.DAY) + if temporal.resultType == SqlTimeTypeInfo.DATE + || temporal.resultType == SqlTimeTypeInfo.TIMESTAMP + || temporal.resultType == TimeIntervalTypeInfo.INTERVAL_MILLIS + || temporal.resultType == TimeIntervalTypeInfo.INTERVAL_MONTHS => + ValidationSuccess + + case SymbolExpression(TimeIntervalUnit.HOUR) + | SymbolExpression(TimeIntervalUnit.MINUTE) + | SymbolExpression(TimeIntervalUnit.SECOND) + if temporal.resultType == SqlTimeTypeInfo.TIME + || temporal.resultType == SqlTimeTypeInfo.TIMESTAMP + || temporal.resultType == TimeIntervalTypeInfo.INTERVAL_MILLIS => + ValidationSuccess + + case _ => + ValidationFailure(s"Extract operator does not support unit '$timeIntervalUnit' for input" + + s" of type '${temporal.resultType}'.") + } + } + + override def toString: String = s"($temporal).extract($timeIntervalUnit)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + // get wrapped Calcite unit + val timeUnitRange = timeIntervalUnit + .asInstanceOf[SymbolExpression] + .symbol + .enum + .asInstanceOf[TimeUnitRange] + + // convert RexNodes + convertExtract( + timeIntervalUnit.toRexNode, + timeUnitRange, + temporal.toRexNode, + relBuilder.asInstanceOf[FlinkRelBuilder]) + } + + /** + * Standard conversion of the EXTRACT operator. + * Source: [[org.apache.calcite.sql2rel.StandardConvertletTable#convertExtract()]] + */ + private def convertExtract( + timeUnitRangeRexNode: RexNode, + timeUnitRange: TimeUnitRange, + temporal: RexNode, + relBuilder: FlinkRelBuilder) + : RexNode = { + + // TODO convert this into Table API expressions to make the code more readable + val rexBuilder = relBuilder.getRexBuilder + val resultType = relBuilder.getTypeFactory().createTypeFromTypeInfo(LONG_TYPE_INFO) + var result = rexBuilder.makeReinterpretCast( + resultType, + temporal, + rexBuilder.makeLiteral(false)) + + val unit = timeUnitRange.startUnit + val sqlTypeName = temporal.getType.getSqlTypeName + unit match { + case TimeUnit.YEAR | TimeUnit.MONTH | TimeUnit.DAY => + sqlTypeName match { + case SqlTypeName.TIMESTAMP => + result = divide(rexBuilder, result, TimeUnit.DAY.multiplier) + return rexBuilder.makeCall( + resultType, + SqlStdOperatorTable.EXTRACT_DATE, + Seq(timeUnitRangeRexNode, result)) + + case SqlTypeName.DATE => + return rexBuilder.makeCall( + resultType, + SqlStdOperatorTable.EXTRACT_DATE, + Seq(timeUnitRangeRexNode, result)) + + case _ => // do nothing + } + + case _ => // do nothing + } + + result = mod(rexBuilder, resultType, result, getFactor(unit)) + result = divide(rexBuilder, result, unit.multiplier) + result + } +} + +abstract class TemporalCeilFloor( + timeIntervalUnit: Expression, + temporal: Expression) + extends Expression { + + override private[flink] def children: Seq[Expression] = timeIntervalUnit :: temporal :: Nil + + override private[flink] def resultType: TypeInformation[_] = temporal.resultType + + override private[flink] def validateInput(): ValidationResult = { + if (!TypeCheckUtils.isTimePoint(temporal.resultType)) { + return ValidationFailure(s"Temporal ceil/floor operator requires Time Point input, " + + s"but $temporal is of type ${temporal.resultType}") + } + val unit = timeIntervalUnit match { + case SymbolExpression(u: TimeIntervalUnit) => Some(u) + case _ => None + } + if (unit.isEmpty) { + return ValidationFailure(s"Temporal ceil/floor operator requires Time Interval Unit " + + s"input, but $timeIntervalUnit is of type ${timeIntervalUnit.resultType}") + } + + (unit.get, temporal.resultType) match { + case (TimeIntervalUnit.YEAR | TimeIntervalUnit.MONTH, + SqlTimeTypeInfo.DATE | SqlTimeTypeInfo.TIMESTAMP) => + ValidationSuccess + case (TimeIntervalUnit.DAY, SqlTimeTypeInfo.TIMESTAMP) => + ValidationSuccess + case (TimeIntervalUnit.HOUR | TimeIntervalUnit.MINUTE | TimeIntervalUnit.SECOND, + SqlTimeTypeInfo.TIME | SqlTimeTypeInfo.TIMESTAMP) => + ValidationSuccess + case _ => + ValidationFailure(s"Temporal ceil/floor operator does not support " + + s"unit '$timeIntervalUnit' for input of type '${temporal.resultType}'.") + } + } +} + +case class TemporalFloor( + timeIntervalUnit: Expression, + temporal: Expression) + extends TemporalCeilFloor( + timeIntervalUnit, + temporal) { + + override def toString: String = s"($temporal).floor($timeIntervalUnit)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.FLOOR, temporal.toRexNode, timeIntervalUnit.toRexNode) + } +} + +case class TemporalCeil( + timeIntervalUnit: Expression, + temporal: Expression) + extends TemporalCeilFloor( + timeIntervalUnit, + temporal) { + + override def toString: String = s"($temporal).ceil($timeIntervalUnit)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + relBuilder.call(SqlStdOperatorTable.CEIL, temporal.toRexNode, timeIntervalUnit.toRexNode) + } +} + +abstract class CurrentTimePoint( + targetType: TypeInformation[_], + local: Boolean) + extends LeafExpression { + + override private[flink] def resultType: TypeInformation[_] = targetType + + override private[flink] def validateInput(): ValidationResult = { + if (!TypeCheckUtils.isTimePoint(targetType)) { + ValidationFailure(s"CurrentTimePoint operator requires Time Point target type, " + + s"but get $targetType.") + } else if (local && targetType == SqlTimeTypeInfo.DATE) { + ValidationFailure(s"Localized CurrentTimePoint operator requires Time or Timestamp target " + + s"type, but get $targetType.") + } else { + ValidationSuccess + } + } + + override def toString: String = if (local) { + s"local$targetType()" + } else { + s"current$targetType()" + } + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + val operator = targetType match { + case SqlTimeTypeInfo.TIME if local => SqlStdOperatorTable.LOCALTIME + case SqlTimeTypeInfo.TIMESTAMP if local => SqlStdOperatorTable.LOCALTIMESTAMP + case SqlTimeTypeInfo.DATE => SqlStdOperatorTable.CURRENT_DATE + case SqlTimeTypeInfo.TIME => SqlStdOperatorTable.CURRENT_TIME + case SqlTimeTypeInfo.TIMESTAMP => SqlStdOperatorTable.CURRENT_TIMESTAMP + } + relBuilder.call(operator) + } +} + +case class CurrentDate() extends CurrentTimePoint(SqlTimeTypeInfo.DATE, local = false) + +case class CurrentTime() extends CurrentTimePoint(SqlTimeTypeInfo.TIME, local = false) + +case class CurrentTimestamp() extends CurrentTimePoint(SqlTimeTypeInfo.TIMESTAMP, local = false) + +case class LocalTime() extends CurrentTimePoint(SqlTimeTypeInfo.TIME, local = true) + +case class LocalTimestamp() extends CurrentTimePoint(SqlTimeTypeInfo.TIMESTAMP, local = true) + +/** + * Extracts the quarter of a year from a SQL date. + */ +case class Quarter(child: Expression) extends UnaryExpression with InputTypeSpec { + + override private[flink] def expectedTypes: Seq[TypeInformation[_]] = Seq(SqlTimeTypeInfo.DATE) + + override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO + + override def toString: String = s"($child).quarter()" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + /** + * Standard conversion of the QUARTER operator. + * Source: [[org.apache.calcite.sql2rel.StandardConvertletTable#convertQuarter()]] + */ + Plus( + Div( + Minus( + Extract(TimeIntervalUnit.MONTH, child), + Literal(1L)), + Literal(TimeUnit.QUARTER.multiplier.longValue())), + Literal(1L) + ).toRexNode + } +} + +/** + * Determines whether two anchored time intervals overlap. + */ +case class TemporalOverlaps( + leftTimePoint: Expression, + leftTemporal: Expression, + rightTimePoint: Expression, + rightTemporal: Expression) + extends Expression { + + override private[flink] def children: Seq[Expression] = + Seq(leftTimePoint, leftTemporal, rightTimePoint, rightTemporal) + + override private[flink] def resultType: TypeInformation[_] = BOOLEAN_TYPE_INFO + + override private[flink] def validateInput(): ValidationResult = { + if (!TypeCheckUtils.isTimePoint(leftTimePoint.resultType)) { + return ValidationFailure(s"TemporalOverlaps operator requires leftTimePoint to be of type " + + s"Time Point, but get ${leftTimePoint.resultType}.") + } + if (!TypeCheckUtils.isTimePoint(rightTimePoint.resultType)) { + return ValidationFailure(s"TemporalOverlaps operator requires rightTimePoint to be of " + + s"type Time Point, but get ${rightTimePoint.resultType}.") + } + if (leftTimePoint.resultType != rightTimePoint.resultType) { + return ValidationFailure(s"TemporalOverlaps operator requires leftTimePoint and " + + s"rightTimePoint to be of same type.") + } + + // leftTemporal is point, then it must be comparable with leftTimePoint + if (TypeCheckUtils.isTimePoint(leftTemporal.resultType)) { + if (leftTemporal.resultType != leftTimePoint.resultType) { + return ValidationFailure(s"TemporalOverlaps operator requires leftTemporal and " + + s"leftTimePoint to be of same type if leftTemporal is of type Time Point.") + } + } else if (!isTimeInterval(leftTemporal.resultType)) { + return ValidationFailure(s"TemporalOverlaps operator requires leftTemporal to be of " + + s"type Time Point or Time Interval.") + } + + // rightTemporal is point, then it must be comparable with rightTimePoint + if (TypeCheckUtils.isTimePoint(rightTemporal.resultType)) { + if (rightTemporal.resultType != rightTimePoint.resultType) { + return ValidationFailure(s"TemporalOverlaps operator requires rightTemporal and " + + s"rightTimePoint to be of same type if rightTemporal is of type Time Point.") + } + } else if (!isTimeInterval(rightTemporal.resultType)) { + return ValidationFailure(s"TemporalOverlaps operator requires rightTemporal to be of " + + s"type Time Point or Time Interval.") + } + ValidationSuccess + } + + override def toString: String = s"temporalOverlaps(${children.mkString(", ")})" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = { + convertOverlaps( + leftTimePoint.toRexNode, + leftTemporal.toRexNode, + rightTimePoint.toRexNode, + rightTemporal.toRexNode, + relBuilder.asInstanceOf[FlinkRelBuilder]) + } + + /** + * Standard conversion of the OVERLAPS operator. + * Source: [[org.apache.calcite.sql2rel.StandardConvertletTable#convertOverlaps()]] + */ + private def convertOverlaps( + leftP: RexNode, + leftT: RexNode, + rightP: RexNode, + rightT: RexNode, + relBuilder: FlinkRelBuilder) + : RexNode = { + // leftT = leftP + leftT if leftT is an interval + val convLeftT = if (isTimeInterval(leftTemporal.resultType)) { + relBuilder.call(SqlStdOperatorTable.DATETIME_PLUS, leftP, leftT) + } else { + leftT + } + // rightT = rightP + rightT if rightT is an interval + val convRightT = if (isTimeInterval(rightTemporal.resultType)) { + relBuilder.call(SqlStdOperatorTable.DATETIME_PLUS, rightP, rightT) + } else { + rightT + } + // leftT >= rightP + val leftPred = relBuilder.call(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, convLeftT, rightP) + // rightT >= leftP + val rightPred = relBuilder.call(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, convRightT, leftP) + + // leftT >= rightP and rightT >= leftP + relBuilder.call(SqlStdOperatorTable.AND, leftPred, rightPred) + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/windowProperties.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/windowProperties.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/windowProperties.scala new file mode 100644 index 0000000..990d928 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/expressions/windowProperties.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions + +import org.apache.calcite.rex.RexNode +import org.apache.calcite.tools.RelBuilder +import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo +import org.apache.flink.table.calcite.FlinkRelBuilder +import FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.table.validate.{ValidationFailure, ValidationSuccess} + +abstract class WindowProperty(child: Expression) extends UnaryExpression { + + override def toString = s"WindowProperty($child)" + + override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = + throw new UnsupportedOperationException("WindowProperty cannot be transformed to RexNode.") + + override private[flink] def validateInput() = + if (child.isInstanceOf[WindowReference]) { + ValidationSuccess + } else { + ValidationFailure("Child must be a window reference.") + } + + private[flink] def toNamedWindowProperty(name: String)(implicit relBuilder: RelBuilder) + : NamedWindowProperty = NamedWindowProperty(name, this) +} + +case class WindowStart(child: Expression) extends WindowProperty(child) { + + override private[flink] def resultType = SqlTimeTypeInfo.TIMESTAMP + + override def toString: String = s"start($child)" +} + +case class WindowEnd(child: Expression) extends WindowProperty(child) { + + override private[flink] def resultType = SqlTimeTypeInfo.TIMESTAMP + + override def toString: String = s"end($child)" +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala new file mode 100644 index 0000000..d01cf68 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/ScalarFunction.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions + +import org.apache.flink.api.common.functions.InvalidTypesException +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.TypeExtractor +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.expressions.{Expression, ScalarFunctionCall} + +/** + * Base class for a user-defined scalar function. A user-defined scalar functions maps zero, one, + * or multiple scalar values to a new scalar value. + * + * The behavior of a [[ScalarFunction]] can be defined by implementing a custom evaluation + * method. An evaluation method must be declared publicly and named "eval". Evaluation methods + * can also be overloaded by implementing multiple methods named "eval". + * + * User-defined functions must have a default constructor and must be instantiable during runtime. + * + * By default the result type of an evaluation method is determined by Flink's type extraction + * facilities. This is sufficient for basic types or simple POJOs but might be wrong for more + * complex, custom, or composite types. In these cases [[TypeInformation]] of the result type + * can be manually defined by overriding [[getResultType()]]. + * + * Internally, the Table/SQL API code generation works with primitive values as much as possible. + * If a user-defined scalar function should not introduce much overhead during runtime, it is + * recommended to declare parameters and result types as primitive types instead of their boxed + * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long. + */ +abstract class ScalarFunction extends UserDefinedFunction { + + /** + * Creates a call to a [[ScalarFunction]] in Scala Table API. + * + * @param params actual parameters of function + * @return [[Expression]] in form of a [[ScalarFunctionCall]] + */ + final def apply(params: Expression*): Expression = { + ScalarFunctionCall(this, params) + } + + override def toString: String = getClass.getCanonicalName + + // ---------------------------------------------------------------------------------------------- + + /** + * Returns the result type of the evaluation method with a given signature. + * + * This method needs to be overriden in case Flink's type extraction facilities are not + * sufficient to extract the [[TypeInformation]] based on the return type of the evaluation + * method. Flink's type extraction facilities can handle basic types or + * simple POJOs but might be wrong for more complex, custom, or composite types. + * + * @param signature signature of the method the return type needs to be determined + * @return [[TypeInformation]] of result type or null if Flink should determine the type + */ + def getResultType(signature: Array[Class[_]]): TypeInformation[_] = null + + /** + * Returns [[TypeInformation]] about the operands of the evaluation method with a given + * signature. + * + * In order to perform operand type inference in SQL (especially when NULL is used) it might be + * necessary to determine the parameter [[TypeInformation]] of an evaluation method. + * By default Flink's type extraction facilities are used for this but might be wrong for + * more complex, custom, or composite types. + * + * @param signature signature of the method the operand types need to be determined + * @return [[TypeInformation]] of operand types + */ + def getParameterTypes(signature: Array[Class[_]]): Array[TypeInformation[_]] = { + signature.map { c => + try { + TypeExtractor.getForClass(c) + } catch { + case ite: InvalidTypesException => + throw new ValidationException( + s"Parameter types of scalar function '${this.getClass.getCanonicalName}' cannot be " + + s"automatically determined. Please provide type information manually.") + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala new file mode 100644 index 0000000..653793e --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions + +import java.util + +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.expressions.{Expression, TableFunctionCall} + +/** + * Base class for a user-defined table function (UDTF). A user-defined table functions works on + * zero, one, or multiple scalar values as input and returns multiple rows as output. + * + * The behavior of a [[TableFunction]] can be defined by implementing a custom evaluation + * method. An evaluation method must be declared publicly, not static and named "eval". + * Evaluation methods can also be overloaded by implementing multiple methods named "eval". + * + * User-defined functions must have a default constructor and must be instantiable during runtime. + * + * By default the result type of an evaluation method is determined by Flink's type extraction + * facilities. This is sufficient for basic types or simple POJOs but might be wrong for more + * complex, custom, or composite types. In these cases [[TypeInformation]] of the result type + * can be manually defined by overriding [[getResultType()]]. + * + * Internally, the Table/SQL API code generation works with primitive values as much as possible. + * If a user-defined table function should not introduce much overhead during runtime, it is + * recommended to declare parameters and result types as primitive types instead of their boxed + * classes. DATE/TIME is equal to int, TIMESTAMP is equal to long. + * + * Example: + * + * {{{ + * + * public class Split extends TableFunction { + * + * // implement an "eval" method with as many parameters as you want + * public void eval(String str) { + * for (String s : str.split(" ")) { + * collect(s); // use collect(...) to emit an output row + * } + * } + * + * // you can overload the eval method here ... + * } + * + * val tEnv: TableEnvironment = ... + * val table: Table = ... // schema: [a: String] + * + * // for Scala users + * val split = new Split() + * table.join(split('c) as ('s)).select('a, 's) + * + * // for Java users + * tEnv.registerFunction("split", new Split()) // register table function first + * table.join("split(a) as (s)").select("a, s") + * + * // for SQL users + * tEnv.registerFunction("split", new Split()) // register table function first + * tEnv.sql("SELECT a, s FROM MyTable, LATERAL TABLE(split(a)) as T(s)") + * + * }}} + * + * @tparam T The type of the output row + */ +abstract class TableFunction[T] extends UserDefinedFunction { + + /** + * Creates a call to a [[TableFunction]] in Scala Table API. + * + * @param params actual parameters of function + * @return [[Expression]] in form of a [[TableFunctionCall]] + */ + final def apply(params: Expression*)(implicit typeInfo: TypeInformation[T]): Expression = { + val resultType = if (getResultType == null) { + typeInfo + } else { + getResultType + } + TableFunctionCall(getClass.getSimpleName, this, params, resultType) + } + + override def toString: String = getClass.getCanonicalName + + // ---------------------------------------------------------------------------------------------- + + private val rows: util.ArrayList[T] = new util.ArrayList[T]() + + /** + * Emit an output row. + * + * @param row the output row + */ + protected def collect(row: T): Unit = { + // cache rows for now, maybe immediately process them further + rows.add(row) + } + + /** + * Internal use. Get an iterator of the buffered rows. + */ + def getRowsIterator = rows.iterator() + + /** + * Internal use. Clear buffered rows. + */ + def clear() = rows.clear() + + // ---------------------------------------------------------------------------------------------- + + /** + * Returns the result type of the evaluation method with a given signature. + * + * This method needs to be overriden in case Flink's type extraction facilities are not + * sufficient to extract the [[TypeInformation]] based on the return type of the evaluation + * method. Flink's type extraction facilities can handle basic types or + * simple POJOs but might be wrong for more complex, custom, or composite types. + * + * @return [[TypeInformation]] of result type or null if Flink should determine the type + */ + def getResultType: TypeInformation[T] = null + +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala new file mode 100644 index 0000000..b99ab8d --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/UserDefinedFunction.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions + +/** + * Base class for all user-defined functions such as scalar functions, table functions, + * or aggregation functions. + * + * User-defined functions must have a default constructor and must be instantiable during runtime. + */ +trait UserDefinedFunction { +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/MathFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/MathFunctions.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/MathFunctions.scala new file mode 100644 index 0000000..64e4bc4 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/MathFunctions.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import java.math.{BigDecimal => JBigDecimal} + +class MathFunctions {} + +object MathFunctions { + def power(a: Double, b: JBigDecimal): Double = { + Math.pow(a, b.doubleValue()) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala new file mode 100644 index 0000000..da652e0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`.SqlOperandTypeChecker.Consistency +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.functions.utils.ScalarSqlFunction.{createOperandTypeChecker, createOperandTypeInference, createReturnTypeInference} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.{getResultType, getSignature, getSignatures, signatureToString, signaturesToString} + +import scala.collection.JavaConverters._ + +/** + * Calcite wrapper for user-defined scalar functions. + * + * @param name function name (used by SQL parser) + * @param scalarFunction scalar function to be called + * @param typeFactory type factory for converting Flink's between Calcite's types + */ +class ScalarSqlFunction( + name: String, + scalarFunction: ScalarFunction, + typeFactory: FlinkTypeFactory) + extends SqlFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + createReturnTypeInference(name, scalarFunction, typeFactory), + createOperandTypeInference(scalarFunction, typeFactory), + createOperandTypeChecker(name, scalarFunction), + null, + SqlFunctionCategory.USER_DEFINED_FUNCTION) { + + def getScalarFunction = scalarFunction + +} + +object ScalarSqlFunction { + + private[flink] def createReturnTypeInference( + name: String, + scalarFunction: ScalarFunction, + typeFactory: FlinkTypeFactory) + : SqlReturnTypeInference = { + /** + * Return type inference based on [[ScalarFunction]] given information. + */ + new SqlReturnTypeInference { + override def inferReturnType(opBinding: SqlOperatorBinding): RelDataType = { + val parameters = opBinding + .collectOperandTypes() + .asScala + .map { operandType => + if (operandType.getSqlTypeName == SqlTypeName.NULL) { + null + } else { + FlinkTypeFactory.toTypeInfo(operandType) + } + } + val foundSignature = getSignature(scalarFunction, parameters) + if (foundSignature.isEmpty) { + throw new ValidationException( + s"Given parameters of function '$name' do not match any signature. \n" + + s"Actual: ${signatureToString(parameters)} \n" + + s"Expected: ${signaturesToString(scalarFunction)}") + } + val resultType = getResultType(scalarFunction, foundSignature.get) + typeFactory.createTypeFromTypeInfo(resultType) + } + } + } + + private[flink] def createOperandTypeInference( + scalarFunction: ScalarFunction, + typeFactory: FlinkTypeFactory) + : SqlOperandTypeInference = { + /** + * Operand type inference based on [[ScalarFunction]] given information. + */ + new SqlOperandTypeInference { + override def inferOperandTypes( + callBinding: SqlCallBinding, + returnType: RelDataType, + operandTypes: Array[RelDataType]): Unit = { + + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getSignature(scalarFunction, operandTypeInfo) + .getOrElse(throw new ValidationException(s"Operand types of could not be inferred.")) + + val inferredTypes = scalarFunction + .getParameterTypes(foundSignature) + .map(typeFactory.createTypeFromTypeInfo) + + inferredTypes.zipWithIndex.foreach { + case (inferredType, i) => + operandTypes(i) = inferredType + } + } + } + } + + private[flink] def createOperandTypeChecker( + name: String, + scalarFunction: ScalarFunction) + : SqlOperandTypeChecker = { + + val signatures = getSignatures(scalarFunction) + + /** + * Operand type checker based on [[ScalarFunction]] given information. + */ + new SqlOperandTypeChecker { + override def getAllowedSignatures(op: SqlOperator, opName: String): String = { + s"$opName[${signaturesToString(scalarFunction)}]" + } + + override def getOperandCountRange: SqlOperandCountRange = { + val signatureLengths = signatures.map(_.length) + SqlOperandCountRanges.between(signatureLengths.min, signatureLengths.max) + } + + override def checkOperandTypes( + callBinding: SqlCallBinding, + throwOnFailure: Boolean) + : Boolean = { + val operandTypeInfo = getOperandTypeInfo(callBinding) + + val foundSignature = getSignature(scalarFunction, operandTypeInfo) + + if (foundSignature.isEmpty) { + if (throwOnFailure) { + throw new ValidationException( + s"Given parameters of function '$name' do not match any signature. \n" + + s"Actual: ${signatureToString(operandTypeInfo)} \n" + + s"Expected: ${signaturesToString(scalarFunction)}") + } else { + false + } + } else { + true + } + } + + override def isOptional(i: Int): Boolean = false + + override def getConsistency: Consistency = Consistency.NONE + + } + } + + private[flink] def getOperandTypeInfo(callBinding: SqlCallBinding): Seq[TypeInformation[_]] = { + val operandTypes = for (i <- 0 until callBinding.getOperandCount) + yield callBinding.getOperandType(i) + operandTypes.map { operandType => + if (operandType.getSqlTypeName == SqlTypeName.NULL) { + null + } else { + FlinkTypeFactory.toTypeInfo(operandType) + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala new file mode 100644 index 0000000..74f3374 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/TableSqlFunction.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.utils + +import com.google.common.base.Predicate +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql._ +import org.apache.calcite.sql.`type`._ +import org.apache.calcite.sql.parser.SqlParserPos +import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction +import org.apache.calcite.util.Util +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.functions.TableFunction +import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl + +import scala.collection.JavaConverters._ +import java.util + +/** + * Calcite wrapper for user-defined table functions. + */ +class TableSqlFunction( + name: String, + udtf: TableFunction[_], + rowTypeInfo: TypeInformation[_], + returnTypeInference: SqlReturnTypeInference, + operandTypeInference: SqlOperandTypeInference, + operandTypeChecker: SqlOperandTypeChecker, + paramTypes: util.List[RelDataType], + functionImpl: FlinkTableFunctionImpl[_]) + extends SqlUserDefinedTableFunction( + new SqlIdentifier(name, SqlParserPos.ZERO), + returnTypeInference, + operandTypeInference, + operandTypeChecker, + paramTypes, + functionImpl) { + + /** + * Get the user-defined table function. + */ + def getTableFunction = udtf + + /** + * Get the type information of the table returned by the table function. + */ + def getRowTypeInfo = rowTypeInfo + + /** + * Get additional mapping information if the returned table type is a POJO + * (POJO types have no deterministic field order). + */ + def getPojoFieldMapping = functionImpl.fieldIndexes + +} + +object TableSqlFunction { + + /** + * Util function to create a [[TableSqlFunction]]. + * + * @param name function name (used by SQL parser) + * @param udtf user-defined table function to be called + * @param rowTypeInfo the row type information generated by the table function + * @param typeFactory type factory for converting Flink's between Calcite's types + * @param functionImpl Calcite table function schema + * @return [[TableSqlFunction]] + */ + def apply( + name: String, + udtf: TableFunction[_], + rowTypeInfo: TypeInformation[_], + typeFactory: FlinkTypeFactory, + functionImpl: FlinkTableFunctionImpl[_]): TableSqlFunction = { + + val argTypes: util.List[RelDataType] = new util.ArrayList[RelDataType] + val typeFamilies: util.List[SqlTypeFamily] = new util.ArrayList[SqlTypeFamily] + // derives operands' data types and type families + functionImpl.getParameters.asScala.foreach{ o => + val relType: RelDataType = o.getType(typeFactory) + argTypes.add(relType) + typeFamilies.add(Util.first(relType.getSqlTypeName.getFamily, SqlTypeFamily.ANY)) + } + // derives whether the 'input'th parameter of a method is optional. + val optional: Predicate[Integer] = new Predicate[Integer]() { + def apply(input: Integer): Boolean = { + functionImpl.getParameters.get(input).isOptional + } + } + // create type check for the operands + val typeChecker: FamilyOperandTypeChecker = OperandTypes.family(typeFamilies, optional) + + new TableSqlFunction( + name, + udtf, + rowTypeInfo, + ReturnTypes.CURSOR, + InferTypes.explicit(argTypes), + typeChecker, + argTypes, + functionImpl) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ffe9ec8e/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala new file mode 100644 index 0000000..aa3fab0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.table.functions.utils + +import java.lang.reflect.{Method, Modifier} +import java.sql.{Date, Time, Timestamp} + +import com.google.common.primitives.Primitives +import org.apache.calcite.sql.SqlFunction +import org.apache.flink.api.common.functions.InvalidTypesException +import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.typeutils.TypeExtractor +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.api.{TableException, ValidationException} +import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction} +import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl +import org.apache.flink.util.InstantiationUtil + +object UserDefinedFunctionUtils { + + /** + * Instantiates a user-defined function. + */ + def instantiate[T <: UserDefinedFunction](clazz: Class[T]): T = { + val constructor = clazz.getDeclaredConstructor() + constructor.setAccessible(true) + constructor.newInstance() + } + + /** + * Checks if a user-defined function can be easily instantiated. + */ + def checkForInstantiation(clazz: Class[_]): Unit = { + if (!InstantiationUtil.isPublic(clazz)) { + throw ValidationException("Function class is not public.") + } + else if (!InstantiationUtil.isProperClass(clazz)) { + throw ValidationException("Function class is no proper class, it is either abstract," + + " an interface, or a primitive type.") + } + else if (InstantiationUtil.isNonStaticInnerClass(clazz)) { + throw ValidationException("The class is an inner class, but not statically accessible.") + } + + // check for default constructor (can be private) + clazz + .getDeclaredConstructors + .find(_.getParameterTypes.isEmpty) + .getOrElse(throw ValidationException("Function class needs a default constructor.")) + } + + /** + * Check whether this is a Scala object. It is forbidden to use [[TableFunction]] implemented + * by a Scala object, since concurrent risks. + */ + def checkNotSingleton(clazz: Class[_]): Unit = { + // TODO it is not a good way to check singleton. Maybe improve it further. + if (clazz.getFields.map(_.getName) contains "MODULE$") { + throw new ValidationException( + s"TableFunction implemented by class ${clazz.getCanonicalName} " + + s"is a Scala object, it is forbidden since concurrent risks.") + } + } + + // ---------------------------------------------------------------------------------------------- + // Utilities for eval methods + // ---------------------------------------------------------------------------------------------- + + /** + * Returns signatures matching the given signature of [[TypeInformation]]. + * Elements of the signature can be null (act as a wildcard). + */ + def getSignature( + function: UserDefinedFunction, + signature: Seq[TypeInformation[_]]) + : Option[Array[Class[_]]] = { + // We compare the raw Java classes not the TypeInformation. + // TypeInformation does not matter during runtime (e.g. within a MapFunction). + val actualSignature = typeInfoToClass(signature) + val signatures = getSignatures(function) + + signatures + // go over all signatures and find one matching actual signature + .find { curSig => + // match parameters of signature to actual parameters + actualSignature.length == curSig.length && + curSig.zipWithIndex.forall { case (clazz, i) => + parameterTypeEquals(actualSignature(i), clazz) + } + } + } + + /** + * Returns eval method matching the given signature of [[TypeInformation]]. + */ + def getEvalMethod( + function: UserDefinedFunction, + signature: Seq[TypeInformation[_]]) + : Option[Method] = { + // We compare the raw Java classes not the TypeInformation. + // TypeInformation does not matter during runtime (e.g. within a MapFunction). + val actualSignature = typeInfoToClass(signature) + val evalMethods = checkAndExtractEvalMethods(function) + + evalMethods + // go over all eval methods and find one matching + .find { cur => + val signatures = cur.getParameterTypes + // match parameters of signature to actual parameters + actualSignature.length == signatures.length && + signatures.zipWithIndex.forall { case (clazz, i) => + parameterTypeEquals(actualSignature(i), clazz) + } + } + } + + /** + * Extracts "eval" methods and throws a [[ValidationException]] if no implementation + * can be found. + */ + def checkAndExtractEvalMethods(function: UserDefinedFunction): Array[Method] = { + val methods = function + .getClass + .getDeclaredMethods + .filter { m => + val modifiers = m.getModifiers + m.getName == "eval" && + Modifier.isPublic(modifiers) && + !Modifier.isAbstract(modifiers) && + !(function.isInstanceOf[TableFunction[_]] && Modifier.isStatic(modifiers)) + } + + if (methods.isEmpty) { + throw new ValidationException( + s"Function class '${function.getClass.getCanonicalName}' does not implement at least " + + s"one method named 'eval' which is public, not abstract and " + + s"(in case of table functions) not static.") + } else { + methods + } + } + + def getSignatures(function: UserDefinedFunction): Array[Array[Class[_]]] = { + checkAndExtractEvalMethods(function).map(_.getParameterTypes) + } + + // ---------------------------------------------------------------------------------------------- + // Utilities for SQL functions + // ---------------------------------------------------------------------------------------------- + + /** + * Create [[SqlFunction]] for a [[ScalarFunction]] + * + * @param name function name + * @param function scalar function + * @param typeFactory type factory + * @return the ScalarSqlFunction + */ + def createScalarSqlFunction( + name: String, + function: ScalarFunction, + typeFactory: FlinkTypeFactory) + : SqlFunction = { + new ScalarSqlFunction(name, function, typeFactory) + } + + /** + * Create [[SqlFunction]]s for a [[TableFunction]]'s every eval method + * + * @param name function name + * @param tableFunction table function + * @param resultType the type information of returned table + * @param typeFactory type factory + * @return the TableSqlFunction + */ + def createTableSqlFunctions( + name: String, + tableFunction: TableFunction[_], + resultType: TypeInformation[_], + typeFactory: FlinkTypeFactory) + : Seq[SqlFunction] = { + val (fieldNames, fieldIndexes, _) = UserDefinedFunctionUtils.getFieldInfo(resultType) + val evalMethods = checkAndExtractEvalMethods(tableFunction) + + evalMethods.map { method => + val function = new FlinkTableFunctionImpl(resultType, fieldIndexes, fieldNames, method) + TableSqlFunction(name, tableFunction, resultType, typeFactory, function) + } + } + + // ---------------------------------------------------------------------------------------------- + // Utilities for scalar functions + // ---------------------------------------------------------------------------------------------- + + /** + * Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and uses + * [[TypeExtractor]] as default return type inference. + */ + def getResultType( + function: ScalarFunction, + signature: Array[Class[_]]) + : TypeInformation[_] = { + // find method for signature + val evalMethod = checkAndExtractEvalMethods(function) + .find(m => signature.sameElements(m.getParameterTypes)) + .getOrElse(throw new ValidationException("Given signature is invalid.")) + + val userDefinedTypeInfo = function.getResultType(signature) + if (userDefinedTypeInfo != null) { + userDefinedTypeInfo + } else { + try { + TypeExtractor.getForClass(evalMethod.getReturnType) + } catch { + case ite: InvalidTypesException => + throw new ValidationException( + s"Return type of scalar function '${function.getClass.getCanonicalName}' cannot be " + + s"automatically determined. Please provide type information manually.") + } + } + } + + /** + * Returns the return type of the evaluation method matching the given signature. + */ + def getResultTypeClass( + function: ScalarFunction, + signature: Array[Class[_]]) + : Class[_] = { + // find method for signature + val evalMethod = checkAndExtractEvalMethods(function) + .find(m => signature.sameElements(m.getParameterTypes)) + .getOrElse(throw new IllegalArgumentException("Given signature is invalid.")) + evalMethod.getReturnType + } + + // ---------------------------------------------------------------------------------------------- + // Miscellaneous + // ---------------------------------------------------------------------------------------------- + + /** + * Returns field names and field positions for a given [[TypeInformation]]. + * + * Field names are automatically extracted for + * [[org.apache.flink.api.common.typeutils.CompositeType]]. + * + * @param inputType The TypeInformation to extract the field names and positions from. + * @return A tuple of two arrays holding the field names and corresponding field positions. + */ + def getFieldInfo(inputType: TypeInformation[_]) + : (Array[String], Array[Int], Array[TypeInformation[_]]) = { + + val fieldNames: Array[String] = inputType match { + case t: CompositeType[_] => t.getFieldNames + case a: AtomicType[_] => Array("f0") + case tpe => + throw new TableException(s"Currently only CompositeType and AtomicType are supported. " + + s"Type $tpe lacks explicit field naming") + } + val fieldIndexes = fieldNames.indices.toArray + val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i => + inputType match { + case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]] + case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]] + case tpe => + throw new TableException(s"Currently only CompositeType and AtomicType are supported.") + } + } + (fieldNames, fieldIndexes, fieldTypes) + } + + /** + * Prints one signature consisting of classes. + */ + def signatureToString(signature: Array[Class[_]]): String = + signature.map { clazz => + if (clazz == null) { + "null" + } else { + clazz.getCanonicalName + } + }.mkString("(", ", ", ")") + + /** + * Prints one signature consisting of TypeInformation. + */ + def signatureToString(signature: Seq[TypeInformation[_]]): String = { + signatureToString(typeInfoToClass(signature)) + } + + /** + * Prints all eval methods signatures of a class. + */ + def signaturesToString(function: UserDefinedFunction): String = { + getSignatures(function).map(signatureToString).mkString(", ") + } + + /** + * Extracts type classes of [[TypeInformation]] in a null-aware way. + */ + private def typeInfoToClass(typeInfos: Seq[TypeInformation[_]]): Array[Class[_]] = + typeInfos.map { typeInfo => + if (typeInfo == null) { + null + } else { + typeInfo.getTypeClass + } + }.toArray + + + /** + * Compares parameter candidate classes with expected classes. If true, the parameters match. + * Candidate can be null (acts as a wildcard). + */ + private def parameterTypeEquals(candidate: Class[_], expected: Class[_]): Boolean = + candidate == null || + candidate == expected || + expected.isPrimitive && Primitives.wrap(expected) == candidate || + candidate == classOf[Date] && expected == classOf[Int] || + candidate == classOf[Time] && expected == classOf[Int] || + candidate == classOf[Timestamp] && expected == classOf[Long] + +}