Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 87A65200C4E for ; Thu, 6 Apr 2017 21:29:14 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 84D14160B9F; Thu, 6 Apr 2017 19:29:14 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 62F72160BAA for ; Thu, 6 Apr 2017 21:29:12 +0200 (CEST) Received: (qmail 54016 invoked by uid 500); 6 Apr 2017 19:29:11 -0000 Mailing-List: contact commits-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list commits@flink.apache.org Received: (qmail 53462 invoked by uid 99); 6 Apr 2017 19:29:11 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 06 Apr 2017 19:29:11 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 4A817F49D5; Thu, 6 Apr 2017 19:29:09 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: fhueske@apache.org To: commits@flink.apache.org Date: Thu, 06 Apr 2017 19:29:19 -0000 Message-Id: <41e28868321942c6b2d498a58ee5b4c7@git.apache.org> In-Reply-To: References: X-Mailer: ASF-Git Admin Mailer Subject: [12/12] flink git commit: [FLINK-6216] [table] Add non-windowed GroupBy aggregation for streams. archived-at: Thu, 06 Apr 2017 19:29:14 -0000 [FLINK-6216] [table] Add non-windowed GroupBy aggregation for streams. This closes #3646. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ff262508 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ff262508 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ff262508 Branch: refs/heads/table-retraction Commit: ff2625089e9f184326c7a8b39cbdbec35ba58869 Parents: c5173fa Author: shaoxuan-wang Authored: Thu Mar 30 03:57:58 2017 +0800 Committer: Fabian Hueske Committed: Thu Apr 6 21:28:57 2017 +0200 ---------------------------------------------------------------------- .../flink/table/plan/logical/operators.scala | 3 - .../nodes/datastream/DataStreamAggregate.scala | 267 ------------------- .../datastream/DataStreamGroupAggregate.scala | 133 +++++++++ .../DataStreamGroupWindowAggregate.scala | 267 +++++++++++++++++++ .../flink/table/plan/rules/FlinkRuleSets.scala | 3 +- .../datastream/DataStreamAggregateRule.scala | 76 ------ .../DataStreamGroupAggregateRule.scala | 77 ++++++ .../DataStreamGroupWindowAggregateRule.scala | 75 ++++++ .../table/runtime/aggregate/AggregateUtil.scala | 46 +++- .../aggregate/GroupAggProcessFunction.scala | 100 +++++++ .../scala/batch/table/FieldProjectionTest.scala | 4 +- .../table/api/scala/stream/sql/SqlITCase.scala | 21 ++ .../scala/stream/sql/WindowAggregateTest.scala | 47 ++-- .../scala/stream/table/AggregationsITCase.scala | 167 ------------ .../stream/table/GroupAggregationsITCase.scala | 132 +++++++++ .../stream/table/GroupAggregationsTest.scala | 218 +++++++++++++++ .../table/GroupWindowAggregationsITCase.scala | 167 ++++++++++++ .../scala/stream/table/GroupWindowTest.scala | 56 ++-- .../scala/stream/table/UnsupportedOpsTest.scala | 7 - 19 files changed, 1288 insertions(+), 578 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala index 559bd75..7438082 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/logical/operators.scala @@ -221,9 +221,6 @@ case class Aggregate( } override def validate(tableEnv: TableEnvironment): LogicalNode = { - if (tableEnv.isInstanceOf[StreamTableEnvironment]) { - failValidation(s"Aggregate on stream tables is currently not supported.") - } val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate] val groupingExprs = resolvedAggregate.groupingExpressions http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala deleted file mode 100644 index 50f8281..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamAggregate.scala +++ /dev/null @@ -1,267 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.nodes.datastream - -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.flink.api.java.tuple.Tuple -import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} -import org.apache.flink.streaming.api.windowing.assigners._ -import org.apache.flink.streaming.api.windowing.time.Time -import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} -import org.apache.flink.table.api.StreamTableEnvironment -import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty -import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.expressions._ -import org.apache.flink.table.plan.logical._ -import org.apache.flink.table.plan.nodes.CommonAggregate -import org.apache.flink.table.plan.nodes.datastream.DataStreamAggregate._ -import org.apache.flink.table.runtime.aggregate.AggregateUtil._ -import org.apache.flink.table.runtime.aggregate._ -import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval -import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} -import org.apache.flink.types.Row - -class DataStreamAggregate( - window: LogicalWindow, - namedProperties: Seq[NamedWindowProperty], - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputNode: RelNode, - namedAggregates: Seq[CalcitePair[AggregateCall, String]], - rowRelDataType: RelDataType, - inputType: RelDataType, - grouping: Array[Int]) - extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { - - override def deriveRowType(): RelDataType = rowRelDataType - - override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { - new DataStreamAggregate( - window, - namedProperties, - cluster, - traitSet, - inputs.get(0), - namedAggregates, - getRowType, - inputType, - grouping) - } - - override def toString: String = { - s"Aggregate(${ - if (!grouping.isEmpty) { - s"groupBy: (${groupingToString(inputType, grouping)}), " - } else { - "" - } - }window: ($window), " + - s"select: (${ - aggregationToString( - inputType, - grouping, - getRowType, - namedAggregates, - namedProperties) - }))" - } - - override def explainTerms(pw: RelWriter): RelWriter = { - super.explainTerms(pw) - .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) - .item("window", window) - .item( - "select", aggregationToString( - inputType, - grouping, - getRowType, - namedAggregates, - namedProperties)) - } - - override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { - - val groupingKeys = grouping.indices.toArray - val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) - - val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) - - val aggString = aggregationToString( - inputType, - grouping, - getRowType, - namedAggregates, - namedProperties) - - val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + - s"window: ($window), " + - s"select: ($aggString)" - val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" - - // grouped / keyed aggregation - if (groupingKeys.length > 0) { - val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( - window, - groupingKeys.length, - namedAggregates.size, - rowRelDataType.getFieldCount, - namedProperties) - - val keyedStream = inputDS.keyBy(groupingKeys: _*) - val windowedStream = - createKeyedWindowedStream(window, keyedStream) - .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] - - val (aggFunction, accumulatorRowType, aggResultRowType) = - AggregateUtil.createDataStreamAggregateFunction( - namedAggregates, - inputType, - rowRelDataType, - grouping) - - windowedStream - .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) - .name(keyedAggOpName) - } - // global / non-keyed aggregation - else { - val windowFunction = AggregateUtil.createAggregationAllWindowFunction( - window, - rowRelDataType.getFieldCount, - namedProperties) - - val windowedStream = - createNonKeyedWindowedStream(window, inputDS) - .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] - - val (aggFunction, accumulatorRowType, aggResultRowType) = - AggregateUtil.createDataStreamAggregateFunction( - namedAggregates, - inputType, - rowRelDataType, - grouping) - - windowedStream - .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) - .name(nonKeyedAggOpName) - } - } -} - -object DataStreamAggregate { - - - private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) - : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { - - case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => - stream.window(TumblingProcessingTimeWindows.of(asTime(size))) - - case ProcessingTimeTumblingGroupWindow(_, size) => - stream.countWindow(asCount(size)) - - case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => - stream.window(TumblingEventTimeWindows.of(asTime(size))) - - case EventTimeTumblingGroupWindow(_, _, size) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => - stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) - - case ProcessingTimeSlidingGroupWindow(_, size, slide) => - stream.countWindow(asCount(size), asCount(slide)) - - case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => - stream.window(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) - - case EventTimeSlidingGroupWindow(_, _, size, slide) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSessionGroupWindow(_, gap: Expression) => - stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap))) - - case EventTimeSessionGroupWindow(_, _, gap) => - stream.window(EventTimeSessionWindows.withGap(asTime(gap))) - } - - private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) - : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { - - case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => - stream.windowAll(TumblingProcessingTimeWindows.of(asTime(size))) - - case ProcessingTimeTumblingGroupWindow(_, size) => - stream.countWindowAll(asCount(size)) - - case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => - stream.windowAll(TumblingEventTimeWindows.of(asTime(size))) - - case EventTimeTumblingGroupWindow(_, _, size) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => - stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) - - case ProcessingTimeSlidingGroupWindow(_, size, slide) => - stream.countWindowAll(asCount(size), asCount(slide)) - - case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => - stream.windowAll(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) - - case EventTimeSlidingGroupWindow(_, _, size, slide) => - // TODO: EventTimeTumblingGroupWindow should sort the stream on event time - // before applying the windowing logic. Otherwise, this would be the same as a - // ProcessingTimeTumblingGroupWindow - throw new UnsupportedOperationException( - "Event-time grouping windows on row intervals are currently not supported.") - - case ProcessingTimeSessionGroupWindow(_, gap) => - stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap))) - - case EventTimeSessionGroupWindow(_, _, gap) => - stream.windowAll(EventTimeSessionWindows.withGap(asTime(gap))) - } - - def asTime(expr: Expression): Time = expr match { - case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => Time.milliseconds(value) - case _ => throw new IllegalArgumentException() - } - - def asCount(expr: Expression): Long = expr match { - case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value - case _ => throw new IllegalArgumentException() - } -} - http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala new file mode 100644 index 0000000..c2d4fb7 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.java.functions.NullByteKeySelector +import org.apache.flink.streaming.api.datastream.DataStream +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.runtime.aggregate._ +import org.apache.flink.table.plan.nodes.CommonAggregate +import org.apache.flink.types.Row +import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair + +/** + * + * Flink RelNode for data stream unbounded group aggregate + * + * @param cluster Cluster of the RelNode, represent for an environment of related + * relational expressions during the optimization of a query. + * @param traitSet Trait set of the RelNode + * @param inputNode The input RelNode of aggregation + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param rowRelDataType The type of the rows of the RelNode + * @param inputType The type of the rows of aggregation input RelNode + * @param groupings The position (in the input Row) of the grouping keys + */ +class DataStreamGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + rowRelDataType: RelDataType, + inputType: RelDataType, + groupings: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) + with CommonAggregate + with DataStreamRel { + + override def deriveRowType() = rowRelDataType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamGroupAggregate( + cluster, + traitSet, + inputs.get(0), + namedAggregates, + getRowType, + inputType, + groupings) + } + + override def toString: String = { + s"Aggregate(${ + if (!groupings.isEmpty) { + s"groupBy: (${groupingToString(inputType, groupings)}), " + } else { + "" + } + }select:(${aggregationToString(inputType, groupings, getRowType, namedAggregates, Nil)}))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString(inputType, groupings), !groupings.isEmpty) + .item("select", aggregationToString(inputType, groupings, getRowType, namedAggregates, Nil)) + } + + override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { + + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + + val aggString = aggregationToString( + inputType, + groupings, + getRowType, + namedAggregates, + Nil) + + val keyedAggOpName = s"groupBy: (${groupingToString(inputType, groupings)}), " + + s"select: ($aggString)" + val nonKeyedAggOpName = s"select: ($aggString)" + + val processFunction = AggregateUtil.createGroupAggregateFunction( + namedAggregates, + inputType, + groupings) + + val result: DataStream[Row] = + // grouped / keyed aggregation + if (groupings.nonEmpty) { + inputDS + .keyBy(groupings: _*) + .process(processFunction) + .returns(rowTypeInfo) + .name(keyedAggOpName) + .asInstanceOf[DataStream[Row]] + } + // global / non-keyed aggregation + else { + inputDS + .keyBy(new NullByteKeySelector[Row]) + .process(processFunction) + .setParallelism(1) + .setMaxParallelism(1) + .returns(rowTypeInfo) + .name(nonKeyedAggOpName) + .asInstanceOf[DataStream[Row]] + } + result + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala new file mode 100644 index 0000000..a0c1dec --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.datastream + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} +import org.apache.flink.api.java.tuple.Tuple +import org.apache.flink.streaming.api.datastream.{AllWindowedStream, DataStream, KeyedStream, WindowedStream} +import org.apache.flink.streaming.api.windowing.assigners._ +import org.apache.flink.streaming.api.windowing.time.Time +import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} +import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.expressions._ +import org.apache.flink.table.plan.logical._ +import org.apache.flink.table.plan.nodes.CommonAggregate +import org.apache.flink.table.plan.nodes.datastream.DataStreamGroupWindowAggregate._ +import org.apache.flink.table.runtime.aggregate.AggregateUtil._ +import org.apache.flink.table.runtime.aggregate._ +import org.apache.flink.table.typeutils.TypeCheckUtils.isTimeInterval +import org.apache.flink.table.typeutils.{RowIntervalTypeInfo, TimeIntervalTypeInfo} +import org.apache.flink.types.Row + +class DataStreamGroupWindowAggregate( + window: LogicalWindow, + namedProperties: Seq[NamedWindowProperty], + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputNode: RelNode, + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + rowRelDataType: RelDataType, + inputType: RelDataType, + grouping: Array[Int]) + extends SingleRel(cluster, traitSet, inputNode) with CommonAggregate with DataStreamRel { + + override def deriveRowType(): RelDataType = rowRelDataType + + override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { + new DataStreamGroupWindowAggregate( + window, + namedProperties, + cluster, + traitSet, + inputs.get(0), + namedAggregates, + getRowType, + inputType, + grouping) + } + + override def toString: String = { + s"Aggregate(${ + if (!grouping.isEmpty) { + s"groupBy: (${groupingToString(inputType, grouping)}), " + } else { + "" + } + }window: ($window), " + + s"select: (${ + aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + }))" + } + + override def explainTerms(pw: RelWriter): RelWriter = { + super.explainTerms(pw) + .itemIf("groupBy", groupingToString(inputType, grouping), !grouping.isEmpty) + .item("window", window) + .item( + "select", aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties)) + } + + override def translateToPlan(tableEnv: StreamTableEnvironment): DataStream[Row] = { + + val groupingKeys = grouping.indices.toArray + val inputDS = input.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) + + val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + + val aggString = aggregationToString( + inputType, + grouping, + getRowType, + namedAggregates, + namedProperties) + + val keyedAggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + + s"window: ($window), " + + s"select: ($aggString)" + val nonKeyedAggOpName = s"window: ($window), select: ($aggString)" + + // grouped / keyed aggregation + if (groupingKeys.length > 0) { + val windowFunction = AggregateUtil.createAggregationGroupWindowFunction( + window, + groupingKeys.length, + namedAggregates.size, + rowRelDataType.getFieldCount, + namedProperties) + + val keyedStream = inputDS.keyBy(groupingKeys: _*) + val windowedStream = + createKeyedWindowedStream(window, keyedStream) + .asInstanceOf[WindowedStream[Row, Tuple, DataStreamWindow]] + + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamGroupWindowAggregateFunction( + namedAggregates, + inputType, + rowRelDataType, + grouping) + + windowedStream + .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) + .name(keyedAggOpName) + } + // global / non-keyed aggregation + else { + val windowFunction = AggregateUtil.createAggregationAllWindowFunction( + window, + rowRelDataType.getFieldCount, + namedProperties) + + val windowedStream = + createNonKeyedWindowedStream(window, inputDS) + .asInstanceOf[AllWindowedStream[Row, DataStreamWindow]] + + val (aggFunction, accumulatorRowType, aggResultRowType) = + AggregateUtil.createDataStreamGroupWindowAggregateFunction( + namedAggregates, + inputType, + rowRelDataType, + grouping) + + windowedStream + .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, rowTypeInfo) + .name(nonKeyedAggOpName) + } + } +} + +object DataStreamGroupWindowAggregate { + + + private def createKeyedWindowedStream(groupWindow: LogicalWindow, stream: KeyedStream[Row, Tuple]) + : WindowedStream[Row, Tuple, _ <: DataStreamWindow] = groupWindow match { + + case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => + stream.window(TumblingProcessingTimeWindows.of(asTime(size))) + + case ProcessingTimeTumblingGroupWindow(_, size) => + stream.countWindow(asCount(size)) + + case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.window(TumblingEventTimeWindows.of(asTime(size))) + + case EventTimeTumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => + stream.window(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + + case ProcessingTimeSlidingGroupWindow(_, size, slide) => + stream.countWindow(asCount(size), asCount(slide)) + + case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => + stream.window(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + + case EventTimeSlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSessionGroupWindow(_, gap: Expression) => + stream.window(ProcessingTimeSessionWindows.withGap(asTime(gap))) + + case EventTimeSessionGroupWindow(_, _, gap) => + stream.window(EventTimeSessionWindows.withGap(asTime(gap))) + } + + private def createNonKeyedWindowedStream(groupWindow: LogicalWindow, stream: DataStream[Row]) + : AllWindowedStream[Row, _ <: DataStreamWindow] = groupWindow match { + + case ProcessingTimeTumblingGroupWindow(_, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingProcessingTimeWindows.of(asTime(size))) + + case ProcessingTimeTumblingGroupWindow(_, size) => + stream.countWindowAll(asCount(size)) + + case EventTimeTumblingGroupWindow(_, _, size) if isTimeInterval(size.resultType) => + stream.windowAll(TumblingEventTimeWindows.of(asTime(size))) + + case EventTimeTumblingGroupWindow(_, _, size) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSlidingGroupWindow(_, size, slide) if isTimeInterval(size.resultType) => + stream.windowAll(SlidingProcessingTimeWindows.of(asTime(size), asTime(slide))) + + case ProcessingTimeSlidingGroupWindow(_, size, slide) => + stream.countWindowAll(asCount(size), asCount(slide)) + + case EventTimeSlidingGroupWindow(_, _, size, slide) if isTimeInterval(size.resultType) => + stream.windowAll(SlidingEventTimeWindows.of(asTime(size), asTime(slide))) + + case EventTimeSlidingGroupWindow(_, _, size, slide) => + // TODO: EventTimeTumblingGroupWindow should sort the stream on event time + // before applying the windowing logic. Otherwise, this would be the same as a + // ProcessingTimeTumblingGroupWindow + throw new UnsupportedOperationException( + "Event-time grouping windows on row intervals are currently not supported.") + + case ProcessingTimeSessionGroupWindow(_, gap) => + stream.windowAll(ProcessingTimeSessionWindows.withGap(asTime(gap))) + + case EventTimeSessionGroupWindow(_, _, gap) => + stream.windowAll(EventTimeSessionWindows.withGap(asTime(gap))) + } + + def asTime(expr: Expression): Time = expr match { + case Literal(value: Long, TimeIntervalTypeInfo.INTERVAL_MILLIS) => Time.milliseconds(value) + case _ => throw new IllegalArgumentException() + } + + def asCount(expr: Expression): Long = expr match { + case Literal(value: Long, RowIntervalTypeInfo.INTERVAL_ROWS) => value + case _ => throw new IllegalArgumentException() + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala index 5caaf1f..6805a68 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/FlinkRuleSets.scala @@ -171,8 +171,9 @@ object FlinkRuleSets { UnionEliminatorRule.INSTANCE, // translate to DataStream nodes + DataStreamGroupAggregateRule.INSTANCE, DataStreamOverAggregateRule.INSTANCE, - DataStreamAggregateRule.INSTANCE, + DataStreamGroupWindowAggregateRule.INSTANCE, DataStreamCalcRule.INSTANCE, DataStreamScanRule.INSTANCE, DataStreamUnionRule.INSTANCE, http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala deleted file mode 100644 index 09f05d7..0000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamAggregateRule.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.rules.datastream - -import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet} -import org.apache.calcite.rel.RelNode -import org.apache.calcite.rel.convert.ConverterRule -import org.apache.flink.table.api.TableException -import org.apache.flink.table.expressions.Alias -import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate -import org.apache.flink.table.plan.nodes.datastream.{DataStreamAggregate, DataStreamConvention} - -import scala.collection.JavaConversions._ - -class DataStreamAggregateRule - extends ConverterRule( - classOf[LogicalWindowAggregate], - Convention.NONE, - DataStreamConvention.INSTANCE, - "DataStreamAggregateRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate] - - // check if we have distinct aggregates - val distinctAggs = agg.getAggCallList.exists(_.isDistinct) - if (distinctAggs) { - throw TableException("DISTINCT aggregates are currently not supported.") - } - - // check if we have grouping sets - val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet - if (groupSets || agg.indicator) { - throw TableException("GROUPING SETS are currently not supported.") - } - - !distinctAggs && !groupSets && !agg.indicator - } - - override def convert(rel: RelNode): RelNode = { - val agg: LogicalWindowAggregate = rel.asInstanceOf[LogicalWindowAggregate] - val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE) - val convInput: RelNode = RelOptRule.convert(agg.getInput, DataStreamConvention.INSTANCE) - - new DataStreamAggregate( - agg.getWindow, - agg.getNamedProperties, - rel.getCluster, - traitSet, - convInput, - agg.getNamedAggCalls, - rel.getRowType, - agg.getInput.getRowType, - agg.getGroupSet.toArray) - } - } - -object DataStreamAggregateRule { - val INSTANCE: RelOptRule = new DataStreamAggregateRule -} http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala new file mode 100644 index 0000000..82d7104 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupAggregateRule.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.calcite.rel.logical.LogicalAggregate +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.nodes.datastream.{DataStreamGroupAggregate, DataStreamConvention} + +import scala.collection.JavaConversions._ + +/** + * Rule to convert a [[LogicalAggregate]] into a [[DataStreamGroupAggregate]]. + */ +class DataStreamGroupAggregateRule + extends ConverterRule( + classOf[LogicalAggregate], + Convention.NONE, + DataStreamConvention.INSTANCE, + "DataStreamGroupAggregateRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: LogicalAggregate = call.rel(0).asInstanceOf[LogicalAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: LogicalAggregate = rel.asInstanceOf[LogicalAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE) + val convInput: RelNode = RelOptRule.convert(agg.getInput, DataStreamConvention.INSTANCE) + + new DataStreamGroupAggregate( + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + rel.getRowType, + agg.getInput.getRowType, + agg.getGroupSet.toArray) + } +} + +object DataStreamGroupAggregateRule { + val INSTANCE: RelOptRule = new DataStreamGroupAggregateRule +} + http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala new file mode 100644 index 0000000..7ec1d40 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/datastream/DataStreamGroupWindowAggregateRule.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.rules.datastream + +import org.apache.calcite.plan.{Convention, RelOptRule, RelOptRuleCall, RelTraitSet} +import org.apache.calcite.rel.RelNode +import org.apache.calcite.rel.convert.ConverterRule +import org.apache.flink.table.api.TableException +import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate +import org.apache.flink.table.plan.nodes.datastream.{DataStreamConvention, DataStreamGroupWindowAggregate} + +import scala.collection.JavaConversions._ + +class DataStreamGroupWindowAggregateRule + extends ConverterRule( + classOf[LogicalWindowAggregate], + Convention.NONE, + DataStreamConvention.INSTANCE, + "DataStreamGroupWindowAggregateRule") { + + override def matches(call: RelOptRuleCall): Boolean = { + val agg: LogicalWindowAggregate = call.rel(0).asInstanceOf[LogicalWindowAggregate] + + // check if we have distinct aggregates + val distinctAggs = agg.getAggCallList.exists(_.isDistinct) + if (distinctAggs) { + throw TableException("DISTINCT aggregates are currently not supported.") + } + + // check if we have grouping sets + val groupSets = agg.getGroupSets.size() != 1 || agg.getGroupSets.get(0) != agg.getGroupSet + if (groupSets || agg.indicator) { + throw TableException("GROUPING SETS are currently not supported.") + } + + !distinctAggs && !groupSets && !agg.indicator + } + + override def convert(rel: RelNode): RelNode = { + val agg: LogicalWindowAggregate = rel.asInstanceOf[LogicalWindowAggregate] + val traitSet: RelTraitSet = rel.getTraitSet.replace(DataStreamConvention.INSTANCE) + val convInput: RelNode = RelOptRule.convert(agg.getInput, DataStreamConvention.INSTANCE) + + new DataStreamGroupWindowAggregate( + agg.getWindow, + agg.getNamedProperties, + rel.getCluster, + traitSet, + convInput, + agg.getNamedAggCalls, + rel.getRowType, + agg.getInput.getRowType, + agg.getGroupSet.toArray) + } + } + +object DataStreamGroupWindowAggregateRule { + val INSTANCE: RelOptRule = new DataStreamGroupWindowAggregateRule +} http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 09d1a13..634f7c8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -79,8 +79,7 @@ object AggregateUtil { inputType, needRetraction = false) - val aggregationStateType: RowTypeInfo = - createDataSetAggregateBufferDataType(Array(), aggregates, inputType) + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, x)).toArray val aggMapping = aggregates.indices.map(x => x + inputType.getFieldCount).toArray @@ -125,7 +124,36 @@ object AggregateUtil { } /** - * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for group (without + * window) aggregate to evaluate final aggregate value. + * + * @param namedAggregates List of calls to aggregate functions and their output field names + * @param inputType Input row type + * @param groupings the position (in the input Row) of the grouping keys + * @return [[org.apache.flink.streaming.api.functions.ProcessFunction]] + */ + private[flink] def createGroupAggregateFunction( + namedAggregates: Seq[CalcitePair[AggregateCall, String]], + inputType: RelDataType, + groupings: Array[Int]): ProcessFunction[Row, Row] = { + + val (aggFields, aggregates) = + transformToAggregateFunctions( + namedAggregates.map(_.getKey), + inputType, + needRetraction = false) + + val aggregationStateType: RowTypeInfo = createAccumulatorRowType(aggregates) + + new GroupAggProcessFunction( + aggregates, + aggFields, + groupings, + aggregationStateType) + } + + /** + * Create an [[org.apache.flink.streaming.api.functions.ProcessFunction]] for ROWS clause * bounded OVER window to evaluate final aggregate value. * * @param generator code generator instance @@ -238,7 +266,7 @@ object AggregateUtil { needRetraction = false) val mapReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -321,7 +349,7 @@ object AggregateUtil { inputType, needRetraction = false)._2 - val returnType: RowTypeInfo = createDataSetAggregateBufferDataType( + val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -550,7 +578,7 @@ object AggregateUtil { window match { case EventTimeSessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -600,7 +628,7 @@ object AggregateUtil { case EventTimeSessionGroupWindow(_, _, gap) => val combineReturnType: RowTypeInfo = - createDataSetAggregateBufferDataType( + createRowTypeForKeysAndAggregates( groupings, aggregates, inputType, @@ -745,7 +773,7 @@ object AggregateUtil { } } - private[flink] def createDataStreamAggregateFunction( + private[flink] def createDataStreamGroupWindowAggregateFunction( namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, outputType: RelDataType, @@ -1125,7 +1153,7 @@ object AggregateUtil { aggTypes } - private def createDataSetAggregateBufferDataType( + private def createRowTypeForKeysAndAggregates( groupings: Array[Int], aggregates: Array[TableAggregateFunction[_]], inputType: RelDataType, http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala new file mode 100644 index 0000000..3813aa0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.api.common.state.ValueStateDescriptor +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.common.state.ValueState +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} + +/** + * Aggregate Function used for the groupby (without window) aggregate + * + * @param aggregates the list of all + * [[org.apache.flink.table.functions.AggregateFunction]] used for + * this aggregation + * @param aggFields the position (in the input Row) of the input value for each + * aggregate + * @param groupings the position (in the input Row) of the grouping keys + * @param aggregationStateType the row type info of aggregation + */ +class GroupAggProcessFunction( + private val aggregates: Array[AggregateFunction[_]], + private val aggFields: Array[Array[Int]], + private val groupings: Array[Int], + private val aggregationStateType: RowTypeInfo) + extends ProcessFunction[Row, Row] { + + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(aggFields) + Preconditions.checkArgument(aggregates.length == aggFields.length) + + private var output: Row = _ + private var state: ValueState[Row] = _ + + override def open(config: Configuration) { + output = new Row(groupings.length + aggregates.length) + val stateDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("GroupAggregateState", aggregationStateType) + state = getRuntimeContext.getState(stateDescriptor) + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + var i = 0 + + var accumulators = state.value() + + if (null == accumulators) { + accumulators = new Row(aggregates.length) + i = 0 + while (i < aggregates.length) { + accumulators.setField(i, aggregates(i).createAccumulator()) + i += 1 + } + } + + // Set group keys value to the final output + i = 0 + while (i < groupings.length) { + output.setField(i, input.getField(groupings(i))) + i += 1 + } + + // Set aggregate result to the final output + i = 0 + while (i < aggregates.length) { + val index = groupings.length + i + val accumulator = accumulators.getField(i).asInstanceOf[Accumulator] + aggregates(i).accumulate(accumulator, input.getField(aggFields(i)(0))) + output.setField(index, aggregates(i).getValue(accumulator)) + i += 1 + } + state.update(accumulators) + + out.collect(output) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala index 4d0d9aa..ebdf9de 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/FieldProjectionTest.scala @@ -231,7 +231,7 @@ class FieldProjectionTest extends TableTestBase { val expected = unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -259,7 +259,7 @@ class FieldProjectionTest extends TableTestBase { val expected = unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala index 67d13b0..f7bdccf 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/SqlITCase.scala @@ -47,6 +47,27 @@ class SqlITCase extends StreamingWithStateTestBase { (8L, 8, "Hello World"), (20L, 20, "Hello World")) + /** test unbounded groupby (without window) **/ + @Test + def testUnboundedGroupby(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val sqlQuery = "SELECT b, COUNT(a) FROM MyTable GROUP BY b" + + val t = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + val result = tEnv.sql(sqlQuery).toDataStream[Row] + result.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList("1,1", "2,1", "2,2") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + /** test selection **/ @Test def testSelectExpressionFromTable(): Unit = { http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala index 1c1752f..8fa9e6f 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/sql/WindowAggregateTest.scala @@ -17,8 +17,6 @@ */ package org.apache.flink.table.api.scala.stream.sql -import java.sql.Timestamp - import org.apache.flink.api.scala._ import org.apache.flink.table.api.TableException import org.apache.flink.table.api.scala._ @@ -63,7 +61,7 @@ class WindowAggregateTest extends TableTestBase { val sqlQuery = "SELECT a, AVG(c) OVER (PARTITION BY a ORDER BY procTime()" + "RANGE BETWEEN INTERVAL '2' HOUR PRECEDING AND CURRENT ROW) AS avgA " + "FROM MyTable" - val expected = + val expected = unaryNode( "DataStreamCalc", unaryNode( @@ -73,7 +71,7 @@ class WindowAggregateTest extends TableTestBase { streamTableNode(0), term("select", "a", "c", "PROCTIME() AS $2") ), - term("partitionBy","a"), + term("partitionBy", "a"), term("orderBy", "PROCTIME"), term("range", "BETWEEN 7200000 PRECEDING AND CURRENT ROW"), term("select", "a", "c", "PROCTIME", "COUNT(c) AS w0$o0", "$SUM0(c) AS w0$o1") @@ -84,6 +82,26 @@ class WindowAggregateTest extends TableTestBase { streamUtil.verifySql(sqlQuery, expected) } + def testGroupbyWithoutWindow() = { + val sql = "SELECT COUNT(a) FROM MyTable GROUP BY b" + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a") + ), + term("groupBy", "b"), + term("select", "b", "COUNT(a) AS EXPR$0") + ), + term("select", "EXPR$0") + ) + streamUtil.verifySql(sql, expected) + } + @Test def testNonPartitionedTumbleWindow() = { val sql = "SELECT COUNT(*) FROM MyTable GROUP BY FLOOR(rowtime() TO HOUR)" @@ -91,7 +109,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -112,7 +130,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -134,7 +152,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -156,7 +174,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -177,7 +195,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -199,7 +217,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -222,7 +240,7 @@ class WindowAggregateTest extends TableTestBase { unaryNode( "DataStreamCalc", unaryNode( - "DataStreamAggregate", + "DataStreamGroupWindowAggregate", unaryNode( "DataStreamCalc", streamTableNode(0), @@ -280,13 +298,6 @@ class WindowAggregateTest extends TableTestBase { streamUtil.verifySql(sql, expected) } - @Test(expected = classOf[TableException]) - def testInvalidWindowExpression() = { - val sql = "SELECT COUNT(*) FROM MyTable GROUP BY FLOOR(localTimestamp TO HOUR)" - val expected = "" - streamUtil.verifySql(sql, expected) - } - @Test def testUnboundPartitionedProcessingWindowWithRange() = { val sql = "SELECT " + http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala deleted file mode 100644 index 3e7b66b..0000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/AggregationsITCase.scala +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.api.scala.stream.table - -import org.apache.flink.api.scala._ -import org.apache.flink.streaming.api.TimeCharacteristic -import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks -import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment -import org.apache.flink.streaming.api.watermark.Watermark -import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase -import org.apache.flink.table.api.TableEnvironment -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.api.scala.stream.table.AggregationsITCase.TimestampWithEqualWatermark -import org.apache.flink.table.api.scala.stream.utils.StreamITCase -import org.apache.flink.types.Row -import org.junit.Assert._ -import org.junit.Test - -import scala.collection.mutable - -/** - * We only test some aggregations until better testing of constructed DataStream - * programs is possible. - */ -class AggregationsITCase extends StreamingMultipleProgramsTestBase { - - val data = List( - (1L, 1, "Hi"), - (2L, 2, "Hello"), - (4L, 2, "Hello"), - (8L, 3, "Hello world"), - (16L, 3, "Hello world")) - - @Test - def testProcessingTimeSlidingGroupWindowOverCount(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env.fromCollection(data) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Slide over 2.rows every 1.rows as 'w) - .groupBy('w, 'string) - .select('string, 'int.count, 'int.avg) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("Hello world,1,3", "Hello world,2,3", "Hello,1,2", "Hello,2,2", "Hi,1,1") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testEventTimeSessionGroupWindowOverTime(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env - .fromCollection(data) - .assignTimestampsAndWatermarks(new TimestampWithEqualWatermark()) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Session withGap 7.milli on 'rowtime as 'w) - .groupBy('w, 'string) - .select('string, 'int.count) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("Hello world,1", "Hello world,1", "Hello,2", "Hi,1") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testAllProcessingTimeTumblingGroupWindowOverCount(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env.fromCollection(data) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Tumble over 2.rows as 'w) - .groupBy('w) - .select('int.count) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq("2", "2") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } - - @Test - def testEventTimeTumblingWindow(): Unit = { - val env = StreamExecutionEnvironment.getExecutionEnvironment - env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) - val tEnv = TableEnvironment.getTableEnvironment(env) - StreamITCase.testResults = mutable.MutableList() - - val stream = env - .fromCollection(data) - .assignTimestampsAndWatermarks(new TimestampWithEqualWatermark()) - val table = stream.toTable(tEnv, 'long, 'int, 'string) - - val windowedTable = table - .window(Tumble over 5.milli on 'rowtime as 'w) - .groupBy('w, 'string) - .select('string, 'int.count, 'int.avg, 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end) - - val results = windowedTable.toDataStream[Row] - results.addSink(new StreamITCase.StringSink) - env.execute() - - val expected = Seq( - "Hello world,1,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", - "Hello world,1,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", - "Hello,2,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", - "Hi,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") - assertEquals(expected.sorted, StreamITCase.testResults.sorted) - } -} - -object AggregationsITCase { - class TimestampWithEqualWatermark extends AssignerWithPunctuatedWatermarks[(Long, Int, String)] { - - override def checkAndGetNextWatermark( - lastElement: (Long, Int, String), - extractedTimestamp: Long) - : Watermark = { - new Watermark(extractedTimestamp) - } - - override def extractTimestamp( - element: (Long, Int, String), - previousElementTimestamp: Long): Long = { - element._1 - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala new file mode 100644 index 0000000..271e90b --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsITCase.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.types.Row +import org.junit.Assert.assertEquals +import org.junit.Test + +import scala.collection.mutable + +/** + * Tests of groupby (without window) aggregations + */ +class GroupAggregationsITCase extends StreamingWithStateTestBase { + + @Test + def testNonKeyedGroupAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .select('a.sum, 'b.sum) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", "3,3", "6,5", "10,8", "15,11", "21,14", "28,18", "36,22", "45,26", "55,30", "66,35", + "78,40", "91,45", "105,50", "120,55", "136,61", "153,67", "171,73", "190,79", "210,85", + "231,91") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testGroupAggregate(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('b, 'a.sum) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1,1", "2,2", "2,5", "3,4", "3,9", "3,15", "4,7", "4,15", + "4,24", "4,34", "5,11", "5,23", "5,36", "5,50", "5,65", "6,16", "6,33", "6,51", "6,70", + "6,90", "6,111") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testDoubleGroupAggregation(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get3TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c) + .groupBy('b) + .select('a.sum as 'd, 'b) + .groupBy('b, 'd) + .select('b) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "1", + "2", "2", + "3", "3", "3", + "4", "4", "4", "4", + "5", "5", "5", "5", "5", + "6", "6", "6", "6", "6", "6") + + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testGroupAggregateWithExpression(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStateBackend(getStateBackend) + env.setParallelism(1) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val t = StreamTestData.get5TupleDataStream(env).toTable(tEnv, 'a, 'b, 'c, 'd, 'e) + .groupBy('e, 'b % 3) + .select('c.min, 'e, 'a.avg, 'd.count) + + val results = t.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = mutable.MutableList( + "0,1,1,1", "1,2,2,1", "2,1,2,1", "3,2,3,1", "1,2,2,2", + "5,3,3,1", "3,2,3,2", "7,1,4,1", "2,1,3,2", "3,2,3,3", "7,1,4,2", "5,3,4,2", "12,3,5,1", + "1,2,3,3", "14,2,5,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala new file mode 100644 index 0000000..1f4a694 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupAggregationsTest.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.utils.TableTestBase +import org.junit.Test +import org.apache.flink.table.api.scala._ +import org.apache.flink.api.scala._ +import org.apache.flink.table.utils.TableTestUtil._ + +class GroupAggregationsTest extends TableTestBase { + + @Test(expected = classOf[ValidationException]) + def testGroupingOnNonExistentField(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val ds = table + // must fail. '_foo is not a valid field + .groupBy('_foo) + .select('a.avg) + } + + @Test(expected = classOf[ValidationException]) + def testGroupingInvalidSelection(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val ds = table + .groupBy('a, 'b) + // must fail. 'c is not a grouping key or aggregation + .select('c) + } + + @Test + def testGroupbyWithoutWindow() = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('a.count) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "a", "b") + ), + term("groupBy", "b"), + term("select", "b", "COUNT(a) AS TMP_0") + ), + term("select", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + + @Test + def testGroupAggregateWithConstant1(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('a, 4 as 'four, 'b) + .groupBy('four, 'a) + .select('four, 'b.sum) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "4 AS four", "b", "a") + ), + term("groupBy", "four", "a"), + term("select", "four", "a", "SUM(b) AS TMP_0") + ), + term("select", "4 AS four", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithConstant2(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('b, 4 as 'four, 'a) + .groupBy('b, 'four) + .select('four, 'a.sum) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "4 AS four", "a", "b") + ), + term("groupBy", "four", "b"), + term("select", "four", "b", "SUM(a) AS TMP_0") + ), + term("select", "4 AS four", "TMP_0") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithExpressionInSelect(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .select('a as 'a, 'b % 3 as 'd, 'c as 'c) + .groupBy('d) + .select('c.min, 'a.avg) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "c", "a", "MOD(b, 3) AS d") + ), + term("groupBy", "d"), + term("select", "d", "MIN(c) AS TMP_0", "AVG(a) AS TMP_1") + ), + term("select", "TMP_0", "TMP_1") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithFilter(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('b, 'a.sum) + .where('b === 2) + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a") + ), + term("groupBy", "b"), + term("select", "b", "SUM(a) AS TMP_0") + ), + term("select", "b", "TMP_0"), + term("where", "=(b, 2)") + ) + util.verifyTable(resultTable, expected) + } + + @Test + def testGroupAggregateWithAverage(): Unit = { + val util = streamTestUtil() + val table = util.addTable[(Long, Int, String)]('a, 'b, 'c) + + val resultTable = table + .groupBy('b) + .select('b, 'a.cast(BasicTypeInfo.DOUBLE_TYPE_INFO).avg) + + val expected = + unaryNode( + "DataStreamGroupAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(0), + term("select", "b", "a", "CAST(a) AS a0") + ), + term("groupBy", "b"), + term("select", "b", "AVG(a0) AS TMP_0") + ) + + util.verifyTable(resultTable, expected) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/ff262508/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala new file mode 100644 index 0000000..b8fc49b --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/stream/table/GroupWindowAggregationsITCase.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.stream.table + +import org.apache.flink.api.scala._ +import org.apache.flink.streaming.api.TimeCharacteristic +import org.apache.flink.streaming.api.functions.AssignerWithPunctuatedWatermarks +import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment +import org.apache.flink.streaming.api.watermark.Watermark +import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase +import org.apache.flink.table.api.TableEnvironment +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.api.scala.stream.table.GroupWindowAggregationsITCase.TimestampWithEqualWatermark +import org.apache.flink.table.api.scala.stream.utils.StreamITCase +import org.apache.flink.types.Row +import org.junit.Assert._ +import org.junit.Test + +import scala.collection.mutable + +/** + * We only test some aggregations until better testing of constructed DataStream + * programs is possible. + */ +class GroupWindowAggregationsITCase extends StreamingMultipleProgramsTestBase { + + val data = List( + (1L, 1, "Hi"), + (2L, 2, "Hello"), + (4L, 2, "Hello"), + (8L, 3, "Hello world"), + (16L, 3, "Hello world")) + + @Test + def testProcessingTimeSlidingGroupWindowOverCount(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Slide over 2.rows every 1.rows as 'w) + .groupBy('w, 'string) + .select('string, 'int.count, 'int.avg) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("Hello world,1,3", "Hello world,2,3", "Hello,1,2", "Hello,2,2", "Hi,1,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testEventTimeSessionGroupWindowOverTime(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampWithEqualWatermark()) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Session withGap 7.milli on 'rowtime as 'w) + .groupBy('w, 'string) + .select('string, 'int.count) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("Hello world,1", "Hello world,1", "Hello,2", "Hi,1") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testAllProcessingTimeTumblingGroupWindowOverCount(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env.fromCollection(data) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Tumble over 2.rows as 'w) + .groupBy('w) + .select('int.count) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq("2", "2") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } + + @Test + def testEventTimeTumblingWindow(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime) + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.testResults = mutable.MutableList() + + val stream = env + .fromCollection(data) + .assignTimestampsAndWatermarks(new TimestampWithEqualWatermark()) + val table = stream.toTable(tEnv, 'long, 'int, 'string) + + val windowedTable = table + .window(Tumble over 5.milli on 'rowtime as 'w) + .groupBy('w, 'string) + .select('string, 'int.count, 'int.avg, 'int.min, 'int.max, 'int.sum, 'w.start, 'w.end) + + val results = windowedTable.toDataStream[Row] + results.addSink(new StreamITCase.StringSink) + env.execute() + + val expected = Seq( + "Hello world,1,3,3,3,3,1970-01-01 00:00:00.005,1970-01-01 00:00:00.01", + "Hello world,1,3,3,3,3,1970-01-01 00:00:00.015,1970-01-01 00:00:00.02", + "Hello,2,2,2,2,4,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005", + "Hi,1,1,1,1,1,1970-01-01 00:00:00.0,1970-01-01 00:00:00.005") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } +} + +object GroupWindowAggregationsITCase { + class TimestampWithEqualWatermark extends AssignerWithPunctuatedWatermarks[(Long, Int, String)] { + + override def checkAndGetNextWatermark( + lastElement: (Long, Int, String), + extractedTimestamp: Long) + : Watermark = { + new Watermark(extractedTimestamp) + } + + override def extractTimestamp( + element: (Long, Int, String), + previousElementTimestamp: Long): Long = { + element._1 + } + } +}