flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From fhue...@apache.org
Subject flink git commit: [FLINK-3586] Fix potential overflow of Long AVG aggregation.
Date Wed, 25 May 2016 14:15:06 GMT
Repository: flink
Updated Branches:
  refs/heads/master 5b9872492 -> af0f41824


[FLINK-3586] Fix potential overflow of Long AVG aggregation.

- Add unit tests for Aggretates.

This closes #2024


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/af0f4182
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/af0f4182
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/af0f4182

Branch: refs/heads/master
Commit: af0f41824a1b54b71060c9ddd4f4830d45436172
Parents: 5b98724
Author: Fabian Hueske <fhueske@apache.org>
Authored: Sun May 22 16:46:43 2016 +0200
Committer: Fabian Hueske <fhueske@apache.org>
Committed: Wed May 25 15:30:04 2016 +0200

----------------------------------------------------------------------
 .../api/table/runtime/aggregate/Aggregate.scala |  61 +++++-----
 .../table/runtime/aggregate/AggregateUtil.scala |  29 ++---
 .../table/runtime/aggregate/AvgAggregate.scala  |  51 +++++---
 .../runtime/aggregate/CountAggregate.scala      |   6 +-
 .../table/runtime/aggregate/MaxAggregate.scala  |  35 ++----
 .../table/runtime/aggregate/MinAggregate.scala  |  35 ++----
 .../table/runtime/aggregate/SumAggregate.scala  |  38 ++----
 .../runtime/aggregate/AggregateTestBase.scala   | 104 +++++++++++++++++
 .../runtime/aggregate/AvgAggregateTest.scala    | 115 +++++++++++++++++++
 .../runtime/aggregate/CountAggregateTest.scala  |  30 +++++
 .../runtime/aggregate/MaxAggregateTest.scala    |  93 +++++++++++++++
 .../runtime/aggregate/MinAggregateTest.scala    |  93 +++++++++++++++
 .../runtime/aggregate/SumAggregateTest.scala    |  89 ++++++++++++++
 13 files changed, 635 insertions(+), 144 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala
index 496dcfb..1e91711 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/Aggregate.scala
@@ -17,7 +17,7 @@
  */
 package org.apache.flink.api.table.runtime.aggregate
 
-import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.table.Row
 
 /**
@@ -43,47 +43,54 @@ import org.apache.flink.api.table.Row
 trait Aggregate[T] extends Serializable {
 
   /**
-   * Initiate the intermediate aggregate value in Row.
-   * @param intermediate
-   */
-  def initiate(intermediate: Row): Unit
+    * Transform the aggregate field value into intermediate aggregate data.
+    *
+    * @param value The value to insert into the intermediate aggregate row.
+    * @param intermediate The intermediate aggregate row into which the value is inserted.
+    */
+  def prepare(value: Any, intermediate: Row): Unit
 
   /**
-   * Transform the aggregate field value into intermediate aggregate data.
-   * @param value
-   * @param intermediate
-   */
-  def prepare(value: Any, intermediate: Row): Unit
+    * Initiate the intermediate aggregate value in Row.
+    *
+    * @param intermediate The intermediate aggregate row to initiate.
+    */
+  def initiate(intermediate: Row): Unit
 
   /**
-   * Merge intermediate aggregate data into aggregate buffer.
-   * @param intermediate
-   * @param buffer
-   */
+    * Merge intermediate aggregate data into aggregate buffer.
+    *
+    * @param intermediate The intermediate aggregate row to merge.
+    * @param buffer The aggregate buffer into which the intermedidate is merged.
+    */
   def merge(intermediate: Row, buffer: Row): Unit
 
   /**
-   * Calculate the final aggregated result based on aggregate buffer.
-   * @param buffer
-   * @return
-   */
+    * Calculate the final aggregated result based on aggregate buffer.
+    *
+    * @param buffer The aggregate buffer from which the final aggregate is computed.
+    * @return The final result of the aggregate.
+    */
   def evaluate(buffer: Row): T
 
   /**
-   * Intermediate aggregate value types.
-   * @return
-   */
-  def intermediateDataType: Array[SqlTypeName]
+    * Intermediate aggregate value types.
+    *
+    * @return The types of the intermediate fields of this aggregate.
+    */
+  def intermediateDataType: Array[TypeInformation[_]]
 
   /**
-   * Set the aggregate data offset in Row.
-   * @param aggOffset
-   */
+    * Set the aggregate data offset in Row.
+    *
+    * @param aggOffset The offset of this aggregate in the intermediate aggregate rows.
+    */
   def setAggOffsetInRow(aggOffset: Int)
 
   /**
     * Whether aggregate function support partial aggregate.
-   * @return
-   */
+    *
+    * @return True if the aggregate supports partial aggregation, False otherwise.
+    */
   def supportPartial: Boolean = false
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
index bdc662a..8222a2e 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
@@ -26,8 +26,8 @@ import org.apache.calcite.sql.`type`.SqlTypeName._
 import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, SqlTypeName}
 import org.apache.calcite.sql.fun._
 import org.apache.flink.api.common.functions.{GroupReduceFunction, MapFunction}
