From commits-return-15417-archive-asf-public=cust-asf.ponee.io@flink.apache.org Fri Jan 19 22:37:42 2018 Return-Path: X-Original-To: archive-asf-public@eu.ponee.io Delivered-To: archive-asf-public@eu.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by mx-eu-01.ponee.io (Postfix) with ESMTP id 43768180607 for ; Fri, 19 Jan 2018 22:37:42 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id 30570160C36; Fri, 19 Jan 2018 21:37:42 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 4335E160C1B for ; Fri, 19 Jan 2018 22:37:40 +0100 (CET) Received: (qmail 29517 invoked by uid 500); 19 Jan 2018 21:37:39 -0000 Mailing-List: contact commits-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list commits@flink.apache.org Received: (qmail 29483 invoked by uid 99); 19 Jan 2018 21:37:39 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 19 Jan 2018 21:37:39 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 43D4AE0A2A; Fri, 19 Jan 2018 21:37:36 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 8bit From: fhueske@apache.org To: commits@flink.apache.org Message-Id: <3952a36f24d54a229611de8813b64555@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: flink git commit: [FLINK-8355] [table] Remove DataSetAggregateWithNullValuesRule. Date: Fri, 19 Jan 2018 21:37:36 +0000 (UTC) Repository: flink Updated Branches: refs/heads/master d6b8505f7 -> 20faf262d [FLINK-8355] [table] Remove DataSetAggregateWithNullValuesRule. This closes #5320. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/20faf262 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/20faf262 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/20faf262 Branch: refs/heads/master Commit: 20faf262de9bb52aa614ff2d989a49e8ea82b963 Parents: d6b8505 Author: 金竹 Authored: Wed Jan 3 23:13:49 2018 +0800 Committer: Fabian Hueske Committed: Fri Jan 19 21:00:12 2018 +0100 ---------------------------------------------------------------------- .../plan/nodes/dataset/DataSetAggregate.scala | 13 +- .../flink/table/plan/rules/FlinkRuleSets.scala | 1 - .../rules/dataSet/DataSetAggregateRule.scala | 6 - .../DataSetAggregateWithNullValuesRule.scala | 89 ---------- .../table/runtime/aggregate/AggregateUtil.scala | 6 +- .../runtime/aggregate/DataSetAggFunction.scala | 24 ++- .../table/api/batch/sql/AggregateTest.scala | 44 +---- .../api/batch/sql/DistinctAggregateTest.scala | 167 +++++-------------- .../table/api/batch/sql/GroupingSetsTest.scala | 28 +--- .../table/api/batch/sql/SetOperatorsTest.scala | 19 +-- .../table/api/batch/sql/SingleRowJoinTest.scala | 130 +++------------ .../table/api/batch/table/AggregateTest.scala | 47 +----- .../flink/table/api/batch/table/CalcTest.scala | 17 +- .../table/plan/QueryDecorrelationTest.scala | 75 ++++----- .../aggfunctions/AggFunctionTestBase.scala | 23 ++- .../runtime/batch/sql/AggregateITCase.scala | 11 ++ .../runtime/batch/table/AggregateITCase.scala | 28 +++- .../table/utils/UserDefinedAggFunctions.scala | 25 ++- 18 files changed, 235 insertions(+), 518 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index c65e301..7dd307b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -23,7 +23,6 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.flink.api.common.functions.GroupReduceFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.typeutils.RowTypeInfo @@ -31,8 +30,8 @@ import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.AggregationCodeGenerator import org.apache.flink.table.plan.nodes.CommonAggregate -import org.apache.flink.table.runtime.aggregate.{AggregateUtil, DataSetPreAggFunction} import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair +import org.apache.flink.table.runtime.aggregate.{AggregateUtil, DataSetAggFunction, DataSetFinalAggFunction, DataSetPreAggFunction} import org.apache.flink.types.Row /** @@ -104,7 +103,7 @@ class DataSetAggregate( val ( preAgg: Option[DataSetPreAggFunction], preAggType: Option[TypeInformation[Row]], - finalAgg: GroupReduceFunction[Row, Row] + finalAgg: Either[DataSetAggFunction, DataSetFinalAggFunction] ) = AggregateUtil.createDataSetAggregateFunctions( generator, namedAggregates, @@ -129,13 +128,13 @@ class DataSetAggregate( .name(aggOpName) // final aggregation .groupBy(grouping.indices: _*) - .reduceGroup(finalAgg) + .reduceGroup(finalAgg.right.get) .returns(rowTypeInfo) .name(aggOpName) } else { inputDS .groupBy(grouping: _*) - .reduceGroup(finalAgg) + .reduceGroup(finalAgg.left.get) .returns(rowTypeInfo) .name(aggOpName) } @@ -151,12 +150,12 @@ class DataSetAggregate( .returns(preAggType.get) .name(aggOpName) // final aggregation - .reduceGroup(finalAgg) + .reduceGroup(finalAgg.right.get) .returns(rowTypeInfo) .name(aggOpName) } else { inputDS - .reduceGroup(finalAgg) + .mapPartition(finalAgg.left.get).setParallelism(1) .returns(rowTypeInfo) .name(aggOpName) } http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index b8a96bf..d3ad2ac 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -164,7 +164,6 @@ object FlinkRuleSets { // translate to Flink DataSet nodes DataSetWindowAggregateRule.INSTANCE, DataSetAggregateRule.INSTANCE, - DataSetAggregateWithNullValuesRule.INSTANCE, DataSetDistinctRule.INSTANCE, DataSetCalcRule.INSTANCE, DataSetJoinRule.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala index 9a31617..e73c76e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateRule.scala @@ -37,12 +37,6 @@ class DataSetAggregateRule override def matches(call: RelOptRuleCall): Boolean = { val agg: FlinkLogicalAggregate = call.rel(0).asInstanceOf[FlinkLogicalAggregate] - // for non-grouped agg sets we attach null row to source data - // we need to apply DataSetAggregateWithNullValuesRule - if (agg.getGroupSet.isEmpty) { - return false - } - // distinct is translated into dedicated operator if (agg.getAggCallList.isEmpty && agg.getGroupCount == agg.getRowType.getFieldCount && http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala deleted file mode 100644 index 4a1e6d6..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetAggregateWithNullValuesRule.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.plan.rules.dataSet - - -import com.google.common.collect.ImmutableList -import org.apache.calcite.plan._ -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.convert.ConverterRule -import org.apache.calcite.rex.RexLiteral -import org.apache.flink.table.plan.nodes.FlinkConventions -import org.apache.flink.table.plan.nodes.dataset.DataSetAggregate -import org.apache.flink.table.plan.nodes.logical.{FlinkLogicalAggregate, FlinkLogicalUnion, FlinkLogicalValues} - -import scala.collection.JavaConversions._ - -/** - * Rule for insert [[org.apache.flink.types.Row]] with null records into a [[DataSetAggregate]]. - * Rule apply for non grouped aggregate query. - */ -class DataSetAggregateWithNullValuesRule - extends ConverterRule( - classOf[FlinkLogicalAggregate], - FlinkConventions.LOGICAL, - FlinkConventions.DATASET, - "DataSetAggregateWithNullValuesRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val agg: FlinkLogicalAggregate = call.rel(0).asInstanceOf[FlinkLogicalAggregate] - - // group sets shouldn't attach a null row - // we need to apply other rules. i.e. DataSetAggregateRule - if (!agg.getGroupSet.isEmpty) { - return false - } - - // check if we have distinct aggregates - val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - - !distinctAggs - } - - override def convert(rel: RelNode): RelNode = { - val agg: FlinkLogicalAggregate = rel.asInstanceOf[FlinkLogicalAggregate] - val traitSet: RelTraitSet = rel.getTraitSet.replace(FlinkConventions.DATASET) - val cluster: RelOptCluster = rel.getCluster - - val fieldTypes = agg.getInput.getRowType.getFieldList.map(_.getType) - val nullLiterals: ImmutableList[ImmutableList[RexLiteral]] = - ImmutableList.of(ImmutableList.copyOf[RexLiteral]( - for (fieldType <- fieldTypes) - yield { - cluster.getRexBuilder. - makeLiteral(null, fieldType, false).asInstanceOf[RexLiteral] - })) - - val logicalValues = FlinkLogicalValues.create(cluster, agg.getInput.getRowType, nullLiterals) - val logicalUnion = FlinkLogicalUnion.create(List(logicalValues, agg.getInput), all = true) - - new DataSetAggregate( - cluster, - traitSet, - RelOptRule.convert(logicalUnion, FlinkConventions.DATASET), - agg.getNamedAggCalls, - rel.getRowType, - agg.getInput.getRowType, - agg.getGroupSet.toArray - ) - } -} - -object DataSetAggregateWithNullValuesRule { - val INSTANCE: RelOptRule = new DataSetAggregateWithNullValuesRule -} http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 532bec6..0d07153 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -829,7 +829,7 @@ object AggregateUtil { outputType: RelDataType, groupings: Array[Int]): (Option[DataSetPreAggFunction], Option[TypeInformation[Row]], - RichGroupReduceFunction[Row, Row]) = { + Either[DataSetAggFunction, DataSetFinalAggFunction]) = { val needRetract = false val (aggInFields, aggregates, accTypes, _) = transformToAggregateFunctions( @@ -899,7 +899,7 @@ object AggregateUtil { ( Some(new DataSetPreAggFunction(genPreAggFunction)), Some(preAggRowType), - new DataSetFinalAggFunction(genFinalAggFunction) + Right(new DataSetFinalAggFunction(genFinalAggFunction)) ) } else { @@ -922,7 +922,7 @@ object AggregateUtil { ( None, None, - new DataSetAggFunction(genFunction) + Left(new DataSetAggFunction(genFunction)) ) } http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala index ced1450..313dae0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.runtime.aggregate import java.lang.Iterable -import org.apache.flink.api.common.functions.RichGroupReduceFunction +import org.apache.flink.api.common.functions.{MapPartitionFunction, RichGroupReduceFunction} import org.apache.flink.configuration.Configuration import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction} import org.apache.flink.table.util.Logging @@ -27,14 +27,15 @@ import org.apache.flink.types.Row import org.apache.flink.util.Collector /** - * [[RichGroupReduceFunction]] to compute aggregates that do not support pre-aggregation for batch - * (DataSet) queries. + * [[RichGroupReduceFunction]] and [[MapPartitionFunction]] to compute aggregates that do + * not support pre-aggregation for batch(DataSet) queries. * * @param genAggregations Code-generated [[GeneratedAggregations]] */ class DataSetAggFunction( private val genAggregations: GeneratedAggregationsFunction) extends RichGroupReduceFunction[Row, Row] + with MapPartitionFunction[Row, Row] with Compiler[GeneratedAggregations] with Logging { private var output: Row = _ @@ -56,6 +57,12 @@ class DataSetAggFunction( accumulators = function.createAccumulators() } + /** + * Computes a non-pre-aggregated aggregation. + * + * @param records An iterator over all records of the group. + * @param out The collector to hand results to. + */ override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = { // reset accumulators @@ -79,4 +86,15 @@ class DataSetAggFunction( out.collect(output) } + + /** + * Computes a non-pre-aggregated aggregation and returns a row even if the input is empty. + * + * @param records An iterator over all records of the partition. + * @param out The collector to hand results to. + */ + override def mapPartition(records: Iterable[Row], out: Collector[Row]): Unit = { + reduce(records, out) + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala index f2e250b..921c139 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/AggregateTest.scala @@ -36,21 +36,9 @@ class AggregateTest extends TableTestBase { val sqlQuery = "SELECT avg(a), sum(b), count(c) FROM MyTable" - val setValues = unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null,null,null)), - term("values","a","b","c") - ) - val union = unaryNode( - "DataSetUnion", - setValues, - term("union","a","b","c") - ) - val aggregate = unaryNode( "DataSetAggregate", - union, + batchTableNode(0), term("select", "AVG(a) AS EXPR$0", "SUM(b) AS EXPR$1", @@ -73,22 +61,9 @@ class AggregateTest extends TableTestBase { term("where", "=(a, 1)") ) - val setValues = unaryNode( - "DataSetValues", - calcNode, - tuples(List(null,null,null)), - term("values","a","b","c") - ) - - val union = unaryNode( - "DataSetUnion", - setValues, - term("union","a","b","c") - ) - val aggregate = unaryNode( "DataSetAggregate", - union, + calcNode, term("select", "AVG(a) AS EXPR$0", "SUM(b) AS EXPR$1", @@ -111,22 +86,9 @@ class AggregateTest extends TableTestBase { term("where", "=(a, 1)") ) - val setValues = unaryNode( - "DataSetValues", - calcNode, - tuples(List(null,null,null,null)), - term("values","a","b","c","$f3") - ) - - val union = unaryNode( - "DataSetUnion", - setValues, - term("union","a","b","c","$f3") - ) - val aggregate = unaryNode( "DataSetAggregate", - union, + calcNode, term("select", "AVG(a) AS EXPR$0", "SUM(b) AS EXPR$1", http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala index ce008e4..ced07e4 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/DistinctAggregateTest.scala @@ -36,22 +36,13 @@ class DistinctAggregateTest extends TableTestBase { val expected = unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetDistinct", unaryNode( - "DataSetValues", - unaryNode( - "DataSetDistinct", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a") - ), - term("distinct", "a") - ), - tuples(List(null)), - term("values", "a") + "DataSetCalc", + batchTableNode(0), + term("select", "a") ), - term("union", "a") + term("distinct", "a") ), term("select", "COUNT(a) AS EXPR$0") ) @@ -69,22 +60,13 @@ class DistinctAggregateTest extends TableTestBase { val expected = unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetDistinct", unaryNode( - "DataSetValues", - unaryNode( - "DataSetDistinct", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a") - ), - term("distinct", "a") - ), - tuples(List(null)), - term("values", "a") + "DataSetCalc", + batchTableNode(0), + term("select", "a") ), - term("union", "a") + term("distinct", "a") ), term("select", "COUNT(a) AS EXPR$0", "SUM(a) AS EXPR$1", "MAX(a) AS EXPR$2") ) @@ -103,23 +85,14 @@ class DistinctAggregateTest extends TableTestBase { val expected0 = unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetAggregate", unaryNode( - "DataSetValues", - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a", "b") - ), - term("groupBy", "a"), - term("select", "a", "SUM(b) AS EXPR$1") - ), - tuples(List(null, null)), - term("values", "a", "EXPR$1") + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b") ), - term("union", "a", "EXPR$1") + term("groupBy", "a"), + term("select", "a", "SUM(b) AS EXPR$1") ), term("select", "COUNT(a) AS EXPR$0", "SUM(EXPR$1) AS EXPR$1") ) @@ -132,23 +105,14 @@ class DistinctAggregateTest extends TableTestBase { val expected1 = unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetAggregate", unaryNode( - "DataSetValues", - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a", "b") - ), - term("groupBy", "b"), - term("select", "b", "COUNT(a) AS EXPR$0") - ), - tuples(List(null, null)), - term("values", "b", "EXPR$0") + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b") ), - term("union", "b", "EXPR$0") + term("groupBy", "b"), + term("select", "b", "COUNT(a) AS EXPR$0") ), term("select", "$SUM0(EXPR$0) AS EXPR$0", "SUM(b) AS EXPR$1") ) @@ -168,44 +132,26 @@ class DistinctAggregateTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetDistinct", unaryNode( - "DataSetValues", - unaryNode( - "DataSetDistinct", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a") - ), - term("distinct", "a") - ), - tuples(List(null)), - term("values", "a") + "DataSetCalc", + batchTableNode(0), + term("select", "a") ), - term("union", "a") + term("distinct", "a") ), term("select", "COUNT(a) AS EXPR$0") ), unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetDistinct", unaryNode( - "DataSetValues", - unaryNode( - "DataSetDistinct", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "b") - ), - term("distinct", "b") - ), - tuples(List(null)), - term("values", "b") + "DataSetCalc", + batchTableNode(0), + term("select", "b") ), - term("union", "b") + term("distinct", "b") ), term("select", "SUM(b) AS EXPR$1") ), @@ -232,37 +178,19 @@ class DistinctAggregateTest extends TableTestBase { "DataSetSingleRowJoin", unaryNode( "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null, null, null)), - term("values", "a, b, c") - ), - term("union", "a, b, c") - ), + batchTableNode(0), term("select", "COUNT(c) AS EXPR$2") ), unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetDistinct", unaryNode( - "DataSetValues", - unaryNode( - "DataSetDistinct", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a") - ), - term("distinct", "a") - ), - tuples(List(null)), - term("values", "a") + "DataSetCalc", + batchTableNode(0), + term("select", "a") ), - term("union", "a") + term("distinct", "a") ), term("select", "COUNT(a) AS EXPR$0") ), @@ -273,22 +201,13 @@ class DistinctAggregateTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", + "DataSetDistinct", unaryNode( - "DataSetValues", - unaryNode( - "DataSetDistinct", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "b") - ), - term("distinct", "b") - ), - tuples(List(null)), - term("values", "b") + "DataSetCalc", + batchTableNode(0), + term("select", "b") ), - term("union", "b") + term("distinct", "b") ), term("select", "SUM(b) AS EXPR$1") ), http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala index 9f3d2b6..57a4c5a 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/GroupingSetsTest.scala @@ -114,19 +114,11 @@ class GroupingSetsTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null, null, null)), - term("values", "a", "b", "c") - ), - term("union", "a", "b", "c") - ), + batchTableNode(0), term("select", "AVG(a) AS a") ), - term("select", "null AS b", "null AS c", "a", "0 AS g", "0 AS gb", "0 AS gc", + term( + "select", "null AS b", "null AS c", "a", "0 AS g", "0 AS gb", "0 AS gc", "0 AS gib", "0 AS gic", "0 AS gid") ) @@ -189,19 +181,11 @@ class GroupingSetsTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null, null, null)), - term("values", "a", "b", "c") - ), - term("union", "a", "b", "c") - ), + batchTableNode(0), term("select", "AVG(a) AS a") ), - term("select", "null AS b", "null AS c", "a", "0 AS g", "0 AS gb", "0 AS gc", + term( + "select", "null AS b", "null AS c", "a", "0 AS g", "0 AS gb", "0 AS gc", "0 AS gib", "0 AS gic", "0 AS gid") ) http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala index bff0b78..d51fc42 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SetOperatorsTest.scala @@ -101,20 +101,11 @@ class SetOperatorsTest extends TableTestBase { batchTableNode(0), unaryNode( "DataSetAggregate", - binaryNode( - "DataSetUnion", - values( - "DataSetValues", - term("tuples", "[{ null }]"), - term("values", "b") - ), - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "b"), - term("where", "OR(=(b, 6), =(b, 1))") - ), - term("union", "b") + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "b"), + term("where", "OR(=(b, 6), =(b, 1))") ), term("select", "COUNT(*) AS $f0", "COUNT(b) AS $f1") ), http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala index 8bfb61b..59156d6 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/sql/SingleRowJoinTest.scala @@ -47,16 +47,7 @@ class SingleRowJoinTest extends TableTestBase { "DataSetCalc", unaryNode( "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null, null)), - term("values", "a1", "a2") - ), - term("union","a1","a2") - ), + batchTableNode(0), term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1") ), term("select", "+($f0, $f1) AS asum") @@ -88,18 +79,9 @@ class SingleRowJoinTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - tuples(List(null)), - term("values", "a1") - ), - term("union","a1") + "DataSetCalc", + batchTableNode(0), + term("select", "a1") ), term("select", "COUNT(a1) AS cnt") ), @@ -132,18 +114,9 @@ class SingleRowJoinTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - tuples(List(null)), - term("values", "a1") - ), - term("union", "a1") + "DataSetCalc", + batchTableNode(0), + term("select", "a1") ), term("select", "COUNT(a1) AS cnt") ), @@ -173,16 +146,7 @@ class SingleRowJoinTest extends TableTestBase { batchTableNode(0), unaryNode( "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(1), - tuples(List(null, null)), - term("values", "b1", "b2") - ), - term("union","b1","b2") - ), + batchTableNode(1), term("select", "MIN(b1) AS b1", "MAX(b2) AS b2") ), term("where", "AND(<(a1, b1)", "=(a2, b2))"), @@ -221,17 +185,9 @@ class SingleRowJoinTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "0 AS $f0")), - tuples(List(null)), term("values", "$f0") - ), - term("union", "$f0") - ), + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), term("select", "COUNT(*) AS cnt") ) @@ -266,17 +222,9 @@ class SingleRowJoinTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "0 AS $f0")), - tuples(List(null)), term("values", "$f0") - ), - term("union", "$f0") - ), + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), term("select", "COUNT(*) AS cnt") ) @@ -308,21 +256,13 @@ class SingleRowJoinTest extends TableTestBase { ), term("select", "a1") ) + unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "0 AS $f0")), - tuples(List(null)), term("values", "$f0") - ), - term("union", "$f0") - ), - term("select", "COUNT(*) AS cnt") - ) + "\n" + + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + term("select", "COUNT(*) AS cnt") + ) + "\n" + batchTableNode(0) util.verifySql(queryRightJoin, expected) @@ -356,17 +296,9 @@ class SingleRowJoinTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "0 AS $f0")), - tuples(List(null)), term("values", "$f0") - ), - term("union", "$f0") - ), + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), term("select", "COUNT(*) AS cnt") ) + "\n" + batchTableNode(0) @@ -406,17 +338,9 @@ class SingleRowJoinTest extends TableTestBase { unaryNode( "DataSetAggregate", unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - tuples(List(null)), term("values", "a1") - ), - term("union", "a1") + "DataSetCalc", + batchTableNode(0), + term("select", "a1") ), term("select", "SUM(a1) AS $f0") ), http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala index 0a135d1..df65481 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/AggregateTest.scala @@ -66,21 +66,9 @@ class AggregateTest extends TableTestBase { val sourceTable = util.addTable[(Int, Long, Int)]("MyTable", 'a, 'b, 'c) val resultTable = sourceTable.select('a.avg,'b.sum,'c.count) - val setValues = unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null,null,null)), - term("values","a","b","c") - ) - val union = unaryNode( - "DataSetUnion", - setValues, - term("union","a","b","c") - ) - val expected = unaryNode( "DataSetAggregate", - union, + batchTableNode(0), term("select", "AVG(a) AS TMP_0", "SUM(b) AS TMP_1", @@ -106,22 +94,9 @@ class AggregateTest extends TableTestBase { term("where", "=(a, 1)") ) - val setValues = unaryNode( - "DataSetValues", - calcNode, - tuples(List(null,null,null)), - term("values","a","b","c") - ) - - val union = unaryNode( - "DataSetUnion", - setValues, - term("union","a","b","c") - ) - val expected = unaryNode( "DataSetAggregate", - union, + calcNode, term("select", "AVG(a) AS TMP_0", "SUM(b) AS TMP_1", @@ -148,23 +123,11 @@ class AggregateTest extends TableTestBase { term("where", "=(a, 1)") ) - val setValues = unaryNode( - "DataSetValues", - calcNode, - tuples(List(null,null,null,null)), - term("values","a","b","c","$f3") - ) - - val union = unaryNode( - "DataSetUnion", - setValues, - term("union","a","b","c","$f3") - ) - val expected = unaryNode( "DataSetAggregate", - union, - term("select", + calcNode, + term( + "select", "AVG(a) AS TMP_0", "SUM(b) AS TMP_1", "COUNT(c) AS TMP_2", http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala index ff6dcf1..bba1a5b 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/batch/table/CalcTest.scala @@ -139,19 +139,10 @@ class CalcTest extends TableTestBase { val expected = unaryNode( "DataSetAggregate", - binaryNode( - "DataSetUnion", - values( - "DataSetValues", - tuples(List(null, null)), - term("values", "a", "b") - ), - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a", "b") - ), - term("union", "a", "b") + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a", "b") ), term("select", "SUM(a) AS TMP_0", "MAX(b) AS TMP_1") ) http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala index 0c3796f..c952578 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/QueryDecorrelationTest.scala @@ -98,55 +98,46 @@ class QueryDecorrelationTest extends TableTestBase { val expectedQuery = unaryNode( "DataSetAggregate", - binaryNode( - "DataSetUnion", - values( - "DataSetValues", - tuples(List(null)), - term("values", "empno") - ), - unaryNode( - "DataSetCalc", - binaryNode( - "DataSetJoin", - unaryNode( - "DataSetCalc", - binaryNode( - "DataSetJoin", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "empno", "salary", "deptno") - ), - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "deptno") - ), - term("where", "=(deptno, deptno0)"), - term("join", "empno", "salary", "deptno", "deptno0"), - term("joinType", "InnerJoin") - ), - term("select", "empno", "salary", "deptno0") - ), - unaryNode( - "DataSetAggregate", + unaryNode( + "DataSetCalc", + binaryNode( + "DataSetJoin", + unaryNode( + "DataSetCalc", + binaryNode( + "DataSetJoin", unaryNode( "DataSetCalc", batchTableNode(0), - term("select", "salary", "deptno"), - term("where", "IS NOT NULL(deptno)") + term("select", "empno", "salary", "deptno") + ), + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "deptno") ), - term("groupBy", "deptno"), - term("select", "deptno", "AVG(salary) AS EXPR$0") + term("where", "=(deptno, deptno0)"), + term("join", "empno", "salary", "deptno", "deptno0"), + term("joinType", "InnerJoin") ), - term("where", "AND(=(deptno0, deptno), >(salary, EXPR$0))"), - term("join", "empno", "salary", "deptno0", "deptno", "EXPR$0"), - term("joinType", "InnerJoin") + term("select", "empno", "salary", "deptno0") + ), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "salary", "deptno"), + term("where", "IS NOT NULL(deptno)") + ), + term("groupBy", "deptno"), + term("select", "deptno", "AVG(salary) AS EXPR$0") ), - term("select", "empno") + term("where", "AND(=(deptno0, deptno), >(salary, EXPR$0))"), + term("join", "empno", "salary", "deptno0", "deptno", "EXPR$0"), + term("joinType", "InnerJoin") ), - term("union", "empno") + term("select", "empno") ), term("select", "SUM(empno) AS EXPR$0") ) http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala index 458f80d..bdd1df0 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala @@ -146,14 +146,31 @@ abstract class AggFunctionTestBase[T, ACC] { val accumulator = aggregator.createAccumulator() vals.foreach( v => - accumulateFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object]) + if (accumulateFunc.getParameterCount == 1) { + this.accumulateFunc.invoke(aggregator, accumulator.asInstanceOf[Object]) + } else { + this.accumulateFunc.invoke( + aggregator, + accumulator.asInstanceOf[Object], + v.asInstanceOf[Object]) + } ) accumulator } - private def retractVals(accumulator:ACC, vals: Seq[_]) = { + private def retractVals(accumulator: ACC, vals: Seq[_]) = { vals.foreach( - v => retractFunc.invoke(aggregator, accumulator.asInstanceOf[Object], v.asInstanceOf[Object]) + v => + if (retractFunc.getParameterCount == 1) { + this.retractFunc.invoke( + aggregator, + accumulator.asInstanceOf[Object]) + } else { + this.retractFunc.invoke( + aggregator, + accumulator.asInstanceOf[Object], + v.asInstanceOf[Object]) + } ) } } http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala index b105ec02..ac0b705 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/sql/AggregateITCase.scala @@ -27,6 +27,7 @@ import org.apache.flink.table.api.scala._ import org.apache.flink.table.functions.aggfunctions.CountAggFunction import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode +import org.apache.flink.table.utils.NonMergableCount import org.apache.flink.test.util.TestBaseUtils import org.apache.flink.types.Row import org.junit._ @@ -34,6 +35,7 @@ import org.junit.runner.RunWith import org.junit.runners.Parameterized import scala.collection.JavaConverters._ +import scala.collection.mutable @RunWith(classOf[Parameterized]) class AggregateITCase( @@ -262,6 +264,8 @@ class AggregateITCase( val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) + val myAgg = new NonMergableCount + tEnv.registerFunction("myAgg", myAgg) val sqlQuery = "SELECT avg(a), sum(a), count(b) " + "FROM MyTable where a = 4 group by a" @@ -272,6 +276,9 @@ class AggregateITCase( val sqlQuery3 = "SELECT avg(a), sum(a), count(b) " + "FROM MyTable" + val sqlQuery4 = "SELECT avg(a), sum(a), count(b), myAgg(b)" + + "FROM MyTable where a = 4" + val ds = env.fromElements( (1: Byte, 1: Short), (2: Byte, 2: Short)) @@ -282,6 +289,7 @@ class AggregateITCase( val result = tEnv.sqlQuery(sqlQuery) val result2 = tEnv.sqlQuery(sqlQuery2) val result3 = tEnv.sqlQuery(sqlQuery3) + val result4 = tEnv.sqlQuery(sqlQuery4) val results = result.toDataSet[Row].collect() val expected = Seq.empty @@ -289,11 +297,14 @@ class AggregateITCase( val expected2 = "null,null,0" val results3 = result3.toDataSet[Row].collect() val expected3 = "1,3,2" + val results4 = result4.toDataSet[Row].collect() + val expected4 = "null,null,0,0" assert(results.equals(expected), "Empty result is expected for grouped set, but actual: " + results) TestBaseUtils.compareResultAsText(results2.asJava, expected2) TestBaseUtils.compareResultAsText(results3.asJava, expected3) + TestBaseUtils.compareResultAsText(results4.asJava, expected4) } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala index e1348f6..892e4f3 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/AggregateITCase.scala @@ -28,7 +28,7 @@ import org.apache.flink.table.api.scala._ import org.apache.flink.table.functions.aggfunctions.CountAggFunction import org.apache.flink.table.runtime.utils.TableProgramsCollectionTestBase import org.apache.flink.table.runtime.utils.TableProgramsTestBase.TableConfigMode -import org.apache.flink.table.utils.Top10 +import org.apache.flink.table.utils.{NonMergableCount, Top10} import org.apache.flink.test.util.TestBaseUtils import org.apache.flink.types.Row import org.junit._ @@ -36,6 +36,7 @@ import org.junit.runner.RunWith import org.junit.runners.Parameterized import scala.collection.JavaConverters._ +import scala.collection.mutable @RunWith(classOf[Parameterized]) class AggregationsITCase( @@ -266,13 +267,36 @@ class AggregationsITCase( .select('a.sum as 'd, 'b) .groupBy('b, 'd) .select('b) - val expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) } @Test + def testAggregateEmptyDataSets(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val myAgg = new NonMergableCount + + val t1 = env.fromCollection(new mutable.MutableList[(Int, String)]).toTable(tEnv, 'a, 'b) + .select('a.sum, 'a.count) + val t2 = env.fromCollection(new mutable.MutableList[(Int, String)]).toTable(tEnv, 'a, 'b) + .select('a.sum, myAgg('b), 'a.count) + + val expected1 = "null,0" + val expected2 = "null,0,0" + + val results1 = t1.toDataSet[Row].collect() + val results2 = t2.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results1.asJava, expected1) + TestBaseUtils.compareResultAsText(results2.asJava, expected2) + + } + + @Test def testGroupedAggregateWithLongKeys(): Unit = { // This uses very long keys to force serialized comparison. // With short keys, the normalized key is sufficient. http://git-wip-us.apache.org/repos/asf/flink/blob/20faf262/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala index 7d4393c..14c8461 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/UserDefinedAggFunctions.scala @@ -45,8 +45,8 @@ class Top10 extends AggregateFunction[Array[JTuple2[JInt, JFloat]], Array[JTuple /** * Adds a new entry and count to the top 10 entries if necessary. * - * @param acc The current top 10 - * @param id The ID + * @param acc The current top 10 + * @param id The ID * @param value The value for the ID */ def accumulate(acc: Array[JTuple2[JInt, JFloat]], id: Int, value: Float) { @@ -91,7 +91,7 @@ class Top10 extends AggregateFunction[Array[JTuple2[JInt, JFloat]], Array[JTuple its: java.lang.Iterable[Array[JTuple2[JInt, JFloat]]]): Unit = { val it = its.iterator() - while(it.hasNext) { + while (it.hasNext) { val acc2 = it.next() var i = 0 @@ -124,3 +124,22 @@ class Top10 extends AggregateFunction[Array[JTuple2[JInt, JFloat]], Array[JTuple ObjectArrayTypeInfo.getInfoFor(new TupleTypeInfo[JTuple2[JInt, JFloat]](Types.INT, Types.FLOAT)) } } + +case class NonMergableCountAcc(var count: Long) + +class NonMergableCount extends AggregateFunction[Long, NonMergableCountAcc] { + + def accumulate(acc: NonMergableCountAcc, value: Any): Unit = { + if (null != value) { + acc.count = acc.count + 1 + } + } + + def resetAccumulator(acc: NonMergableCountAcc): Unit = { + acc.count = 0 + } + + override def createAccumulator(): NonMergableCountAcc = NonMergableCountAcc(0) + + override def getValue(acc: NonMergableCountAcc): Long = acc.count +}