flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jinch...@apache.org
Subject flink git commit: [FLINK-7194] [table] Add default implementations for type hints to UDAGG interface.
Date Sat, 22 Jul 2017 11:30:45 GMT
Repository: flink
Updated Branches:
  refs/heads/master c472309c7 -> ea1edfb46


[FLINK-7194] [table] Add default implementations for type hints to UDAGG interface.

This closes #4379


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ea1edfb4
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ea1edfb4
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ea1edfb4

Branch: refs/heads/master
Commit: ea1edfb46f674035fd920c70100f60575600405f
Parents: c472309
Author: Fabian Hueske <fhueske@apache.org>
Authored: Thu Jul 20 15:09:06 2017 +0200
Committer: Jincheng Sun <jincheng@apache.org>
Committed: Sat Jul 22 06:55:44 2017 +0800

----------------------------------------------------------------------
 .../table/functions/AggregateFunction.scala     | 64 +++++++-------
 .../functions/aggfunctions/AvgAggFunction.scala | 16 ++--
 .../aggfunctions/CountAggFunction.scala         | 13 +--
 .../functions/aggfunctions/MaxAggFunction.scala |  4 +-
 .../MaxAggFunctionWithRetract.scala             |  8 +-
 .../functions/aggfunctions/MinAggFunction.scala |  4 +-
 .../MinAggFunctionWithRetract.scala             |  8 +-
 .../functions/aggfunctions/SumAggFunction.scala |  8 +-
 .../SumWithRetractAggFunction.scala             |  8 +-
 .../utils/UserDefinedFunctionUtils.scala        | 90 +++++++++++++-------
 .../table/api/stream/sql/AggregateTest.scala    |  2 +-
 .../aggfunctions/CountAggFunctionTest.scala     |  8 +-
 12 files changed, 129 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
index f90860b..8f50971 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala
@@ -17,6 +17,8 @@
  */
 package org.apache.flink.table.functions
 
+import org.apache.flink.api.common.typeinfo.TypeInformation
+
 /**
   * Base class for User-Defined Aggregates.
   *
@@ -28,9 +30,8 @@ package org.apache.flink.table.functions
   *
   *  There are a few other methods that can be optional to have:
   *  - retract,
-  *  - merge,
-  *  - resetAccumulator, and
-  *  - getAccumulatorType.
+  *  - merge, and
+  *  - resetAccumulator
   *
   * All these methods muse be declared publicly, not static and named exactly as the names
   * mentioned above. The methods createAccumulator and getValue are defined in the
@@ -72,7 +73,7 @@ package org.apache.flink.table.functions
   *                     custom merge method.
   * @param its          an [[java.lang.Iterable]] pointed to a group of accumulators that
will be
   *                     merged.
-
+  *
   * def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit
   * }}}
   *
@@ -82,39 +83,16 @@ package org.apache.flink.table.functions
   * dataset grouping aggregate.
   *
   * @param accumulator  the accumulator which needs to be reset
-
-  * def resetAccumulator(accumulator: ACC): Unit
-  * }}}
-  *
   *
-  * {{{
-  * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the accumulator.
This
-  * function is optional and can be implemented if the accumulator type cannot be automatically
-  * inferred from the instance returned by createAccumulator method.
-  *
-  * @return  the type information for the accumulator.
-
-  * def getAccumulatorType: TypeInformation[_]
-  * }}}
-  *
-  *
-  * {{{
-  * Returns the [[org.apache.flink.api.common.typeinfo.TypeInformation]] of the return value.
This
-  * function is optional and needed in case Flink's type extraction facilities are not sufficient
-  * to extract the TypeInformation. Flink's type extraction facilities can handle basic types
or
-  * simple POJOs but might be wrong for more complex, custom, or composite types.
-  *
-  * @return  the type information for the return value.
-  *
-  * def getResultType: TypeInformation[_]
+  * def resetAccumulator(accumulator: ACC): Unit
   * }}}
   *
   *
   * @tparam T   the type of the aggregation result
-  * @tparam ACC base class for aggregate Accumulator. The accumulator is used to keep the
aggregated
-  *             values which are needed to compute an aggregation result. AggregateFunction
-  *             represents its state using accumulator, thereby the state of the AggregateFunction
-  *             must be put into the accumulator.
+  * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep
the
+  *             aggregated values which are needed to compute an aggregation result.
+  *             AggregateFunction represents its state using accumulator, thereby the state
of the
+  *             AggregateFunction must be put into the accumulator.
   */
 abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
   /**
@@ -136,8 +114,26 @@ abstract class AggregateFunction[T, ACC] extends UserDefinedFunction
{
     */
   def getValue(accumulator: ACC): T
 