+import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.table.typeutils.TypeConverter
-import TypeConverter._
 import org.apache.flink.api.table.typeutils.RowTypeInfo
 import org.apache.flink.api.table.{TableException, Row, TableConfig}
 
@@ -73,15 +73,9 @@ object AggregateUtil {
     val aggFieldIndexes = aggregateFunctionsAndFieldIndexes._1
     val aggregates = aggregateFunctionsAndFieldIndexes._2
 
-    val bufferDataType: RelRecordType =
+    val mapReturnType: RowTypeInfo =
       createAggregateBufferDataType(groupings, aggregates, inputType)
 
-    val mapReturnType = determineReturnType(
-        bufferDataType,
-        Some(TypeConverter.DEFAULT_ROW_TYPE),
-        config.getNullCheck,
-        config.getEfficientTypeUsage)
-
     val mapFunction = new AggregateMapFunction[Row, Row](
         aggregates, aggFieldIndexes, groupings,
         mapReturnType.asInstanceOf[RowTypeInfo]).asInstanceOf[MapFunction[Any, Row]]
@@ -240,25 +234,22 @@ object AggregateUtil {
   private def createAggregateBufferDataType(
       groupings: Array[Int],
       aggregates: Array[Aggregate[_]],
-      inputType: RelDataType): RelRecordType = {
+      inputType: RelDataType): RowTypeInfo = {
 
     // get the field data types of group keys.
-    val groupingTypes: Seq[RelDataTypeField] = groupings.map(inputType.getFieldList.get(_))
+    val groupingTypes: Seq[TypeInformation[_]] = groupings
+      .map(inputType.getFieldList.get(_).getType.getSqlTypeName)
+      .map(TypeConverter.sqlTypeToTypeInfo)
 
     val aggPartialNameSuffix = "agg_buffer_"
     val factory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT)
 
-    // get all the aggregate buffer value data type by their SqlTypeName.
-    val aggTypes: Seq[RelDataTypeField] =
-      aggregates.flatMap(_.intermediateDataType).zipWithIndex.map {
-        case (typeName: SqlTypeName, index: Int) =>
-          val fieldDataType = factory.createSqlType(typeName)
-          new RelDataTypeFieldImpl(aggPartialNameSuffix + index,
-            groupings.length + index, fieldDataType)
-      }
+    // get all field data types of all intermediate aggregates
+    val aggTypes: Seq[TypeInformation[_]] = aggregates.flatMap(_.intermediateDataType)
 
+    // concat group key types and aggregation types
     val allFieldTypes = groupingTypes ++: aggTypes
-    val partialType = new RelRecordType(allFieldTypes.toList)
+    val partialType = new RowTypeInfo(allFieldTypes)
     partialType
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala
index 8d3a45b..8cf181a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregate.scala
@@ -18,8 +18,9 @@
 package org.apache.flink.api.table.runtime.aggregate
 
 import com.google.common.math.LongMath
-import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
 import org.apache.flink.api.table.Row
+import java.math.BigInteger
 
 abstract class AvgAggregate[T] extends Aggregate[T] {
   protected var partialSumIndex: Int = _
@@ -34,8 +35,6 @@ abstract class AvgAggregate[T] extends Aggregate[T] {
 }
 
 abstract class IntegralAvgAggregate[T] extends AvgAggregate[T] {
-  private final val intermediateType = Array(SqlTypeName.BIGINT, SqlTypeName.BIGINT)
-
 
   override def initiate(partial: Row): Unit = {
     partial.setField(partialSumIndex, 0L)
@@ -60,9 +59,9 @@ abstract class IntegralAvgAggregate[T] extends AvgAggregate[T] {
     buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount))
   }
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    intermediateType
-  }
+  override def intermediateDataType = Array(
+    BasicTypeInfo.LONG_TYPE_INFO,
+    BasicTypeInfo.LONG_TYPE_INFO)
 
   def doPrepare(value: Any, partial: Row): Unit
 }
@@ -113,21 +112,47 @@ class IntAvgAggregate extends IntegralAvgAggregate[Int] {
 
 class LongAvgAggregate extends IntegralAvgAggregate[Long] {
 
+  override def intermediateDataType = Array(
+    BasicTypeInfo.BIG_INT_TYPE_INFO,
+    BasicTypeInfo.LONG_TYPE_INFO)
+
+  override def initiate(partial: Row): Unit = {
+    partial.setField(partialSumIndex, BigInteger.ZERO)
+    partial.setField(partialCountIndex, 0L)
+  }
+
+  override def prepare(value: Any, partial: Row): Unit = {
+    if (value == null) {
+      partial.setField(partialSumIndex, BigInteger.ZERO)
+      partial.setField(partialCountIndex, 0L)
+    } else {
+      doPrepare(value, partial)
+    }
+  }
+
   override def doPrepare(value: Any, partial: Row): Unit = {
     val input = value.asInstanceOf[Long]
-    partial.setField(partialSumIndex, input)
+    partial.setField(partialSumIndex, BigInteger.valueOf(input))
     partial.setField(partialCountIndex, 1L)
   }
 
+  override def merge(partial: Row, buffer: Row): Unit = {
+    val partialSum = partial.productElement(partialSumIndex).asInstanceOf[BigInteger]
+    val partialCount = partial.productElement(partialCountIndex).asInstanceOf[Long]
+    val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigInteger]
+    val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long]
+    buffer.setField(partialSumIndex, partialSum.add(bufferSum))
+    buffer.setField(partialCountIndex, LongMath.checkedAdd(partialCount, bufferCount))
+  }
+
   override def evaluate(buffer: Row): Long = {
-    val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[Long]
+    val bufferSum = buffer.productElement(partialSumIndex).asInstanceOf[BigInteger]
     val bufferCount = buffer.productElement(partialCountIndex).asInstanceOf[Long]
-    (bufferSum / bufferCount)
+    bufferSum.divide(BigInteger.valueOf(bufferCount)).longValue()
   }
 }
 
 abstract class FloatingAvgAggregate[T: Numeric] extends AvgAggregate[T] {
-  private val partialType = Array(SqlTypeName.DOUBLE, SqlTypeName.BIGINT)
 
   override def initiate(partial: Row): Unit = {
     partial.setField(partialSumIndex, 0D)
@@ -153,9 +178,9 @@ abstract class FloatingAvgAggregate[T: Numeric] extends AvgAggregate[T]
{
     buffer.setField(partialCountIndex, partialCount + bufferCount)
   }
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(
+    BasicTypeInfo.DOUBLE_TYPE_INFO,
+    BasicTypeInfo.LONG_TYPE_INFO)
 
   def doPrepare(value: Any, partial: Row): Unit
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala
index d615088..d9f288a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregate.scala
@@ -17,7 +17,7 @@
  */
 package org.apache.flink.api.table.runtime.aggregate
 
-import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
 import org.apache.flink.api.table.Row
 
 class CountAggregate extends Aggregate[Long] {
@@ -45,9 +45,7 @@ class CountAggregate extends Aggregate[Long] {
     }
   }
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    Array(SqlTypeName.BIGINT)
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
 
   override def supportPartial: Boolean = true
 

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala
index fde1b53..8f491f2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregate.scala
@@ -17,7 +17,7 @@
  */
 package org.apache.flink.api.table.runtime.aggregate
 
-import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.table.Row
 
 abstract class MaxAggregate[T: Numeric] extends Aggregate[T] {
@@ -27,6 +27,7 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] {
 
   /**
    * Accessed in MapFunction, prepare the input of partial aggregate.
+   *
    * @param value
    * @param intermediate
    */
@@ -41,6 +42,7 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] {
   /**
    * Accessed in CombineFunction and GroupReduceFunction, merge partial
    * aggregate result into aggregate buffer.
+   *
    * @param intermediate
    * @param buffer
    */
@@ -52,6 +54,7 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] {
 
   /**
    * Return the final aggregated result based on aggregate buffer.
+   *
    * @param buffer
    * @return
    */
@@ -67,11 +70,8 @@ abstract class MaxAggregate[T: Numeric] extends Aggregate[T] {
 }
 
 class ByteMaxAggregate extends MaxAggregate[Byte] {
-  private val intermediateType = Array(SqlTypeName.TINYINT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    intermediateType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(maxIndex, Byte.MinValue)
@@ -79,11 +79,8 @@ class ByteMaxAggregate extends MaxAggregate[Byte] {
 }
 
 class ShortMaxAggregate extends MaxAggregate[Short] {
-  private val intermediateType = Array(SqlTypeName.SMALLINT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    intermediateType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(maxIndex, Short.MinValue)
@@ -91,11 +88,8 @@ class ShortMaxAggregate extends MaxAggregate[Short] {
 }
 
 class IntMaxAggregate extends MaxAggregate[Int] {
-  private val intermediateType = Array(SqlTypeName.INTEGER)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    intermediateType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(maxIndex, Int.MinValue)
@@ -103,11 +97,8 @@ class IntMaxAggregate extends MaxAggregate[Int] {
 }
 
 class LongMaxAggregate extends MaxAggregate[Long] {
-  private val intermediateType = Array(SqlTypeName.BIGINT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    intermediateType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(maxIndex, Long.MinValue)
@@ -115,11 +106,8 @@ class LongMaxAggregate extends MaxAggregate[Long] {
 }
 
 class FloatMaxAggregate extends MaxAggregate[Float] {
-  private val intermediateType = Array(SqlTypeName.FLOAT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    intermediateType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(maxIndex, Float.MinValue)
@@ -127,11 +115,8 @@ class FloatMaxAggregate extends MaxAggregate[Float] {
 }
 
 class DoubleMaxAggregate extends MaxAggregate[Double] {
-  private val intermediateType = Array(SqlTypeName.DOUBLE)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    intermediateType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(maxIndex, Double.MinValue)

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala
index 7cc1b48..e78fb00 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregate.scala
@@ -17,7 +17,7 @@
  */
 package org.apache.flink.api.table.runtime.aggregate
 
-import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.table.Row
 
 abstract  class MinAggregate[T: Numeric] extends Aggregate[T]{
@@ -27,6 +27,7 @@ abstract  class MinAggregate[T: Numeric] extends Aggregate[T]{
 
   /**
    * Accessed in MapFunction, prepare the input of partial aggregate.
+   *
    * @param value
    * @param partial
    */
@@ -41,6 +42,7 @@ abstract  class MinAggregate[T: Numeric] extends Aggregate[T]{
   /**
    * Accessed in CombineFunction and GroupReduceFunction, merge partial
    * aggregate result into aggregate buffer.
+   *
    * @param partial
    * @param buffer
    */
@@ -52,6 +54,7 @@ abstract  class MinAggregate[T: Numeric] extends Aggregate[T]{
 
   /**
    * Return the final aggregated result based on aggregate buffer.
+   *
    * @param buffer
    * @return
    */
@@ -67,11 +70,8 @@ abstract  class MinAggregate[T: Numeric] extends Aggregate[T]{
 }
 
 class ByteMinAggregate extends MinAggregate[Byte] {
-  private val partialType = Array(SqlTypeName.TINYINT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(minIndex, Byte.MaxValue)
@@ -79,11 +79,8 @@ class ByteMinAggregate extends MinAggregate[Byte] {
 }
 
 class ShortMinAggregate extends MinAggregate[Short] {
-  private val partialType = Array(SqlTypeName.SMALLINT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(minIndex, Short.MaxValue)
@@ -91,11 +88,8 @@ class ShortMinAggregate extends MinAggregate[Short] {
 }
 
 class IntMinAggregate extends MinAggregate[Int] {
-  private val partialType = Array(SqlTypeName.INTEGER)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(minIndex, Int.MaxValue)
@@ -103,11 +97,8 @@ class IntMinAggregate extends MinAggregate[Int] {
 }
 
 class LongMinAggregate extends MinAggregate[Long] {
-  private val partialType = Array(SqlTypeName.BIGINT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(minIndex, Long.MaxValue)
@@ -115,11 +106,8 @@ class LongMinAggregate extends MinAggregate[Long] {
 }
 
 class FloatMinAggregate extends MinAggregate[Float] {
-  private val partialType = Array(SqlTypeName.FLOAT)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(minIndex, Float.MaxValue)
@@ -127,11 +115,8 @@ class FloatMinAggregate extends MinAggregate[Float] {
 }
 
 class DoubleMinAggregate extends MinAggregate[Double] {
-  private val partialType = Array(SqlTypeName.DOUBLE)
 
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
 
   override def initiate(intermediate: Row): Unit = {
     intermediate.setField(minIndex, Double.MaxValue)

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala
index 25ef344..b4c56fe 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregate.scala
@@ -17,7 +17,7 @@
  */
 package org.apache.flink.api.table.runtime.aggregate
 
-import org.apache.calcite.sql.`type`.SqlTypeName
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo
 import org.apache.flink.api.table.Row
 
 abstract class SumAggregate[T: Numeric]
@@ -57,49 +57,25 @@ abstract class SumAggregate[T: Numeric]
 }
 
 class ByteSumAggregate extends SumAggregate[Byte] {
-  private val partialType = Array(SqlTypeName.TINYINT)
-
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.BYTE_TYPE_INFO)
 }
 
 class ShortSumAggregate extends SumAggregate[Short] {
-  private val partialType = Array(SqlTypeName.SMALLINT)
-
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.SHORT_TYPE_INFO)
 }
 
 class IntSumAggregate extends SumAggregate[Int] {
-  private val partialType = Array(SqlTypeName.INTEGER)
-
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.INT_TYPE_INFO)
 }
 
 class LongSumAggregate extends SumAggregate[Long] {
-  private val partialType = Array(SqlTypeName.BIGINT)
-
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.LONG_TYPE_INFO)
 }
 
 class FloatSumAggregate extends SumAggregate[Float] {
-  private val partialType = Array(SqlTypeName.FLOAT)
-
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.FLOAT_TYPE_INFO)
 }
 
 class DoubleSumAggregate extends SumAggregate[Double] {
-  private val partialType = Array(SqlTypeName.DOUBLE)
-
-  override def intermediateDataType: Array[SqlTypeName] = {
-    partialType
-  }
+  override def intermediateDataType = Array(BasicTypeInfo.DOUBLE_TYPE_INFO)
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala
new file mode 100644
index 0000000..78d5f8c
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AggregateTestBase.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime.aggregate
+
+import org.apache.flink.api.table.Row
+import org.junit.Test
+import org.junit.Assert.assertEquals
+
+abstract class AggregateTestBase[T] {
+
+  private val offset = 2
+  private val rowArity: Int = offset + aggregator.intermediateDataType.length
+
+  def inputValueSets: Seq[Seq[_]]
+
+  def expectedResults: Seq[T]
+
+  def aggregator: Aggregate[T]
+
+  private def createAggregator(): Aggregate[T] = {
+    val agg = aggregator
+    agg.setAggOffsetInRow(offset)
+    agg
+  }
+
+  private def createRow(): Row = {
+    new Row(rowArity)
+  }
+
+  @Test
+  def testAggregate(): Unit = {
+
+    // iterate over input sets
+    for((vals, expected) <- inputValueSets.zip(expectedResults)) {
+
+      // prepare mapper
+      val rows: Seq[Row] = prepare(vals)
+
+      val result = if (aggregator.supportPartial) {
+        // test with combiner
+        val (firstVals, secondVals) = rows.splitAt(rows.length / 2)
+        val combined = partialAgg(firstVals) :: partialAgg(secondVals) :: Nil
+        finalAgg(combined)
+
+      } else {
+        // test without combiner
+        finalAgg(rows)
+      }
+
+      assertEquals(expected, result)
+
+    }
+  }
+
+  private def prepare(vals: Seq[_]): Seq[Row] = {
+
+    val agg = createAggregator()
+
+    vals.map { v =>
+      val row = createRow()
+      agg.prepare(v, row)
+      row
+    }
+  }
+
+  private def partialAgg(rows: Seq[Row]): Row = {
+
+    val agg = createAggregator()
+    val aggBuf = createRow()
+
+    agg.initiate(aggBuf)
+    rows.foreach(v => agg.merge(v, aggBuf))
+
+    aggBuf
+  }
+
+  private def finalAgg(rows: Seq[Row]): T = {
+
+    val agg = createAggregator()
+    val aggBuf = createRow()
+
+    agg.initiate(aggBuf)
+    rows.foreach(v => agg.merge(v, aggBuf))
+
+    agg.evaluate(partialAgg(rows))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala
new file mode 100644
index 0000000..2575fa2
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/AvgAggregateTest.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime.aggregate
+
+abstract class AvgAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+  private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+  def minVal: T
+  def maxVal: T
+
+  override def inputValueSets: Seq[Seq[T]] = Seq(
+    Seq(
+      minVal,
+      minVal,
+      null.asInstanceOf[T],
+      minVal,
+      minVal,
+      null.asInstanceOf[T],
+      minVal,
+      minVal,
+      minVal
+    ),
+    Seq(
+      maxVal,
+      maxVal,
+      null.asInstanceOf[T],
+      maxVal,
+      maxVal,
+      null.asInstanceOf[T],
+      maxVal,
+      maxVal,
+      maxVal
+    ),
+    Seq(
+      minVal,
+      maxVal,
+      null.asInstanceOf[T],
+      numeric.fromInt(0),
+      numeric.negate(maxVal),
+      numeric.negate(minVal),
+      null.asInstanceOf[T]
+    )
+  )
+
+  override def expectedResults: Seq[T] = Seq(
+    minVal,
+    maxVal,
+    numeric.fromInt(0)
+  )
+}
+
+class ByteAvgAggregateTest extends AvgAggregateTestBase[Byte] {
+
+  override def minVal = (Byte.MinValue + 1).toByte
+  override def maxVal = (Byte.MaxValue - 1).toByte
+
+  override def aggregator = new ByteAvgAggregate()
+}
+
+class ShortAvgAggregateTest extends AvgAggregateTestBase[Short] {
+
+  override def minVal = (Short.MinValue + 1).toShort
+  override def maxVal = (Short.MaxValue - 1).toShort
+
+  override def aggregator = new ShortAvgAggregate()
+}
+
+class IntAvgAggregateTest extends AvgAggregateTestBase[Int] {
+
+  override def minVal = Int.MinValue + 1
+  override def maxVal = Int.MaxValue - 1
+
+  override def aggregator = new IntAvgAggregate()
+}
+
+class LongAvgAggregateTest extends AvgAggregateTestBase[Long] {
+
+  override def minVal = Long.MinValue + 1
+  override def maxVal = Long.MaxValue - 1
+
+  override def aggregator = new LongAvgAggregate()
+}
+
+class FloatAvgAggregateTest extends AvgAggregateTestBase[Float] {
+
+  override def minVal = Float.MinValue
+  override def maxVal = Float.MaxValue
+
+  override def aggregator = new FloatAvgAggregate()
+}
+
+class DoubleAvgAggregateTest extends AvgAggregateTestBase[Double] {
+
+  override def minVal = Float.MinValue
+  override def maxVal = Float.MaxValue
+
+  override def aggregator = new DoubleAvgAggregate()
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala
new file mode 100644
index 0000000..ce27d7c
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/CountAggregateTest.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime.aggregate
+
+class CountAggregateTest extends AggregateTestBase[Long] {
+
+  override def inputValueSets: Seq[Seq[_]] = Seq(
+    Seq("a", "b", null, "c", null, "d", "e", null, "f")
+  )
+
+  override def expectedResults: Seq[Long] = Seq(6L)
+
+  override def aggregator: Aggregate[Long] = new CountAggregate()
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala
new file mode 100644
index 0000000..f3951e4
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MaxAggregateTest.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime.aggregate
+
+abstract class MaxAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+  private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+  def minVal: T
+  def maxVal: T
+
+  override def inputValueSets: Seq[Seq[T]] = Seq(
+    Seq(
+      numeric.fromInt(1),
+      null.asInstanceOf[T],
+      maxVal,
+      numeric.fromInt(-99),
+      numeric.fromInt(3),
+      numeric.fromInt(56),
+      numeric.fromInt(0),
+      minVal,
+      numeric.fromInt(-20),
+      numeric.fromInt(17),
+      null.asInstanceOf[T]
+    )
+  )
+
+  override def expectedResults: Seq[T] = Seq(maxVal)
+}
+
+class ByteMaxAggregateTest extends MaxAggregateTestBase[Byte] {
+
+  override def minVal = (Byte.MinValue + 1).toByte
+  override def maxVal = (Byte.MaxValue - 1).toByte
+
+  override def aggregator: Aggregate[Byte] = new ByteMaxAggregate()
+}
+
+class ShortMaxAggregateTest extends MaxAggregateTestBase[Short] {
+
+  override def minVal = (Short.MinValue + 1).toShort
+  override def maxVal = (Short.MaxValue - 1).toShort
+
+  override def aggregator: Aggregate[Short] = new ShortMaxAggregate()
+}
+
+class IntMaxAggregateTest extends MaxAggregateTestBase[Int] {
+
+  override def minVal = Int.MinValue + 1
+  override def maxVal = Int.MaxValue - 1
+
+  override def aggregator: Aggregate[Int] = new IntMaxAggregate()
+}
+
+class LongMaxAggregateTest extends MaxAggregateTestBase[Long] {
+
+  override def minVal = Long.MinValue + 1
+  override def maxVal = Long.MaxValue - 1
+
+  override def aggregator: Aggregate[Long] = new LongMaxAggregate()
+}
+
+class FloatMaxAggregateTest extends MaxAggregateTestBase[Float] {
+
+  override def minVal = Float.MinValue / 2
+  override def maxVal = Float.MaxValue / 2
+
+  override def aggregator: Aggregate[Float] = new FloatMaxAggregate()
+}
+
+class DoubleMaxAggregateTest extends MaxAggregateTestBase[Double] {
+
+  override def minVal = Double.MinValue / 2
+  override def maxVal = Double.MaxValue / 2
+
+  override def aggregator: Aggregate[Double] = new DoubleMaxAggregate()
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala
new file mode 100644
index 0000000..3a4b111
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/MinAggregateTest.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime.aggregate
+
+abstract class MinAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+  private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+  def minVal: T
+  def maxVal: T
+
+  override def inputValueSets: Seq[Seq[T]] = Seq(
+    Seq(
+      numeric.fromInt(1),
+      null.asInstanceOf[T],
+      maxVal,
+      numeric.fromInt(-99),
+      numeric.fromInt(3),
+      numeric.fromInt(56),
+      numeric.fromInt(0),
+      minVal,
+      numeric.fromInt(-20),
+      numeric.fromInt(17),
+      null.asInstanceOf[T]
+    )
+  )
+
+  override def expectedResults: Seq[T] = Seq(minVal)
+}
+
+class ByteMinAggregateTest extends MinAggregateTestBase[Byte] {
+
+  override def minVal = (Byte.MinValue + 1).toByte
+  override def maxVal = (Byte.MaxValue - 1).toByte
+
+  override def aggregator: Aggregate[Byte] = new ByteMinAggregate()
+}
+
+class ShortMinAggregateTest extends MinAggregateTestBase[Short] {
+
+  override def minVal = (Short.MinValue + 1).toShort
+  override def maxVal = (Short.MaxValue - 1).toShort
+
+  override def aggregator: Aggregate[Short] = new ShortMinAggregate()
+}
+
+class IntMinAggregateTest extends MinAggregateTestBase[Int] {
+
+  override def minVal = Int.MinValue + 1
+  override def maxVal = Int.MaxValue - 1
+
+  override def aggregator: Aggregate[Int] = new IntMinAggregate()
+}
+
+class LongMinAggregateTest extends MinAggregateTestBase[Long] {
+
+  override def minVal = Long.MinValue + 1
+  override def maxVal = Long.MaxValue - 1
+
+  override def aggregator: Aggregate[Long] = new LongMinAggregate()
+}
+
+class FloatMinAggregateTest extends MinAggregateTestBase[Float] {
+
+  override def minVal = Float.MinValue / 2
+  override def maxVal = Float.MaxValue / 2
+
+  override def aggregator: Aggregate[Float] = new FloatMinAggregate()
+}
+
+class DoubleMinAggregateTest extends MinAggregateTestBase[Double] {
+
+  override def minVal = Double.MinValue / 2
+  override def maxVal = Double.MaxValue / 2
+
+  override def aggregator: Aggregate[Double] = new DoubleMinAggregate()
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/af0f4182/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala
new file mode 100644
index 0000000..f5de3fc
--- /dev/null
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/table/runtime/aggregate/SumAggregateTest.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.table.runtime.aggregate
+
+abstract class SumAggregateTestBase[T: Numeric] extends AggregateTestBase[T] {
+
+  private val numeric: Numeric[T] = implicitly[Numeric[T]]
+
+  def maxVal: T
+  private val minVal = numeric.negate(maxVal)
+
+  override def inputValueSets: Seq[Seq[T]] = Seq(
+    Seq(
+      minVal,
+      numeric.fromInt(1),
+      null.asInstanceOf[T],
+      numeric.fromInt(2),
+      numeric.fromInt(3),
+      numeric.fromInt(4),
+      numeric.fromInt(5),
+      numeric.fromInt(-10),
+      numeric.fromInt(-20),
+      numeric.fromInt(17),
+      null.asInstanceOf[T],
+      maxVal
+    )
+  )
+
+  override def expectedResults: Seq[T] = Seq(numeric.fromInt(2))
+
+}
+
+class ByteSumAggregateTest extends SumAggregateTestBase[Byte] {
+
+  override def maxVal = (Byte.MaxValue / 2).toByte
+
+  override def aggregator: Aggregate[Byte] = new ByteSumAggregate
+}
+
+class ShortSumAggregateTest extends SumAggregateTestBase[Short] {
+
+  override def maxVal = (Short.MaxValue / 2).toShort
+
+  override def aggregator: Aggregate[Short] = new ShortSumAggregate
+}
+
+class IntSumAggregateTest extends SumAggregateTestBase[Int] {
+
+  override def maxVal = Int.MaxValue / 2
+
+  override def aggregator: Aggregate[Int] = new IntSumAggregate
+}
+
+class LongSumAggregateTest extends SumAggregateTestBase[Long] {
+
+  override def maxVal = Long.MaxValue / 2
+
+  override def aggregator: Aggregate[Long] = new LongSumAggregate
+}
+
+class FloatSumAggregateTest extends SumAggregateTestBase[Float] {
+
+  override def maxVal = 12345.6789f
+
+  override def aggregator: Aggregate[Float] = new FloatSumAggregate
+}
+
+class DoubleSumAggregateTest extends SumAggregateTestBase[Double] {
+
+  override def maxVal = 12345.6789d
+
+  override def aggregator: Aggregate[Double] = new DoubleSumAggregate
+}


Mime
View raw message