Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 215DF200A5B for ; Wed, 25 May 2016 16:15:10 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 202CC160A0F; Wed, 25 May 2016 14:15:10 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id C2216160A29 for ; Wed, 25 May 2016 16:15:07 +0200 (CEST) Received: (qmail 77815 invoked by uid 500); 25 May 2016 14:15:07 -0000 Mailing-List: contact commits-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list commits@flink.apache.org Received: (qmail 77804 invoked by uid 99); 25 May 2016 14:15:06 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 25 May 2016 14:15:06 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id C897EDFC8C; Wed, 25 May 2016 14:15:06 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: fhueske@apache.org To: commits@flink.apache.org Message-Id: <28907da49db142f79cb2c508f903dd26@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: flink git commit: [FLINK-3586] Fix potential overflow of Long AVG aggregation. Date: Wed, 25 May 2016 14:15:06 +0000 (UTC) archived-at: Wed, 25 May 2016 14:15:10 -0000 Repository: flink Updated Branches: refs/heads/master 5b9872492 -> af0f41824 [FLINK-3586] Fix potential overflow of Long AVG aggregation. - Add unit tests for Aggretates. This closes #2024 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/af0f4182 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/af0f4182 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/af0f4182 Branch: refs/heads/master Commit: af0f41824a1b54b71060c9ddd4f4830d45436172 Parents: 5b98724 Author: Fabian Hueske Authored: Sun May 22 16:46:43 2016 +0200 Committer: Fabian Hueske Committed: Wed May 25 15:30:04 2016 +0200 ---------------------------------------------------------------------- .../api/table/runtime/aggregate/Aggregate.scala | 61 +++++----- .../table/runtime/aggregate/AggregateUtil.scala | 29 ++--- .../table/runtime/aggregate/AvgAggregate.scala | 51 +++++--- .../runtime/aggregate/CountAggregate.scala | 6 +- .../table/runtime/aggregate/MaxAggregate.scala | 35 ++---- .../table/runtime/aggregate/MinAggregate.scala | 35 ++---- .../table/runtime/aggregate/SumAggregate.scala | 38 ++---- .../runtime/aggregate/AggregateTestBase.scala | 104 +++++++++++++++++ .../runtime/aggregate/AvgAggregateTest.scala | 115 +++++++++++++++++++ .../runtime/aggregate/CountAggregateTest.scala | 30 +++++ .../runtime/aggregate/MaxAggregateTest.scala | 93 +++++++++++++++ .../runtime/aggregate/MinAggregateTest.scala | 93 +++++++++++++++ .../runtime/aggregate/SumAggregateTest.scala | 89 ++++++++++++++ 13 files changed, 635 insertions(+), 144 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala index 496dcfb..1e91711 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.api.table.runtime.aggregate -import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.table.Row /** @@ -43,47 +43,54 @@ import org.apache.flink.api.table.Row trait Aggregate[T] extends Serializable { /** - * Initiate the intermediate aggregate value in Row. - * @param intermediate - */ - def initiate(intermediate: Row): Unit + * Transform the aggregate field value into intermediate aggregate data. + * + * @param value The value to insert into the intermediate aggregate row. + * @param intermediate The intermediate aggregate row into which the value is inserted. + */ + def prepare(value: Any, intermediate: Row): Unit /** - * Transform the aggregate field value into intermediate aggregate data. - * @param value - * @param intermediate - */ - def prepare(value: Any, intermediate: Row): Unit + * Initiate the intermediate aggregate value in Row. + * + * @param intermediate The intermediate aggregate row to initiate. + */ + def initiate(intermediate: Row): Unit /** - * Merge intermediate aggregate data into aggregate buffer. - * @param intermediate - * @param buffer - */ + * Merge intermediate aggregate data into aggregate buffer. + * + * @param intermediate The intermediate aggregate row to merge. + * @param buffer The aggregate buffer into which the intermedidate is merged. + */ def merge(intermediate: Row, buffer: Row): Unit /** - * Calculate the final aggregated result based on aggregate buffer. - * @param buffer - * @return - */ + * Calculate the final aggregated result based on aggregate buffer. + * + * @param buffer The aggregate buffer from which the final aggregate is computed. + * @return The final result of the aggregate. + */ def evaluate(buffer: Row): T /** - * Intermediate aggregate value types. - * @return - */ - def intermediateDataType: Array[SqlTypeName] + * Intermediate aggregate value types. + * + * @return The types of the intermediate fields of this aggregate. + */ + def intermediateDataType: Array[TypeInformation[_]] /** - * Set the aggregate data offset in Row. - * @param aggOffset - */ + * Set the aggregate data offset in Row. + * + * @param aggOffset The offset of this aggregate in the intermediate aggregate rows. + */ def setAggOffsetInRow(aggOffset: Int) /** * Whether aggregate function support partial aggregate. - * @return - */ + * + * @return True if the aggregate supports partial aggregation, False otherwise. + */ def supportPartial: Boolean = false } http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala index bdc662a..8222a2e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala @@ -26,8 +26,8 @@ import org.apache.calcite.sql.`type`.SqlTypeName._ import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, SqlTypeName} import org.apache.calcite.sql.fun._ import org.apache.flink.api.common.functions.{GroupReduceFunction, MapFunction} +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.table.typeutils.TypeConverter -import TypeConverter._ import org.apache.flink.api.table.typeutils.RowTypeInfo import org.apache.flink.api.table.{TableException, Row, TableConfig} @@ -73,15 +73,9 @@ object AggregateUtil { val aggFieldIndexes = aggregateFunctionsAndFieldIndexes._1 val aggregates = aggregateFunctionsAndFieldIndexes._2 - val bufferDataType: RelRecordType = + val mapReturnType: RowTypeInfo = createAggregateBufferDataType(groupings, aggregates, inputType) - val mapReturnType = determineReturnType( - bufferDataType, - Some(TypeConverter.DEFAULT_ROW_TYPE), - config.getNullCheck, - config.getEfficientTypeUsage) - val mapFunction = new AggregateMapFunction[Row, Row]( aggregates, aggFieldIndexes, groupings, mapReturnType.asInstanceOf[RowTypeInfo]).asInstanceOf[MapFunction[Any, Row]] @@ -240,25 +234,22 @@ object AggregateUtil { private def createAggregateBufferDataType( groupings: Array[Int], aggregates: Array[Aggregate[_]], - inputType: RelDataType): RelRecordType = { + inputType: RelDataType): RowTypeInfo = { // get the field data types of group keys. - val groupingTypes: Seq[RelDataTypeField] = groupings.map(inputType.getFieldList.get(_)) + val groupingTypes: Seq[TypeInformation[_]] = groupings + .map(inputType.getFieldList.get(_).getType.getSqlTypeName) + .map(TypeConverter.sqlTypeToTypeInfo) val aggPartialNameSuffix = "agg_buffer_" val factory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT) - // get all the aggregate buffer value data type by their SqlTypeName. - val aggTypes: Seq[RelDataTypeField] = - aggregates.flatMap(_.intermediateDataType).zipWithIndex.map { - case (typeName: SqlTypeName, index: Int) => - val fieldDataType = factory.createSqlType(typeName) - new RelDataTypeFieldImpl(aggPartialNameSuffix + index, - groupings.length + index, fieldDataType) - } + // get all field data types of all intermediate aggregates + val aggTypes: Seq[TypeInformation[_]] = aggregates.flatMap(_.intermediateDataType) + // concat group key types and aggregation types val allFieldTypes = groupingTypes ++: aggTypes - val partialType = new RelRecordType(allFieldTypes.toList) + val partialType = new RowTypeInfo(allFieldTypes) partialType } http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala index 8d3a45b..8cf181a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala @@ -18,8 +18,9 @@ package org.apache.flink.api.table.runtime.aggregate import com.google.common.math.LongMath -import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.table.Row +import java.math.BigInteger abstract class AvgAggregate[T] extends Aggregate[T] { protected var partialSumIndex: Int = _ @@ -34,8 +35,6 @@ abstract class AvgAggregate[T] extends Aggregate[T] { } abstract class IntegralAvgAggregate[T] extends AvgAggregate[T] { - private final val intermediateType = Array(SqlTypeName.BIGINT, SqlTypeName.BIGINT) - override def initiate(partial: Row): Unit = { partial.setField(partialSumIndex, 0L) @@ -60,9 +59,9 @@ abstract class IntegralAvgAggregate[T] extends AvgAggregate[T] { buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount)) } - override def intermediateDataType: Array[SqlTypeName] = { - intermediateType - } + override def intermediateDataType = Array( + BasicTypeInfo.LONG_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) def doPrepare(value: Any, partial: Row): Unit } @@ -113,21 +112,47 @@ class IntAvgAggregate extends IntegralAvgAggregate[Int] { class LongAvgAggregate extends IntegralAvgAggregate[Long] { + override def intermediateDataType = Array( + BasicTypeInfo.BIG_INT_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) + + override def initiate(partial: Row): Unit = { + partial.setField(partialSumIndex, BigInteger.ZERO) + partial.setField(partialCountIndex, 0L) + } + + override def prepare(value: Any, partial: Row): Unit = { + if (value == null) { + partial.setField(partialSumIndex, BigInteger.ZERO) + partial.setField(partialCountIndex, 0L) + } else { + doPrepare(value, partial) + } + } + override def doPrepare(value: Any, partial: Row): Unit = { val input = value.asInstanceOf[Long] - partial.setField(partialSumIndex, input) + partial.setField(partialSumIndex, BigInteger.valueOf(input)) partial.setField(partialCountIndex, 1L) } + override def merge(partial: Row, buffer: Row): Unit = { + val partialSum = partial.productElement(partialSumIndex).asInstanceOf[BigInteger] + val partialCount = partial.productElement(partialCountIndex).asInstanceOf[Long] + val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigInteger] + val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long] + buffer.setField(partialSumIndex, partialSum.add(bufferSum)) + buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount)) + } + override def evaluate(buffer: Row): Long = { - val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[Long] + val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigInteger] val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long] - (bufferSum / bufferCount) + bufferSum.divide(BigInteger.valueOf(bufferCount)).longValue() } } abstract class FloatingAvgAggregate[T: Numeric] extends AvgAggregate[T] { - private val partialType = Array(SqlTypeName.DOUBLE, SqlTypeName.BIGINT) override def initiate(partial: Row): Unit = { partial.setField(partialSumIndex, 0D) @@ -153,9 +178,9 @@ abstract class FloatingAvgAggregate[T: Numeric] extends AvgAggregate[T] { buffer.setField(partialCountIndex, partialCount + bufferCount) } - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array( + BasicTypeInfo.DOUBLE_TYPE_INFO, + BasicTypeInfo.LONG_TYPE_INFO) def doPrepare(value: Any, partial: Row): Unit } http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala index d615088..d9f288a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.api.table.runtime.aggregate -import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.table.Row class CountAggregate extends Aggregate[Long] { @@ -45,9 +45,7 @@ class CountAggregate extends Aggregate[Long] { } } - override def intermediateDataType: Array[SqlTypeName] = { - Array(SqlTypeName.BIGINT) - } + override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO) override def supportPartial: Boolean = true http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala index fde1b53..8f491f2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.api.table.runtime.aggregate -import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.table.Row abstract class MaxAggregate[T: Numeric] extends Aggregate[T] { @@ -27,6 +27,7 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] { /** * Accessed in MapFunction, prepare the input of partial aggregate. + * * @param value * @param intermediate */ @@ -41,6 +42,7 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] { /** * Accessed in CombineFunction and GroupReduceFunction, merge partial * aggregate result into aggregate buffer. + * * @param intermediate * @param buffer */ @@ -52,6 +54,7 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] { /** * Return the final aggregated result based on aggregate buffer. + * * @param buffer * @return */ @@ -67,11 +70,8 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] { } class ByteMaxAggregate extends MaxAggregate[Byte] { - private val intermediateType = Array(SqlTypeName.TINYINT) - override def intermediateDataType: Array[SqlTypeName] = { - intermediateType - } + override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(maxIndex, Byte.MinValue) @@ -79,11 +79,8 @@ class ByteMaxAggregate extends MaxAggregate[Byte] { } class ShortMaxAggregate extends MaxAggregate[Short] { - private val intermediateType = Array(SqlTypeName.SMALLINT) - override def intermediateDataType: Array[SqlTypeName] = { - intermediateType - } + override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(maxIndex, Short.MinValue) @@ -91,11 +88,8 @@ class ShortMaxAggregate extends MaxAggregate[Short] { } class IntMaxAggregate extends MaxAggregate[Int] { - private val intermediateType = Array(SqlTypeName.INTEGER) - override def intermediateDataType: Array[SqlTypeName] = { - intermediateType - } + override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(maxIndex, Int.MinValue) @@ -103,11 +97,8 @@ class IntMaxAggregate extends MaxAggregate[Int] { } class LongMaxAggregate extends MaxAggregate[Long] { - private val intermediateType = Array(SqlTypeName.BIGINT) - override def intermediateDataType: Array[SqlTypeName] = { - intermediateType - } + override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(maxIndex, Long.MinValue) @@ -115,11 +106,8 @@ class LongMaxAggregate extends MaxAggregate[Long] { } class FloatMaxAggregate extends MaxAggregate[Float] { - private val intermediateType = Array(SqlTypeName.FLOAT) - override def intermediateDataType: Array[SqlTypeName] = { - intermediateType - } + override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(maxIndex, Float.MinValue) @@ -127,11 +115,8 @@ class FloatMaxAggregate extends MaxAggregate[Float] { } class DoubleMaxAggregate extends MaxAggregate[Double] { - private val intermediateType = Array(SqlTypeName.DOUBLE) - override def intermediateDataType: Array[SqlTypeName] = { - intermediateType - } + override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(maxIndex, Double.MinValue) http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala index 7cc1b48..e78fb00 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.api.table.runtime.aggregate -import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.table.Row abstract class MinAggregate[T: Numeric] extends Aggregate[T]{ @@ -27,6 +27,7 @@ abstract class MinAggregate[T: Numeric] extends Aggregate[T]{ /** * Accessed in MapFunction, prepare the input of partial aggregate. + * * @param value * @param partial */ @@ -41,6 +42,7 @@ abstract class MinAggregate[T: Numeric] extends Aggregate[T]{ /** * Accessed in CombineFunction and GroupReduceFunction, merge partial * aggregate result into aggregate buffer. + * * @param partial * @param buffer */ @@ -52,6 +54,7 @@ abstract class MinAggregate[T: Numeric] extends Aggregate[T]{ /** * Return the final aggregated result based on aggregate buffer. + * * @param buffer * @return */ @@ -67,11 +70,8 @@ abstract class MinAggregate[T: Numeric] extends Aggregate[T]{ } class ByteMinAggregate extends MinAggregate[Byte] { - private val partialType = Array(SqlTypeName.TINYINT) - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(minIndex, Byte.MaxValue) @@ -79,11 +79,8 @@ class ByteMinAggregate extends MinAggregate[Byte] { } class ShortMinAggregate extends MinAggregate[Short] { - private val partialType = Array(SqlTypeName.SMALLINT) - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(minIndex, Short.MaxValue) @@ -91,11 +88,8 @@ class ShortMinAggregate extends MinAggregate[Short] { } class IntMinAggregate extends MinAggregate[Int] { - private val partialType = Array(SqlTypeName.INTEGER) - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(minIndex, Int.MaxValue) @@ -103,11 +97,8 @@ class IntMinAggregate extends MinAggregate[Int] { } class LongMinAggregate extends MinAggregate[Long] { - private val partialType = Array(SqlTypeName.BIGINT) - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(minIndex, Long.MaxValue) @@ -115,11 +106,8 @@ class LongMinAggregate extends MinAggregate[Long] { } class FloatMinAggregate extends MinAggregate[Float] { - private val partialType = Array(SqlTypeName.FLOAT) - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(minIndex, Float.MaxValue) @@ -127,11 +115,8 @@ class FloatMinAggregate extends MinAggregate[Float] { } class DoubleMinAggregate extends MinAggregate[Double] { - private val partialType = Array(SqlTypeName.DOUBLE) - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO) override def initiate(intermediate: Row): Unit = { intermediate.setField(minIndex, Double.MaxValue) http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala index 25ef344..b4c56fe 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.api.table.runtime.aggregate -import org.apache.calcite.sql.`type`.SqlTypeName +import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.api.table.Row abstract class SumAggregate[T: Numeric] @@ -57,49 +57,25 @@ abstract class SumAggregate[T: Numeric] } class ByteSumAggregate extends SumAggregate[Byte] { - private val partialType = Array(SqlTypeName.TINYINT) - - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO) } class ShortSumAggregate extends SumAggregate[Short] { - private val partialType = Array(SqlTypeName.SMALLINT) - - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO) } class IntSumAggregate extends SumAggregate[Int] { - private val partialType = Array(SqlTypeName.INTEGER) - - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO) } class LongSumAggregate extends SumAggregate[Long] { - private val partialType = Array(SqlTypeName.BIGINT) - - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO) } class FloatSumAggregate extends SumAggregate[Float] { - private val partialType = Array(SqlTypeName.FLOAT) - - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO) } class DoubleSumAggregate extends SumAggregate[Double] { - private val partialType = Array(SqlTypeName.DOUBLE) - - override def intermediateDataType: Array[SqlTypeName] = { - partialType - } + override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO) } http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala new file mode 100644 index 0000000..78d5f8c --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +import org.apache.flink.api.table.Row +import org.junit.Test +import org.junit.Assert.assertEquals + +abstract class AggregateTestBase[T] { + + private val offset = 2 + private val rowArity: Int = offset + aggregator.intermediateDataType.length + + def inputValueSets: Seq[Seq[_]] + + def expectedResults: Seq[T] + + def aggregator: Aggregate[T] + + private def createAggregator(): Aggregate[T] = { + val agg = aggregator + agg.setAggOffsetInRow(offset) + agg + } + + private def createRow(): Row = { + new Row(rowArity) + } + + @Test + def testAggregate(): Unit = { + + // iterate over input sets + for((vals, expected) <- inputValueSets.zip(expectedResults)) { + + // prepare mapper + val rows: Seq[Row] = prepare(vals) + + val result = if (aggregator.supportPartial) { + // test with combiner + val (firstVals, secondVals) = rows.splitAt(rows.length / 2) + val combined = partialAgg(firstVals) :: partialAgg(secondVals) :: Nil + finalAgg(combined) + + } else { + // test without combiner + finalAgg(rows) + } + + assertEquals(expected, result) + + } + } + + private def prepare(vals: Seq[_]): Seq[Row] = { + + val agg = createAggregator() + + vals.map { v => + val row = createRow() + agg.prepare(v, row) + row + } + } + + private def partialAgg(rows: Seq[Row]): Row = { + + val agg = createAggregator() + val aggBuf = createRow() + + agg.initiate(aggBuf) + rows.foreach(v => agg.merge(v, aggBuf)) + + aggBuf + } + + private def finalAgg(rows: Seq[Row]): T = { + + val agg = createAggregator() + val aggBuf = createRow() + + agg.initiate(aggBuf) + rows.foreach(v => agg.merge(v, aggBuf)) + + agg.evaluate(partialAgg(rows)) + } + +} http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala new file mode 100644 index 0000000..2575fa2 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def minVal: T + def maxVal: T + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + minVal, + minVal, + null.asInstanceOf[T], + minVal, + minVal, + null.asInstanceOf[T], + minVal, + minVal, + minVal + ), + Seq( + maxVal, + maxVal, + null.asInstanceOf[T], + maxVal, + maxVal, + null.asInstanceOf[T], + maxVal, + maxVal, + maxVal + ), + Seq( + minVal, + maxVal, + null.asInstanceOf[T], + numeric.fromInt(0), + numeric.negate(maxVal), + numeric.negate(minVal), + null.asInstanceOf[T] + ) + ) + + override def expectedResults: Seq[T] = Seq( + minVal, + maxVal, + numeric.fromInt(0) + ) +} + +class ByteAvgAggregateTest extends AvgAggregateTestBase[Byte] { + + override def minVal = (Byte.MinValue + 1).toByte + override def maxVal = (Byte.MaxValue - 1).toByte + + override def aggregator = new ByteAvgAggregate() +} + +class ShortAvgAggregateTest extends AvgAggregateTestBase[Short] { + + override def minVal = (Short.MinValue + 1).toShort + override def maxVal = (Short.MaxValue - 1).toShort + + override def aggregator = new ShortAvgAggregate() +} + +class IntAvgAggregateTest extends AvgAggregateTestBase[Int] { + + override def minVal = Int.MinValue + 1 + override def maxVal = Int.MaxValue - 1 + + override def aggregator = new IntAvgAggregate() +} + +class LongAvgAggregateTest extends AvgAggregateTestBase[Long] { + + override def minVal = Long.MinValue + 1 + override def maxVal = Long.MaxValue - 1 + + override def aggregator = new LongAvgAggregate() +} + +class FloatAvgAggregateTest extends AvgAggregateTestBase[Float] { + + override def minVal = Float.MinValue + override def maxVal = Float.MaxValue + + override def aggregator = new FloatAvgAggregate() +} + +class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] { + + override def minVal = Float.MinValue + override def maxVal = Float.MaxValue + + override def aggregator = new DoubleAvgAggregate() +} http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala new file mode 100644 index 0000000..ce27d7c --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +class CountAggregateTest extends AggregateTestBase[Long] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq("a", "b", null, "c", null, "d", "e", null, "f") + ) + + override def expectedResults: Seq[Long] = Seq(6L) + + override def aggregator: Aggregate[Long] = new CountAggregate() +} http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala new file mode 100644 index 0000000..f3951e4 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def minVal: T + def maxVal: T + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + numeric.fromInt(1), + null.asInstanceOf[T], + maxVal, + numeric.fromInt(-99), + numeric.fromInt(3), + numeric.fromInt(56), + numeric.fromInt(0), + minVal, + numeric.fromInt(-20), + numeric.fromInt(17), + null.asInstanceOf[T] + ) + ) + + override def expectedResults: Seq[T] = Seq(maxVal) +} + +class ByteMaxAggregateTest extends MaxAggregateTestBase[Byte] { + + override def minVal = (Byte.MinValue + 1).toByte + override def maxVal = (Byte.MaxValue - 1).toByte + + override def aggregator: Aggregate[Byte] = new ByteMaxAggregate() +} + +class ShortMaxAggregateTest extends MaxAggregateTestBase[Short] { + + override def minVal = (Short.MinValue + 1).toShort + override def maxVal = (Short.MaxValue - 1).toShort + + override def aggregator: Aggregate[Short] = new ShortMaxAggregate() +} + +class IntMaxAggregateTest extends MaxAggregateTestBase[Int] { + + override def minVal = Int.MinValue + 1 + override def maxVal = Int.MaxValue - 1 + + override def aggregator: Aggregate[Int] = new IntMaxAggregate() +} + +class LongMaxAggregateTest extends MaxAggregateTestBase[Long] { + + override def minVal = Long.MinValue + 1 + override def maxVal = Long.MaxValue - 1 + + override def aggregator: Aggregate[Long] = new LongMaxAggregate() +} + +class FloatMaxAggregateTest extends MaxAggregateTestBase[Float] { + + override def minVal = Float.MinValue / 2 + override def maxVal = Float.MaxValue / 2 + + override def aggregator: Aggregate[Float] = new FloatMaxAggregate() +} + +class DoubleMaxAggregateTest extends MaxAggregateTestBase[Double] { + + override def minVal = Double.MinValue / 2 + override def maxVal = Double.MaxValue / 2 + + override def aggregator: Aggregate[Double] = new DoubleMaxAggregate() +} http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala new file mode 100644 index 0000000..3a4b111 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def minVal: T + def maxVal: T + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + numeric.fromInt(1), + null.asInstanceOf[T], + maxVal, + numeric.fromInt(-99), + numeric.fromInt(3), + numeric.fromInt(56), + numeric.fromInt(0), + minVal, + numeric.fromInt(-20), + numeric.fromInt(17), + null.asInstanceOf[T] + ) + ) + + override def expectedResults: Seq[T] = Seq(minVal) +} + +class ByteMinAggregateTest extends MinAggregateTestBase[Byte] { + + override def minVal = (Byte.MinValue + 1).toByte + override def maxVal = (Byte.MaxValue - 1).toByte + + override def aggregator: Aggregate[Byte] = new ByteMinAggregate() +} + +class ShortMinAggregateTest extends MinAggregateTestBase[Short] { + + override def minVal = (Short.MinValue + 1).toShort + override def maxVal = (Short.MaxValue - 1).toShort + + override def aggregator: Aggregate[Short] = new ShortMinAggregate() +} + +class IntMinAggregateTest extends MinAggregateTestBase[Int] { + + override def minVal = Int.MinValue + 1 + override def maxVal = Int.MaxValue - 1 + + override def aggregator: Aggregate[Int] = new IntMinAggregate() +} + +class LongMinAggregateTest extends MinAggregateTestBase[Long] { + + override def minVal = Long.MinValue + 1 + override def maxVal = Long.MaxValue - 1 + + override def aggregator: Aggregate[Long] = new LongMinAggregate() +} + +class FloatMinAggregateTest extends MinAggregateTestBase[Float] { + + override def minVal = Float.MinValue / 2 + override def maxVal = Float.MaxValue / 2 + + override def aggregator: Aggregate[Float] = new FloatMinAggregate() +} + +class DoubleMinAggregateTest extends MinAggregateTestBase[Double] { + + override def minVal = Double.MinValue / 2 + override def maxVal = Double.MaxValue / 2 + + override def aggregator: Aggregate[Double] = new DoubleMinAggregate() +} http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala new file mode 100644 index 0000000..f5de3fc --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.runtime.aggregate + +abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def maxVal: T + private val minVal = numeric.negate(maxVal) + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + minVal, + numeric.fromInt(1), + null.asInstanceOf[T], + numeric.fromInt(2), + numeric.fromInt(3), + numeric.fromInt(4), + numeric.fromInt(5), + numeric.fromInt(-10), + numeric.fromInt(-20), + numeric.fromInt(17), + null.asInstanceOf[T], + maxVal + ) + ) + + override def expectedResults: Seq[T] = Seq(numeric.fromInt(2)) + +} + +class ByteSumAggregateTest extends SumAggregateTestBase[Byte] { + + override def maxVal = (Byte.MaxValue / 2).toByte + + override def aggregator: Aggregate[Byte] = new ByteSumAggregate +} + +class ShortSumAggregateTest extends SumAggregateTestBase[Short] { + + override def maxVal = (Short.MaxValue / 2).toShort + + override def aggregator: Aggregate[Short] = new ShortSumAggregate +} + +class IntSumAggregateTest extends SumAggregateTestBase[Int] { + + override def maxVal = Int.MaxValue / 2 + + override def aggregator: Aggregate[Int] = new IntSumAggregate +} + +class LongSumAggregateTest extends SumAggregateTestBase[Long] { + + override def maxVal = Long.MaxValue / 2 + + override def aggregator: Aggregate[Long] = new LongSumAggregate +} + +class FloatSumAggregateTest extends SumAggregateTestBase[Float] { + + override def maxVal = 12345.6789f + + override def aggregator: Aggregate[Float] = new FloatSumAggregate +} + +class DoubleSumAggregateTest extends SumAggregateTestBase[Double] { + + override def maxVal = 12345.6789d + + override def aggregator: Aggregate[Double] = new DoubleSumAggregate +}