-  /**
-    * whether this aggregate only used in OVER clause
+    /**
+    * Returns true if this AggregateFunction can only be applied in an OVER window.
+    *
+    * @return true if the AggregateFunction requires an OVER window, false otherwise.
     */
   def requiresOver: Boolean = false
+
+  /**
+    * Returns the TypeInformation of the AggregateFunction's result.
+    *
+    * @return The TypeInformation of the AggregateFunction's result or null if the result
type
+    *         should be automatically inferred.
+    */
+  def getResultType: TypeInformation[T] = null
+
+  /**
+    * Returns the TypeInformation of the AggregateFunction's accumulator.
+    *
+    * @return The TypeInformation of the AggregateFunction's accumulator or null if the
+    *         accumulator type should be automatically inferred.
+    */
+  def getAccumulatorType: TypeInformation[ACC] = null
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
index 3f4e5db..b651c42 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
@@ -80,9 +80,9 @@ abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T, IntegralAv
     acc.f1 = 0L
   }
 
-  def getAccumulatorType: TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[IntegralAvgAccumulator] = {
     new TupleTypeInfo(
-      new IntegralAvgAccumulator().getClass,
+      classOf[IntegralAvgAccumulator],
       BasicTypeInfo.LONG_TYPE_INFO,
       BasicTypeInfo.LONG_TYPE_INFO)
   }
@@ -175,9 +175,9 @@ abstract class BigIntegralAvgAggFunction[T]
     acc.f1 = 0
   }
 
-  def getAccumulatorType: TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[BigIntegralAvgAccumulator] = {
     new TupleTypeInfo(
-      new BigIntegralAvgAccumulator().getClass,
+      classOf[BigIntegralAvgAccumulator],
       BasicTypeInfo.BIG_INT_TYPE_INFO,
       BasicTypeInfo.LONG_TYPE_INFO)
   }
@@ -255,9 +255,9 @@ abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T,
FloatingAv
     acc.f1 = 0L
   }
 
-  def getAccumulatorType: TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[FloatingAvgAccumulator] = {
     new TupleTypeInfo(
-      new FloatingAvgAccumulator().getClass,
+      classOf[FloatingAvgAccumulator],
       BasicTypeInfo.DOUBLE_TYPE_INFO,
       BasicTypeInfo.LONG_TYPE_INFO)
   }
@@ -339,9 +339,9 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal, DecimalAvgAccu
     acc.f1 = 0L
   }
 
-  def getAccumulatorType: TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[DecimalAvgAccumulator] = {
     new TupleTypeInfo(
-      new DecimalAvgAccumulator().getClass,
+      classOf[DecimalAvgAccumulator],
       BasicTypeInfo.BIG_DEC_TYPE_INFO,
       BasicTypeInfo.LONG_TYPE_INFO)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
index 2b8ec14..c94e053 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala
@@ -18,6 +18,7 @@
 package org.apache.flink.table.functions.aggfunctions
 
 import java.lang.{Iterable => JIterable}
+import java.lang.{Long => JLong}
 
 import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
 import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1}
@@ -32,7 +33,7 @@ class CountAccumulator extends JTuple1[Long] {
 /**
   * built-in count aggregate function
   */
-class CountAggFunction extends AggregateFunction[Long, CountAccumulator] {
+class CountAggFunction extends AggregateFunction[JLong, CountAccumulator] {
 
   def accumulate(acc: CountAccumulator, value: Any): Unit = {
     if (value != null) {
@@ -46,7 +47,7 @@ class CountAggFunction extends AggregateFunction[Long, CountAccumulator]
{
     }
   }
 
-  override def getValue(acc: CountAccumulator): Long = {
+  override def getValue(acc: CountAccumulator): JLong = {
     acc.f0
   }
 
@@ -65,10 +66,10 @@ class CountAggFunction extends AggregateFunction[Long, CountAccumulator]
{
     acc.f0 = 0L
   }
 
-  def getAccumulatorType(): TypeInformation[_] = {
-    new TupleTypeInfo((new CountAccumulator).getClass, BasicTypeInfo.LONG_TYPE_INFO)
+  override def getAccumulatorType: TypeInformation[CountAccumulator] = {
+    new TupleTypeInfo(classOf[CountAccumulator], BasicTypeInfo.LONG_TYPE_INFO)
   }
 
-  def getResultType(): TypeInformation[_] =
-    BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[_]]
+  override def getResultType: TypeInformation[JLong] =
+    BasicTypeInfo.LONG_TYPE_INFO
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
index 96ee8d1..0789bee 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala
@@ -76,9 +76,9 @@ abstract class MaxAggFunction[T](implicit ord: Ordering[T])
     acc.f1 = false
   }
 
-  def getAccumulatorType(): TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[MaxAccumulator[T]] = {
     new TupleTypeInfo(
-      new MaxAccumulator[T].getClass,
+      classOf[MaxAccumulator[T]],
       getValueTypeInfo,
       BasicTypeInfo.BOOLEAN_TYPE_INFO)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
index 6f18739..c79c06a 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala
@@ -82,7 +82,7 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T])
           val iterator = acc.f1.keySet().iterator()
           var key = iterator.next()
           acc.f0 = key
-          while (iterator.hasNext()) {
+          while (iterator.hasNext) {
             key = iterator.next()
             if (ord.compare(acc.f0, key) < 0) {
               acc.f0 = key
@@ -116,7 +116,7 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T])
         }
         // merge the count for each key
         val iterator = a.f1.keySet().iterator()
-        while (iterator.hasNext()) {
+        while (iterator.hasNext) {
           val key = iterator.next()
           if (acc.f1.containsKey(key)) {
             acc.f1.put(key, acc.f1.get(key) + a.f1.get(key))
@@ -133,9 +133,9 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T])
     acc.f1.clear()
   }
 
-  def getAccumulatorType(): TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[MaxWithRetractAccumulator[T]] = {
     new TupleTypeInfo(
-      new MaxWithRetractAccumulator[T].getClass,
+      classOf[MaxWithRetractAccumulator[T]],
       getValueTypeInfo,
       new MapTypeInfo(getValueTypeInfo, BasicTypeInfo.LONG_TYPE_INFO))
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
index 88d7afd..d2132c2 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala
@@ -76,9 +76,9 @@ abstract class MinAggFunction[T](implicit ord: Ordering[T])
     acc.f1 = false
   }
 
-  def getAccumulatorType(): TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[MinAccumulator[T]] = {
     new TupleTypeInfo(
-      new MinAccumulator[T].getClass,
+      classOf[MinAccumulator[T]],
       getValueTypeInfo,
       BasicTypeInfo.BOOLEAN_TYPE_INFO)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
index 2d3348b..faa6725 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala
@@ -82,7 +82,7 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T])
           val iterator = acc.f1.keySet().iterator()
           var key = iterator.next()
           acc.f0 = key
-          while (iterator.hasNext()) {
+          while (iterator.hasNext) {
             key = iterator.next()
             if (ord.compare(acc.f0, key) > 0) {
               acc.f0 = key
@@ -116,7 +116,7 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T])
         }
         // merge the count for each key
         val iterator = a.f1.keySet().iterator()
-        while (iterator.hasNext()) {
+        while (iterator.hasNext) {
           val key = iterator.next()
           if (acc.f1.containsKey(key)) {
             acc.f1.put(key, acc.f1.get(key) + a.f1.get(key))
@@ -133,9 +133,9 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T])
     acc.f1.clear()
   }
 
-  def getAccumulatorType(): TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[MinWithRetractAccumulator[T]] = {
     new TupleTypeInfo(
-      new MinWithRetractAccumulator[T].getClass,
+      classOf[MinWithRetractAccumulator[T]],
       getValueTypeInfo,
       new MapTypeInfo(getValueTypeInfo, BasicTypeInfo.LONG_TYPE_INFO))
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
index 43fc7ff..5c0b14b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala
@@ -76,9 +76,9 @@ abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T, SumAccumu
     acc.f1 = false
   }
 
-  def getAccumulatorType: TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[SumAccumulator[T]] = {
     new TupleTypeInfo(
-      (new SumAccumulator).getClass,
+      classOf[SumAccumulator[T]],
       getValueTypeInfo,
       BasicTypeInfo.BOOLEAN_TYPE_INFO)
   }
@@ -175,9 +175,9 @@ class DecimalSumAggFunction extends AggregateFunction[BigDecimal, DecimalSumAccu
     acc.f1 = false
   }
 
-  def getAccumulatorType: TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[DecimalSumAccumulator] = {
     new TupleTypeInfo(
-      (new DecimalSumAccumulator).getClass,
+      classOf[DecimalSumAccumulator],
       BasicTypeInfo.BIG_DEC_TYPE_INFO,
       BasicTypeInfo.BOOLEAN_TYPE_INFO)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
index 7f68d11..fc51b9b 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumWithRetractAggFunction.scala
@@ -84,9 +84,9 @@ abstract class SumWithRetractAggFunction[T: Numeric]
     acc.f1 = 0L
   }
 
-  def getAccumulatorType(): TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[SumWithRetractAccumulator[T]] = {
     new TupleTypeInfo(
-      (new SumWithRetractAccumulator).getClass,
+      classOf[SumWithRetractAccumulator[T]],
       getValueTypeInfo,
       BasicTypeInfo.LONG_TYPE_INFO)
   }
@@ -191,9 +191,9 @@ class DecimalSumWithRetractAggFunction
     acc.f1 = 0L
   }
 
-  def getAccumulatorType(): TypeInformation[_] = {
+  override def getAccumulatorType: TypeInformation[DecimalSumWithRetractAccumulator] = {
     new TupleTypeInfo(
-      (new DecimalSumWithRetractAccumulator).getClass,
+      classOf[DecimalSumWithRetractAccumulator],
       BasicTypeInfo.BIG_DEC_TYPE_INFO,
       BasicTypeInfo.LONG_TYPE_INFO)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
index 5e34586..47469d1 100644
--- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
+++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
@@ -307,64 +307,90 @@ object UserDefinedFunctionUtils {
   // ----------------------------------------------------------------------------------------------
 
   /**
-    * Internal method of AggregateFunction#getResultType() that does some pre-checking and
uses
-    * [[TypeExtractor]] as default return type inference.
+    * Tries to infer the TypeInformation of an AggregateFunction's return type.
+    *
+    * @param aggregateFunction The AggregateFunction for which the return type is inferred.
+    * @param extractedType The implicitly inferred type of the result type.
+    *
+    * @return The inferred result type of the AggregateFunction.
     */
   def getResultTypeOfAggregateFunction(
       aggregateFunction: AggregateFunction[_, _],
       extractedType: TypeInformation[_] = null)
     : TypeInformation[_] = {
-    getParameterTypeOfAggregateFunction(aggregateFunction, "getResultType", 0, extractedType)
+
+    val resultType = aggregateFunction.getResultType
+    if (resultType != null) {
+      resultType
+    } else if (extractedType != null) {
+      extractedType
+    } else {
+      try {
+        extractTypeFromAggregateFunction(aggregateFunction, 0)
+      } catch {
+        case ite: InvalidTypesException =>
+          throw new TableException(
+            "Cannot infer generic type of ${aggregateFunction.getClass}. " +
+              "You can override AggregateFunction.getResultType() to specify the type.",
+            ite
+          )
+      }
+    }
   }
 
   /**
-    * Internal method of AggregateFunction#getAccumulatorType() that does some pre-checking
-    * and uses [[TypeExtractor]] as default return type inference.
+    * Tries to infer the TypeInformation of an AggregateFunction's accumulator type.
+    *
+    * @param aggregateFunction The AggregateFunction for which the accumulator type is inferred.
+    * @param extractedType The implicitly inferred type of the accumulator type.
+    *
+    * @return The inferred accumulator type of the AggregateFunction.
     */
   def getAccumulatorTypeOfAggregateFunction(
     aggregateFunction: AggregateFunction[_, _],
     extractedType: TypeInformation[_] = null)
   : TypeInformation[_] = {
-    getParameterTypeOfAggregateFunction(aggregateFunction, "getAccumulatorType", 1, extractedType)
-  }
-
-  private def getParameterTypeOfAggregateFunction(
-    aggregateFunction: AggregateFunction[_, _],
-    getTypeMethod: String,
-    parameterTypePos: Int,
-    extractedType: TypeInformation[_] = null)
-  : TypeInformation[_] = {
 
-    val resultType = try {
-      val method: Method = aggregateFunction.getClass.getMethod(getTypeMethod)
-      method.invoke(aggregateFunction).asInstanceOf[TypeInformation[_]]
-    } catch {
-      case _: NoSuchMethodException => null
-      case ite: Throwable => throw new TableException("Unexpected exception:", ite)
-    }
-    if (resultType != null) {
-      resultType
+    val accType = aggregateFunction.getAccumulatorType
+    if (accType != null) {
+      accType
     } else if (extractedType != null) {
       extractedType
     } else {
       try {
-        TypeExtractor
-        .createTypeInfo(aggregateFunction,
-                        classOf[AggregateFunction[_, _]],
-                        aggregateFunction.getClass,
-                        parameterTypePos)
-        .asInstanceOf[TypeInformation[_]]
+        extractTypeFromAggregateFunction(aggregateFunction, 1)
       } catch {
         case ite: InvalidTypesException =>
           throw new TableException(
-            s"Cannot infer generic type of ${aggregateFunction.getClass}. " +
-              s"You can override AggregateFunction.$getTypeMethod() to specify the type.",
-            ite)
+            "Cannot infer generic type of ${aggregateFunction.getClass}. " +
+              "You can override AggregateFunction.getAccumulatorType() to specify the type.",
+            ite
+          )
       }
     }
   }
 
   /**
+    * Internal method to extract a type from an AggregateFunction's type parameters.
+    *
+    * @param aggregateFunction The AggregateFunction for which the type is extracted.
+    * @param parameterTypePos The position of the type parameter for which the type is extracted.
+    *
+    * @return The extracted type.
+    */
+  @throws(classOf[InvalidTypesException])
+  private def extractTypeFromAggregateFunction(
+      aggregateFunction: AggregateFunction[_, _],
+      parameterTypePos: Int): TypeInformation[_] = {
+
+    TypeExtractor.createTypeInfo(
+      aggregateFunction,
+      classOf[AggregateFunction[_, _]],
+      aggregateFunction.getClass,
+      parameterTypePos).asInstanceOf[TypeInformation[_]]
+  }
+
+  /**
     * Internal method of [[ScalarFunction#getResultType()]] that does some pre-checking and
uses
     * [[TypeExtractor]] as default return type inference.
     */

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
index 70d1d21..76d33c2 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/stream/sql/AggregateTest.scala
@@ -153,5 +153,5 @@ class MyAgg2 extends AggregateFunction[Long, Row] {
 
   override def getValue(accumulator: Row): Long = 1L
 
-  def getAccumulatorType: TypeInformation[_] = new RowTypeInfo(Types.LONG, Types.INT)
+  override def getAccumulatorType: TypeInformation[Row] = new RowTypeInfo(Types.LONG, Types.INT)
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ea1edfb4/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
index f9dd474..87aaff9 100644
--- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
+++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/CountAggFunctionTest.scala
@@ -18,22 +18,24 @@
 
 package org.apache.flink.table.runtime.aggfunctions
 
+import java.lang.{Long => JLong}
+
 import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.functions.aggfunctions.{CountAccumulator, CountAggFunction}
 
 /**
   * Test case for built-in count aggregate function
   */
-class CountAggFunctionTest extends AggFunctionTestBase[Long, CountAccumulator] {
+class CountAggFunctionTest extends AggFunctionTestBase[JLong, CountAccumulator] {
 
   override def inputValueSets: Seq[Seq[_]] = Seq(
     Seq("a", "b", null, "c", null, "d", "e", null, "f"),
     Seq(null, null, null, null, null, null)
   )
 
-  override def expectedResults: Seq[Long] = Seq(6L, 0L)
+  override def expectedResults: Seq[JLong] = Seq(6L, 0L)
 
-  override def aggregator: AggregateFunction[Long, CountAccumulator] = new CountAggFunction()
+  override def aggregator: AggregateFunction[JLong, CountAccumulator] = new CountAggFunction()
 
   override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any])
 }


Mime
View raw message