beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From taki...@apache.org
Subject [1/2] beam git commit: UDAF support: - Adds an abstract class BeamSqlUdaf for defining Calcite SQL UDAFs. - Updates built-in COUNT/SUM/AVG/MAX/MIN accumulators to use this new class.
Date Fri, 30 Jun 2017 22:41:49 GMT
Repository: beam
Updated Branches:
  refs/heads/DSL_SQL 2096da25e -> 7ba77dd43


UDAF support:
- Adds an abstract class BeamSqlUdaf for defining Calcite SQL UDAFs.
- Updates built-in COUNT/SUM/AVG/MAX/MIN accumulators to use this new class.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a13fce98
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a13fce98
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a13fce98

Branch: refs/heads/DSL_SQL
Commit: a13fce98f61867fcb5adb52c80f1cfd3eecfc436
Parents: 2096da2
Author: mingmxu <mingmxu@ebay.com>
Authored: Mon Jun 26 16:03:51 2017 -0700
Committer: Tyler Akidau <takidau@apache.org>
Committed: Fri Jun 30 15:34:07 2017 -0700

----------------------------------------------------------------------
 .../org/apache/beam/dsls/sql/BeamSqlEnv.java    |  10 +
 .../beam/dsls/sql/rel/BeamAggregationRel.java   |   2 +-
 .../beam/dsls/sql/schema/BeamSqlUdaf.java       |  72 ++
 .../transform/BeamAggregationTransforms.java    | 658 ++++---------------
 .../sql/transform/BeamBuiltinAggregations.java  | 412 ++++++++++++
 .../transform/BeamAggregationTransformTest.java |   2 +-
 6 files changed, 633 insertions(+), 523 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/a13fce98/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
index baa2617..078d9d3 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSqlEnv.java
@@ -22,6 +22,7 @@ import java.io.Serializable;
 import org.apache.beam.dsls.sql.planner.BeamQueryPlanner;
 import org.apache.beam.dsls.sql.schema.BaseBeamTable;
 import org.apache.beam.dsls.sql.schema.BeamSqlRecordType;
+import org.apache.beam.dsls.sql.schema.BeamSqlUdaf;
 import org.apache.beam.dsls.sql.utils.CalciteUtils;
 import org.apache.calcite.DataContext;
 import org.apache.calcite.linq4j.Enumerable;
@@ -32,6 +33,7 @@ import org.apache.calcite.schema.Schema;
 import org.apache.calcite.schema.SchemaPlus;
 import org.apache.calcite.schema.Statistic;
 import org.apache.calcite.schema.Statistics;
+import org.apache.calcite.schema.impl.AggregateFunctionImpl;
 import org.apache.calcite.schema.impl.ScalarFunctionImpl;
 import org.apache.calcite.tools.Frameworks;
 
@@ -58,6 +60,14 @@ public class BeamSqlEnv {
   }
 
   /**
+   * Register a UDAF function which can be used in GROUP-BY expression.
+   * See {@link BeamSqlUdaf} on how to implement a UDAF.
+   */
+  public void registerUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz) {
+    schema.add(functionName, AggregateFunctionImpl.create(clazz));
+  }
+
+  /**
    * Registers a {@link BaseBeamTable} which can be used for all subsequent queries.
    *
    */

http://git-wip-us.apache.org/repos/asf/beam/blob/a13fce98/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
index 9ec9e9f..9bb2902 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/rel/BeamAggregationRel.java
@@ -104,7 +104,7 @@ public class BeamAggregationRel extends Aggregate implements BeamRelNode {
     PCollection<KV<BeamSqlRow, BeamSqlRow>> aggregatedStream = exCombineByStream.apply(
         stageName + "combineBy",
         Combine.<BeamSqlRow, BeamSqlRow, BeamSqlRow>perKey(
-            new BeamAggregationTransforms.AggregationCombineFn(getAggCallList(),
+            new BeamAggregationTransforms.AggregationAdaptor(getAggCallList(),
                 CalciteUtils.toBeamRecordType(input.getRowType()))))
         .setCoder(KvCoder.of(keyCoder, aggCoder));
 

http://git-wip-us.apache.org/repos/asf/beam/blob/a13fce98/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java
new file mode 100644
index 0000000..9582ffa
--- /dev/null
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/schema/BeamSqlUdaf.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.schema;
+
+import java.io.Serializable;
+import java.lang.reflect.ParameterizedType;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+
+/**
+ * abstract class of aggregation functions in Beam SQL.
+ *
+ * <p>There're several constrains for a UDAF:<br>
+ * 1. A constructor with an empty argument list is required;<br>
+ * 2. The type of {@code InputT} and {@code OutputT} can only be Interger/Long/Short/Byte/Double
+ * /Float/Date/BigDecimal, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT
+ * /TIMESTAMP/DECIMAL;<br>
+ * 3. Keep intermediate data in {@code AccumT}, and do not rely on elements in class;<br>
+ */
+public abstract class BeamSqlUdaf<InputT, AccumT, OutputT> implements Serializable {
+  public BeamSqlUdaf(){}
+
+  /**
+   * create an initial aggregation object, equals to {@link CombineFn#createAccumulator()}.
+   */
+  public abstract AccumT init();
+
+  /**
+   * add an input value, equals to {@link CombineFn#addInput(Object, Object)}.
+   */
+  public abstract AccumT add(AccumT accumulator, InputT input);
+
+  /**
+   * merge aggregation objects from parallel tasks, equals to
+   *  {@link CombineFn#mergeAccumulators(Iterable)}.
+   */
+  public abstract AccumT merge(Iterable<AccumT> accumulators);
+
+  /**
+   * extract output value from aggregation object, equals to
+   * {@link CombineFn#extractOutput(Object)}.
+   */
+  public abstract OutputT result(AccumT accumulator);
+
+  /**
+   * get the coder for AccumT which stores the intermediate result.
+   * By default it's fetched from {@link CoderRegistry}.
+   */
+  public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry)
+      throws CannotProvideCoderException {
+    return registry.getCoder(
+        (Class<AccumT>) ((ParameterizedType) getClass()
+        .getGenericSuperclass()).getActualTypeArguments()[1]);
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/a13fce98/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
index 83d473a..9c0b4a3 100644
--- a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamAggregationTransforms.java
@@ -17,25 +17,35 @@
  */
 package org.apache.beam.dsls.sql.transform;
 
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
 import java.io.Serializable;
+import java.math.BigDecimal;
 import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Date;
+import java.util.Iterator;
 import java.util.List;
 import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlExpression;
 import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlInputRefExpression;
-import org.apache.beam.dsls.sql.interpreter.operator.BeamSqlPrimitive;
 import org.apache.beam.dsls.sql.schema.BeamSqlRecordType;
 import org.apache.beam.dsls.sql.schema.BeamSqlRow;
+import org.apache.beam.dsls.sql.schema.BeamSqlUdaf;
 import org.apache.beam.dsls.sql.utils.CalciteUtils;
+import org.apache.beam.sdk.coders.BigDecimalCoder;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.values.KV;
 import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.sql.SqlAggFunction;
-import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.schema.impl.AggregateFunctionImpl;
+import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.joda.time.Instant;
 
@@ -71,9 +81,7 @@ public class BeamAggregationTransforms implements Serializable{
         outRecord.addField(aggFieldNames.get(idx), kvRecord.getValue().getFieldValue(idx));
       }
 
-      // if (c.pane().isLast()) {
       c.output(outRecord);
-      // }
     }
   }
 
@@ -134,545 +142,153 @@ public class BeamAggregationTransforms implements Serializable{
   }
 
   /**
-   * Aggregation function which supports COUNT, MAX, MIN, SUM, AVG.
-   *
-   * <p>Multiple aggregation functions are combined together.
-   * For each aggregation function, it may accept part of all data types:<br>
-   *   1). COUNT works for any data type;<br>
-   *   2). MAX/MIN works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT, TIMESTAMP;<br>
-   *   3). SUM/AVG works for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT, TINYINT;<br>
-   *
+   * An adaptor class to invoke Calcite UDAF instances in Beam {@code CombineFn}.
    */
-  public static class AggregationCombineFn extends CombineFn<BeamSqlRow, BeamSqlRow, BeamSqlRow> {
-    private BeamSqlRecordType aggDataType;
+  public static class AggregationAdaptor
+    extends CombineFn<BeamSqlRow, AggregationAccumulator, BeamSqlRow> {
+    private List<BeamSqlUdaf> aggregators;
+    private List<BeamSqlExpression> sourceFieldExps;
+    private BeamSqlRecordType finalRecordType;
 
-    private int countIndex = -1;
-
-    List<String> aggFunctions;
-    List<BeamSqlExpression> aggElementExpressions;
-
-    public AggregationCombineFn(List<AggregateCall> aggregationCalls,
+    public AggregationAdaptor(List<AggregateCall> aggregationCalls,
         BeamSqlRecordType sourceRowRecordType) {
-      this.aggFunctions = new ArrayList<>();
-      this.aggElementExpressions = new ArrayList<>();
-
-      boolean hasAvg = false;
-      boolean hasCount = false;
-      int countIndex = -1;
-      List<String> fieldNames = new ArrayList<>();
-      List<Integer> fieldTypes = new ArrayList<>();
-      for (int idx = 0; idx < aggregationCalls.size(); ++idx) {
-        AggregateCall ac = aggregationCalls.get(idx);
-        //verify it's supported.
-        verifySupportedAggregation(ac);
-
-        fieldNames.add(ac.name);
-        fieldTypes.add(CalciteUtils.toJavaType(ac.type.getSqlTypeName()));
-
-        SqlAggFunction aggFn = ac.getAggregation();
-        switch (aggFn.getName()) {
-        case "COUNT":
-          aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L));
-          hasCount = true;
-          countIndex = idx;
-          break;
-        case "SUM":
-        case "MAX":
-        case "MIN":
-        case "AVG":
-          int refIndex = ac.getArgList().get(0);
-          aggElementExpressions.add(new BeamSqlInputRefExpression(
-              CalciteUtils.getFieldType(sourceRowRecordType, refIndex), refIndex));
-          if ("AVG".equals(aggFn.getName())) {
-            hasAvg = true;
-          }
-          break;
-
-        default:
+      aggregators = new ArrayList<>();
+      sourceFieldExps = new ArrayList<>();
+      List<String> outFieldsName = new ArrayList<>();
+      List<Integer> outFieldsType = new ArrayList<>();
+      for (AggregateCall call : aggregationCalls) {
+        int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0;
+        BeamSqlExpression sourceExp = new BeamSqlInputRefExpression(
+            CalciteUtils.getFieldType(sourceRowRecordType, refIndex), refIndex);
+        sourceFieldExps.add(sourceExp);
+
+        outFieldsName.add(call.name);
+        int outFieldType = CalciteUtils.toJavaType(call.type.getSqlTypeName());
+        outFieldsType.add(outFieldType);
+
+        switch (call.getAggregation().getName()) {
+          case "COUNT":
+            aggregators.add(new BeamBuiltinAggregations.Count());
+            break;
+          case "MAX":
+            aggregators.add(BeamBuiltinAggregations.Max.create(call.type.getSqlTypeName()));
+            break;
+          case "MIN":
+            aggregators.add(BeamBuiltinAggregations.Min.create(call.type.getSqlTypeName()));
+            break;
+          case "SUM":
+            aggregators.add(BeamBuiltinAggregations.Sum.create(call.type.getSqlTypeName()));
+            break;
+          case "AVG":
+            aggregators.add(BeamBuiltinAggregations.Avg.create(call.type.getSqlTypeName()));
+            break;
+          default:
+            if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
+              // handle UDAF.
+              SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction) call.getAggregation();
+              AggregateFunctionImpl fn = (AggregateFunctionImpl) udaf.function;
+              try {
+                aggregators.add((BeamSqlUdaf) fn.declaringClass.newInstance());
+              } catch (Exception e) {
+                throw new IllegalStateException(e);
+              }
+            } else {
+              throw new UnsupportedOperationException(
+                  String.format("Aggregator [%s] is not supported",
+                  call.getAggregation().getName()));
+            }
           break;
         }
-        aggFunctions.add(aggFn.getName());
       }
-
-
-      // add a COUNT holder if only have AVG
-      if (hasAvg && !hasCount) {
-        fieldNames.add("__COUNT");
-        fieldTypes.add(CalciteUtils.toJavaType(SqlTypeName.BIGINT));
-
-        aggFunctions.add("COUNT");
-        aggElementExpressions.add(BeamSqlPrimitive.<Long>of(SqlTypeName.BIGINT, 1L));
-
-        hasCount = true;
-        countIndex = aggDataType.size() - 1;
+      finalRecordType = BeamSqlRecordType.create(outFieldsName, outFieldsType);
+    }
+    @Override
+    public AggregationAccumulator createAccumulator() {
+      AggregationAccumulator initialAccu = new AggregationAccumulator();
+      for (BeamSqlUdaf agg : aggregators) {
+        initialAccu.accumulatorElements.add(agg.init());
       }
-
-      this.aggDataType = BeamSqlRecordType.create(fieldNames, fieldTypes);
-      this.countIndex = countIndex;
+      return initialAccu;
     }
-
-    private void verifySupportedAggregation(AggregateCall ac) {
-      //donot support DISTINCT
-      if (ac.isDistinct()) {
-        throw new UnsupportedOperationException("DISTINCT is not supported yet.");
+    @Override
+    public AggregationAccumulator addInput(AggregationAccumulator accumulator, BeamSqlRow input) {
+      AggregationAccumulator deltaAcc = new AggregationAccumulator();
+      for (int idx = 0; idx < aggregators.size(); ++idx) {
+        deltaAcc.accumulatorElements.add(
+            aggregators.get(idx).add(accumulator.accumulatorElements.get(idx),
+            sourceFieldExps.get(idx).evaluate(input).getValue()));
       }
-      String aggFnName = ac.getAggregation().getName();
-      switch (aggFnName) {
-      case "COUNT":
-        //COUNT works for any data type;
-        break;
-      case "SUM":
-        // SUM only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
-        // TINYINT now
-        if (!Arrays
-            .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE,
-                SqlTypeName.SMALLINT, SqlTypeName.TINYINT)
-            .contains(ac.type.getSqlTypeName())) {
-          throw new UnsupportedOperationException(
-              "SUM only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT");
-        }
-        break;
-      case "MAX":
-      case "MIN":
-        // MAX/MIN only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
-        // TINYINT, TIMESTAMP now
-        if (!Arrays.asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT,
-            SqlTypeName.DOUBLE, SqlTypeName.SMALLINT, SqlTypeName.TINYINT,
-            SqlTypeName.TIMESTAMP).contains(ac.type.getSqlTypeName())) {
-          throw new UnsupportedOperationException("MAX/MIN only support for INT, LONG, FLOAT,"
-              + " DOUBLE, SMALLINT, TINYINT, TIMESTAMP");
-        }
-        break;
-      case "AVG":
-        // AVG only support for INT, LONG, FLOAT, DOUBLE, DECIMAL, SMALLINT,
-        // TINYINT now
-        if (!Arrays
-            .asList(SqlTypeName.INTEGER, SqlTypeName.BIGINT, SqlTypeName.FLOAT, SqlTypeName.DOUBLE,
-                SqlTypeName.SMALLINT, SqlTypeName.TINYINT)
-            .contains(ac.type.getSqlTypeName())) {
-          throw new UnsupportedOperationException(
-              "AVG only support for INT, LONG, FLOAT, DOUBLE, SMALLINT, TINYINT");
+      return deltaAcc;
+    }
+    @Override
+    public AggregationAccumulator mergeAccumulators(Iterable<AggregationAccumulator> accumulators) {
+      AggregationAccumulator deltaAcc = new AggregationAccumulator();
+      for (int idx = 0; idx < aggregators.size(); ++idx) {
+        List accs = new ArrayList<>();
+        Iterator<AggregationAccumulator> ite = accumulators.iterator();
+        while (ite.hasNext()) {
+          accs.add(ite.next().accumulatorElements.get(idx));
         }
-        break;
-      default:
-        throw new UnsupportedOperationException(
-            String.format("[%s] is not supported.", aggFnName));
+        deltaAcc.accumulatorElements.add(aggregators.get(idx).merge(accs));
       }
+      return deltaAcc;
     }
-
     @Override
-    public BeamSqlRow createAccumulator() {
-      BeamSqlRow initialRecord = new BeamSqlRow(aggDataType);
-      for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
-        BeamSqlExpression ex = aggElementExpressions.get(idx);
-        String aggFnName = aggFunctions.get(idx);
-        switch (aggFnName) {
-        case "COUNT":
-          initialRecord.addField(idx, 0L);
-          break;
-        case "AVG":
-        case "SUM":
-          //for both AVG/SUM, a summary value is hold at first.
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            initialRecord.addField(idx, 0);
-            break;
-          case BIGINT:
-            initialRecord.addField(idx, 0L);
-            break;
-          case SMALLINT:
-            initialRecord.addField(idx, (short) 0);
-            break;
-          case TINYINT:
-            initialRecord.addField(idx, (byte) 0);
-            break;
-          case FLOAT:
-            initialRecord.addField(idx, 0.0f);
-            break;
-          case DOUBLE:
-            initialRecord.addField(idx, 0.0);
-            break;
-          default:
-            break;
-          }
-          break;
-        case "MAX":
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            initialRecord.addField(idx, Integer.MIN_VALUE);
-            break;
-          case BIGINT:
-            initialRecord.addField(idx, Long.MIN_VALUE);
-            break;
-          case SMALLINT:
-            initialRecord.addField(idx, Short.MIN_VALUE);
-            break;
-          case TINYINT:
-            initialRecord.addField(idx, Byte.MIN_VALUE);
-            break;
-          case FLOAT:
-            initialRecord.addField(idx, Float.MIN_VALUE);
-            break;
-          case DOUBLE:
-            initialRecord.addField(idx, Double.MIN_VALUE);
-            break;
-          case TIMESTAMP:
-            initialRecord.addField(idx, new Date(0));
-            break;
-          default:
-            break;
-          }
-          break;
-        case "MIN":
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            initialRecord.addField(idx, Integer.MAX_VALUE);
-            break;
-          case BIGINT:
-            initialRecord.addField(idx, Long.MAX_VALUE);
-            break;
-          case SMALLINT:
-            initialRecord.addField(idx, Short.MAX_VALUE);
-            break;
-          case TINYINT:
-            initialRecord.addField(idx, Byte.MAX_VALUE);
-            break;
-          case FLOAT:
-            initialRecord.addField(idx, Float.MAX_VALUE);
-            break;
-          case DOUBLE:
-            initialRecord.addField(idx, Double.MAX_VALUE);
-            break;
-          case TIMESTAMP:
-            initialRecord.addField(idx, new Date(Long.MAX_VALUE));
-            break;
-          default:
-            break;
-          }
-          break;
-        default:
-          break;
-        }
+    public BeamSqlRow extractOutput(AggregationAccumulator accumulator) {
+      BeamSqlRow result = new BeamSqlRow(finalRecordType);
+      for (int idx = 0; idx < aggregators.size(); ++idx) {
+        result.addField(idx, aggregators.get(idx).result(accumulator.accumulatorElements.get(idx)));
       }
-      return initialRecord;
+      return result;
     }
-
     @Override
-    public BeamSqlRow addInput(BeamSqlRow accumulator, BeamSqlRow input) {
-      BeamSqlRow deltaRecord = new BeamSqlRow(aggDataType);
-      for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
-        BeamSqlExpression ex = aggElementExpressions.get(idx);
-        String aggFnName = aggFunctions.get(idx);
-        switch (aggFnName) {
-        case "COUNT":
-          deltaRecord.addField(idx, 1 + accumulator.getLong(idx));
-          break;
-        case "AVG":
-        case "SUM":
-          // for both AVG/SUM, a summary value is hold at first.
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            deltaRecord.addField(idx,
-                ex.evaluate(input).getInteger() + accumulator.getInteger(idx));
-            break;
-          case BIGINT:
-            deltaRecord.addField(idx, ex.evaluate(input).getLong() + accumulator.getLong(idx));
-            break;
-          case SMALLINT:
-            deltaRecord.addField(idx,
-                (short) (ex.evaluate(input).getShort() + accumulator.getShort(idx)));
-            break;
-          case TINYINT:
-            deltaRecord.addField(idx,
-                (byte) (ex.evaluate(input).getByte() + accumulator.getByte(idx)));
-            break;
-          case FLOAT:
-            deltaRecord.addField(idx,
-                (float) (ex.evaluate(input).getFloat() + accumulator.getFloat(idx)));
-            break;
-          case DOUBLE:
-            deltaRecord.addField(idx, ex.evaluate(input).getDouble() + accumulator.getDouble(idx));
-            break;
-          default:
-            break;
-          }
-          break;
-        case "MAX":
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            deltaRecord.addField(idx,
-                Math.max(ex.evaluate(input).getInteger(), accumulator.getInteger(idx)));
-            break;
-          case BIGINT:
-            deltaRecord.addField(idx,
-                Math.max(ex.evaluate(input).getLong(), accumulator.getLong(idx)));
-            break;
-          case SMALLINT:
-            deltaRecord.addField(idx,
-                (short) Math.max(ex.evaluate(input).getShort(), accumulator.getShort(idx)));
-            break;
-          case TINYINT:
-            deltaRecord.addField(idx,
-                (byte) Math.max(ex.evaluate(input).getByte(), accumulator.getByte(idx)));
-            break;
-          case FLOAT:
-            deltaRecord.addField(idx,
-                Math.max(ex.evaluate(input).getFloat(), accumulator.getFloat(idx)));
-            break;
-          case DOUBLE:
-            deltaRecord.addField(idx,
-                Math.max(ex.evaluate(input).getDouble(), accumulator.getDouble(idx)));
-            break;
-          case TIMESTAMP:
-            Date preDate = accumulator.getDate(idx);
-            Date nowDate = ex.evaluate(input).getDate();
-            deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate);
-            break;
-          default:
-            break;
-          }
-          break;
-        case "MIN":
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            deltaRecord.addField(idx,
-                Math.min(ex.evaluate(input).getInteger(), accumulator.getInteger(idx)));
-            break;
-          case BIGINT:
-            deltaRecord.addField(idx,
-                Math.min(ex.evaluate(input).getLong(), accumulator.getLong(idx)));
-            break;
-          case SMALLINT:
-            deltaRecord.addField(idx,
-                (short) Math.min(ex.evaluate(input).getShort(), accumulator.getShort(idx)));
-            break;
-          case TINYINT:
-            deltaRecord.addField(idx,
-                (byte) Math.min(ex.evaluate(input).getByte(), accumulator.getByte(idx)));
-            break;
-          case FLOAT:
-            deltaRecord.addField(idx,
-                Math.min(ex.evaluate(input).getFloat(), accumulator.getFloat(idx)));
-            break;
-          case DOUBLE:
-            deltaRecord.addField(idx,
-                Math.min(ex.evaluate(input).getDouble(), accumulator.getDouble(idx)));
-            break;
-          case TIMESTAMP:
-            Date preDate = accumulator.getDate(idx);
-            Date nowDate = ex.evaluate(input).getDate();
-            deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate);
-            break;
-          default:
-            break;
-          }
-          break;
-        default:
-          break;
-        }
+    public Coder<AggregationAccumulator> getAccumulatorCoder(
+        CoderRegistry registry, Coder<BeamSqlRow> inputCoder)
+        throws CannotProvideCoderException {
+      registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of());
+      List<Coder> aggAccuCoderList = new ArrayList<>();
+      for (BeamSqlUdaf udaf : aggregators) {
+        aggAccuCoderList.add(udaf.getAccumulatorCoder(registry));
       }
-      return deltaRecord;
+      return new AggregationAccumulatorCoder(aggAccuCoderList);
     }
+  }
 
-    @Override
-    public BeamSqlRow mergeAccumulators(Iterable<BeamSqlRow> accumulators) {
-      BeamSqlRow deltaRecord = new BeamSqlRow(aggDataType);
+  /**
+   * A class to holder varied accumulator objects.
+   */
+  public static class AggregationAccumulator{
+    private List accumulatorElements = new ArrayList<>();
+  }
 
-      while (accumulators.iterator().hasNext()) {
-        BeamSqlRow sa = accumulators.iterator().next();
-        for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
-          BeamSqlExpression ex = aggElementExpressions.get(idx);
-          String aggFnName = aggFunctions.get(idx);
-          switch (aggFnName) {
-          case "COUNT":
-            deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx));
-            break;
-          case "AVG":
-          case "SUM":
-            // for both AVG/SUM, a summary value is hold at first.
-            switch (ex.getOutputType()) {
-            case INTEGER:
-              deltaRecord.addField(idx, deltaRecord.getInteger(idx) + sa.getInteger(idx));
-              break;
-            case BIGINT:
-              deltaRecord.addField(idx, deltaRecord.getLong(idx) + sa.getLong(idx));
-              break;
-            case SMALLINT:
-              deltaRecord.addField(idx, (short) (deltaRecord.getShort(idx) + sa.getShort(idx)));
-              break;
-            case TINYINT:
-              deltaRecord.addField(idx, (byte) (deltaRecord.getByte(idx) + sa.getByte(idx)));
-              break;
-            case FLOAT:
-              deltaRecord.addField(idx, (float) (deltaRecord.getFloat(idx) + sa.getFloat(idx)));
-              break;
-            case DOUBLE:
-              deltaRecord.addField(idx, deltaRecord.getDouble(idx) + sa.getDouble(idx));
-              break;
-            default:
-              break;
-            }
-            break;
-          case "MAX":
-            switch (ex.getOutputType()) {
-            case INTEGER:
-              deltaRecord.addField(idx, Math.max(deltaRecord.getInteger(idx), sa.getInteger(idx)));
-              break;
-            case BIGINT:
-              deltaRecord.addField(idx, Math.max(deltaRecord.getLong(idx), sa.getLong(idx)));
-              break;
-            case SMALLINT:
-              deltaRecord.addField(idx,
-                  (short) Math.max(deltaRecord.getShort(idx), sa.getShort(idx)));
-              break;
-            case TINYINT:
-              deltaRecord.addField(idx, (byte) Math.max(deltaRecord.getByte(idx), sa.getByte(idx)));
-              break;
-            case FLOAT:
-              deltaRecord.addField(idx, Math.max(deltaRecord.getFloat(idx), sa.getFloat(idx)));
-              break;
-            case DOUBLE:
-              deltaRecord.addField(idx, Math.max(deltaRecord.getDouble(idx), sa.getDouble(idx)));
-              break;
-            case TIMESTAMP:
-              Date preDate = deltaRecord.getDate(idx);
-              Date nowDate = sa.getDate(idx);
-              deltaRecord.addField(idx, preDate.getTime() > nowDate.getTime() ? preDate : nowDate);
-              break;
-            default:
-              break;
-            }
-            break;
-          case "MIN":
-            switch (ex.getOutputType()) {
-            case INTEGER:
-              deltaRecord.addField(idx, Math.min(deltaRecord.getInteger(idx), sa.getInteger(idx)));
-              break;
-            case BIGINT:
-              deltaRecord.addField(idx, Math.min(deltaRecord.getLong(idx), sa.getLong(idx)));
-              break;
-            case SMALLINT:
-              deltaRecord.addField(idx,
-                  (short) Math.min(deltaRecord.getShort(idx), sa.getShort(idx)));
-              break;
-            case TINYINT:
-              deltaRecord.addField(idx, (byte) Math.min(deltaRecord.getByte(idx), sa.getByte(idx)));
-              break;
-            case FLOAT:
-              deltaRecord.addField(idx, Math.min(deltaRecord.getFloat(idx), sa.getFloat(idx)));
-              break;
-            case DOUBLE:
-              deltaRecord.addField(idx, Math.min(deltaRecord.getDouble(idx), sa.getDouble(idx)));
-              break;
-            case TIMESTAMP:
-              Date preDate = deltaRecord.getDate(idx);
-              Date nowDate = sa.getDate(idx);
-              deltaRecord.addField(idx, preDate.getTime() < nowDate.getTime() ? preDate : nowDate);
-              break;
-            default:
-              break;
-            }
-            break;
-          default:
-            break;
-          }
-        }
+  /**
+   * Coder for {@link AggregationAccumulator}.
+   */
+  public static class AggregationAccumulatorCoder extends CustomCoder<AggregationAccumulator>{
+    private VarIntCoder sizeCoder = VarIntCoder.of();
+    private List<Coder> elementCoders;
+
+    public AggregationAccumulatorCoder(List<Coder> elementCoders) {
+      this.elementCoders = elementCoders;
+    }
+
+    @Override
+    public void encode(AggregationAccumulator value, OutputStream outStream)
+        throws CoderException, IOException {
+      sizeCoder.encode(value.accumulatorElements.size(), outStream);
+      for (int idx = 0; idx < value.accumulatorElements.size(); ++idx) {
+        elementCoders.get(idx).encode(value.accumulatorElements.get(idx), outStream);
       }
-      return deltaRecord;
     }
 
     @Override
-    public BeamSqlRow extractOutput(BeamSqlRow accumulator) {
-      BeamSqlRow finalRecord = new BeamSqlRow(aggDataType);
-      for (int idx = 0; idx < aggElementExpressions.size(); ++idx) {
-        BeamSqlExpression ex = aggElementExpressions.get(idx);
-        String aggFnName = aggFunctions.get(idx);
-        switch (aggFnName) {
-        case "COUNT":
-          finalRecord.addField(idx, accumulator.getLong(idx));
-          break;
-        case "AVG":
-          long count = accumulator.getLong(countIndex);
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            finalRecord.addField(idx, (int) (accumulator.getInteger(idx) / count));
-            break;
-          case BIGINT:
-            finalRecord.addField(idx, accumulator.getLong(idx) / count);
-            break;
-          case SMALLINT:
-            finalRecord.addField(idx, (short) (accumulator.getShort(idx) / count));
-            break;
-          case TINYINT:
-            finalRecord.addField(idx, (byte) (accumulator.getByte(idx) / count));
-            break;
-          case FLOAT:
-            finalRecord.addField(idx, (float) (accumulator.getFloat(idx) / count));
-            break;
-          case DOUBLE:
-            finalRecord.addField(idx, accumulator.getDouble(idx) / count);
-            break;
-          default:
-            break;
-          }
-          break;
-        case "SUM":
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            finalRecord.addField(idx, accumulator.getInteger(idx));
-            break;
-          case BIGINT:
-            finalRecord.addField(idx, accumulator.getLong(idx));
-            break;
-          case SMALLINT:
-            finalRecord.addField(idx, accumulator.getShort(idx));
-            break;
-          case TINYINT:
-            finalRecord.addField(idx, accumulator.getByte(idx));
-            break;
-          case FLOAT:
-            finalRecord.addField(idx, accumulator.getFloat(idx));
-            break;
-          case DOUBLE:
-            finalRecord.addField(idx, accumulator.getDouble(idx));
-            break;
-          default:
-            break;
-          }
-          break;
-        case "MAX":
-        case "MIN":
-          switch (ex.getOutputType()) {
-          case INTEGER:
-            finalRecord.addField(idx, accumulator.getInteger(idx));
-            break;
-          case BIGINT:
-            finalRecord.addField(idx, accumulator.getLong(idx));
-            break;
-          case SMALLINT:
-            finalRecord.addField(idx, accumulator.getShort(idx));
-            break;
-          case TINYINT:
-            finalRecord.addField(idx, accumulator.getByte(idx));
-            break;
-          case FLOAT:
-            finalRecord.addField(idx, accumulator.getFloat(idx));
-            break;
-          case DOUBLE:
-            finalRecord.addField(idx, accumulator.getDouble(idx));
-            break;
-          case TIMESTAMP:
-            finalRecord.addField(idx, accumulator.getDate(idx));
-            break;
-          default:
-            break;
-          }
-          break;
-        default:
-          break;
-        }
+    public AggregationAccumulator decode(InputStream inStream) throws CoderException, IOException {
+      AggregationAccumulator accu = new AggregationAccumulator();
+      int size = sizeCoder.decode(inStream);
+      for (int idx = 0; idx < size; ++idx) {
+        accu.accumulatorElements.add(elementCoders.get(idx).decode(inStream));
       }
-      return finalRecord;
+      return accu;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/a13fce98/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java
new file mode 100644
index 0000000..fab2666
--- /dev/null
+++ b/dsls/sql/src/main/java/org/apache/beam/dsls/sql/transform/BeamBuiltinAggregations.java
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.dsls.sql.transform;
+
+import java.math.BigDecimal;
+import java.util.Date;
+import java.util.Iterator;
+import org.apache.beam.dsls.sql.schema.BeamSqlUdaf;
+import org.apache.beam.sdk.coders.BigDecimalCoder;
+import org.apache.beam.sdk.coders.ByteCoder;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.DoubleCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.calcite.sql.type.SqlTypeName;
+
+/**
+ * Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG.
+ */
+class BeamBuiltinAggregations {
+  /**
+   * Built-in aggregation for COUNT.
+   */
+  public static final class Count<T> extends BeamSqlUdaf<T, Long, Long> {
+    public Count() {}
+
+    @Override
+    public Long init() {
+      return 0L;
+    }
+
+    @Override
+    public Long add(Long accumulator, T input) {
+      return accumulator + 1;
+    }
+
+    @Override
+    public Long merge(Iterable<Long> accumulators) {
+      long v = 0L;
+      Iterator<Long> ite = accumulators.iterator();
+      while (ite.hasNext()) {
+        v += ite.next();
+      }
+      return v;
+    }
+
+    @Override
+    public Long result(Long accumulator) {
+      return accumulator;
+    }
+  }
+
+  /**
+   * Built-in aggregation for MAX.
+   */
+  public static final class Max<T extends Comparable<T>> extends BeamSqlUdaf<T, T, T> {
+    public static Max create(SqlTypeName fieldType) {
+      switch (fieldType) {
+        case INTEGER:
+          return new BeamBuiltinAggregations.Max<Integer>(fieldType);
+        case SMALLINT:
+          return new BeamBuiltinAggregations.Max<Short>(fieldType);
+        case TINYINT:
+          return new BeamBuiltinAggregations.Max<Byte>(fieldType);
+        case BIGINT:
+          return new BeamBuiltinAggregations.Max<Long>(fieldType);
+        case FLOAT:
+          return new BeamBuiltinAggregations.Max<Float>(fieldType);
+        case DOUBLE:
+          return new BeamBuiltinAggregations.Max<Double>(fieldType);
+        case TIMESTAMP:
+          return new BeamBuiltinAggregations.Max<Date>(fieldType);
+        case DECIMAL:
+          return new BeamBuiltinAggregations.Max<BigDecimal>(fieldType);
+        default:
+          throw new UnsupportedOperationException(
+              String.format("[%s] is not support in MAX", fieldType));
+      }
+    }
+
+    private final SqlTypeName fieldType;
+    private Max(SqlTypeName fieldType) {
+      this.fieldType = fieldType;
+    }
+
+    @Override
+    public T init() {
+      return null;
+    }
+
+    @Override
+    public T add(T accumulator, T input) {
+      return (accumulator == null || accumulator.compareTo(input) < 0) ? input : accumulator;
+    }
+
+    @Override
+    public T merge(Iterable<T> accumulators) {
+      Iterator<T> ite = accumulators.iterator();
+      T mergedV = ite.next();
+      while (ite.hasNext()) {
+        T v = ite.next();
+        mergedV = mergedV.compareTo(v) > 0 ? mergedV : v;
+      }
+      return mergedV;
+    }
+
+    @Override
+    public T result(T accumulator) {
+      return accumulator;
+    }
+
+    @Override
+    public Coder<T> getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException {
+      return BeamBuiltinAggregations.getSqlTypeCoder(fieldType);
+    }
+  }
+
+  /**
+   * Built-in aggregation for MIN.
+   */
+  public static final class Min<T extends Comparable<T>> extends BeamSqlUdaf<T, T, T> {
+    public static Min create(SqlTypeName fieldType) {
+      switch (fieldType) {
+        case INTEGER:
+          return new BeamBuiltinAggregations.Min<Integer>(fieldType);
+        case SMALLINT:
+          return new BeamBuiltinAggregations.Min<Short>(fieldType);
+        case TINYINT:
+          return new BeamBuiltinAggregations.Min<Byte>(fieldType);
+        case BIGINT:
+          return new BeamBuiltinAggregations.Min<Long>(fieldType);
+        case FLOAT:
+          return new BeamBuiltinAggregations.Min<Float>(fieldType);
+        case DOUBLE:
+          return new BeamBuiltinAggregations.Min<Double>(fieldType);
+        case TIMESTAMP:
+          return new BeamBuiltinAggregations.Min<Date>(fieldType);
+        case DECIMAL:
+          return new BeamBuiltinAggregations.Min<BigDecimal>(fieldType);
+        default:
+          throw new UnsupportedOperationException(
+              String.format("[%s] is not support in MIN", fieldType));
+      }
+    }
+
+    private final SqlTypeName fieldType;
+    private Min(SqlTypeName fieldType) {
+      this.fieldType = fieldType;
+    }
+
+    @Override
+    public T init() {
+      return null;
+    }
+
+    @Override
+    public T add(T accumulator, T input) {
+      return (accumulator == null || accumulator.compareTo(input) > 0) ? input : accumulator;
+    }
+
+    @Override
+    public T merge(Iterable<T> accumulators) {
+      Iterator<T> ite = accumulators.iterator();
+      T mergedV = ite.next();
+      while (ite.hasNext()) {
+        T v = ite.next();
+        mergedV = mergedV.compareTo(v) < 0 ? mergedV : v;
+      }
+      return mergedV;
+    }
+
+    @Override
+    public T result(T accumulator) {
+      return accumulator;
+    }
+
+    @Override
+    public Coder<T> getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException {
+      return BeamBuiltinAggregations.getSqlTypeCoder(fieldType);
+    }
+  }
+
+  /**
+   * Built-in aggregation for SUM.
+   */
+  public static final class Sum<T> extends BeamSqlUdaf<T, BigDecimal, T> {
+    public static Sum create(SqlTypeName fieldType) {
+      switch (fieldType) {
+        case INTEGER:
+          return new BeamBuiltinAggregations.Sum<Integer>(fieldType);
+        case SMALLINT:
+          return new BeamBuiltinAggregations.Sum<Short>(fieldType);
+        case TINYINT:
+          return new BeamBuiltinAggregations.Sum<Byte>(fieldType);
+        case BIGINT:
+          return new BeamBuiltinAggregations.Sum<Long>(fieldType);
+        case FLOAT:
+          return new BeamBuiltinAggregations.Sum<Float>(fieldType);
+        case DOUBLE:
+          return new BeamBuiltinAggregations.Sum<Double>(fieldType);
+        case TIMESTAMP:
+          return new BeamBuiltinAggregations.Sum<Date>(fieldType);
+        case DECIMAL:
+          return new BeamBuiltinAggregations.Sum<BigDecimal>(fieldType);
+        default:
+          throw new UnsupportedOperationException(
+              String.format("[%s] is not support in SUM", fieldType));
+      }
+    }
+
+    private SqlTypeName fieldType;
+      private Sum(SqlTypeName fieldType) {
+        this.fieldType = fieldType;
+      }
+
+    @Override
+    public BigDecimal init() {
+      return new BigDecimal(0);
+    }
+
+    @Override
+    public BigDecimal add(BigDecimal accumulator, T input) {
+      return accumulator.add(new BigDecimal(input.toString()));
+    }
+
+    @Override
+    public BigDecimal merge(Iterable<BigDecimal> accumulators) {
+      BigDecimal v = new BigDecimal(0);
+      Iterator<BigDecimal> ite = accumulators.iterator();
+      while (ite.hasNext()) {
+        v = v.add(ite.next());
+      }
+      return v;
+    }
+
+    @Override
+    public T result(BigDecimal accumulator) {
+      Object result = null;
+      switch (fieldType) {
+        case INTEGER:
+          result = accumulator.intValue();
+          break;
+        case BIGINT:
+          result = accumulator.longValue();
+          break;
+        case SMALLINT:
+          result = accumulator.shortValue();
+          break;
+        case TINYINT:
+          result = accumulator.byteValue();
+          break;
+        case DOUBLE:
+          result = accumulator.doubleValue();
+          break;
+        case FLOAT:
+          result = accumulator.floatValue();
+          break;
+        case DECIMAL:
+          result = accumulator;
+          break;
+        default:
+          break;
+      }
+      return (T) result;
+    }
+  }
+
+  /**
+   * Built-in aggregation for AVG.
+   */
+  public static final class Avg<T> extends BeamSqlUdaf<T, KV<BigDecimal, Long>, T> {
+    public static Avg create(SqlTypeName fieldType) {
+      switch (fieldType) {
+        case INTEGER:
+          return new BeamBuiltinAggregations.Avg<Integer>(fieldType);
+        case SMALLINT:
+          return new BeamBuiltinAggregations.Avg<Short>(fieldType);
+        case TINYINT:
+          return new BeamBuiltinAggregations.Avg<Byte>(fieldType);
+        case BIGINT:
+          return new BeamBuiltinAggregations.Avg<Long>(fieldType);
+        case FLOAT:
+          return new BeamBuiltinAggregations.Avg<Float>(fieldType);
+        case DOUBLE:
+          return new BeamBuiltinAggregations.Avg<Double>(fieldType);
+        case TIMESTAMP:
+          return new BeamBuiltinAggregations.Avg<Date>(fieldType);
+        case DECIMAL:
+          return new BeamBuiltinAggregations.Avg<BigDecimal>(fieldType);
+        default:
+          throw new UnsupportedOperationException(
+              String.format("[%s] is not support in AVG", fieldType));
+      }
+    }
+
+    private SqlTypeName fieldType;
+      private Avg(SqlTypeName fieldType) {
+        this.fieldType = fieldType;
+      }
+
+    @Override
+    public KV<BigDecimal, Long> init() {
+      return KV.of(new BigDecimal(0), 0L);
+    }
+
+    @Override
+    public KV<BigDecimal, Long> add(KV<BigDecimal, Long> accumulator, T input) {
+      return KV.of(
+              accumulator.getKey().add(new BigDecimal(input.toString())),
+              accumulator.getValue() + 1);
+    }
+
+    @Override
+    public KV<BigDecimal, Long> merge(Iterable<KV<BigDecimal, Long>> accumulators) {
+      BigDecimal v = new BigDecimal(0);
+      long s = 0;
+      Iterator<KV<BigDecimal, Long>> ite = accumulators.iterator();
+      while (ite.hasNext()) {
+        KV<BigDecimal, Long> r = ite.next();
+        v = v.add(r.getKey());
+        s += r.getValue();
+      }
+      return KV.of(v, s);
+    }
+
+    @Override
+    public T result(KV<BigDecimal, Long> accumulator) {
+      BigDecimal decimalAvg = accumulator.getKey().divide(
+          new BigDecimal(accumulator.getValue()));
+      Object result = null;
+      switch (fieldType) {
+        case INTEGER:
+          result = decimalAvg.intValue();
+          break;
+        case BIGINT:
+          result = decimalAvg.longValue();
+          break;
+        case SMALLINT:
+          result = decimalAvg.shortValue();
+          break;
+        case TINYINT:
+          result = decimalAvg.byteValue();
+          break;
+        case DOUBLE:
+          result = decimalAvg.doubleValue();
+          break;
+        case FLOAT:
+          result = decimalAvg.floatValue();
+          break;
+        case DECIMAL:
+          result = decimalAvg;
+          break;
+        default:
+          break;
+      }
+      return (T) result;
+    }
+
+    @Override
+    public Coder<KV<BigDecimal, Long>> getAccumulatorCoder(CoderRegistry registry)
+        throws CannotProvideCoderException {
+      return KvCoder.of(BigDecimalCoder.of(), VarLongCoder.of());
+    }
+  }
+
+  /**
+   * Find {@link Coder} for Beam SQL field types.
+   */
+  private static Coder getSqlTypeCoder(SqlTypeName sqlType) {
+    switch (sqlType) {
+      case INTEGER:
+        return VarIntCoder.of();
+      case SMALLINT:
+        return SerializableCoder.of(Short.class);
+      case TINYINT:
+        return ByteCoder.of();
+      case BIGINT:
+        return VarLongCoder.of();
+      case FLOAT:
+        return SerializableCoder.of(Float.class);
+      case DOUBLE:
+        return DoubleCoder.of();
+      case TIMESTAMP:
+        return SerializableCoder.of(Date.class);
+      case DECIMAL:
+        return BigDecimalCoder.of();
+      default:
+        throw new UnsupportedOperationException(
+            String.format("Cannot find a Coder for data type [%s]", sqlType));
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/a13fce98/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
----------------------------------------------------------------------
diff --git a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
index 388a344..2b01254 100644
--- a/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
+++ b/dsls/sql/src/test/java/org/apache/beam/dsls/sql/schema/transform/BeamAggregationTransformTest.java
@@ -117,7 +117,7 @@ public class BeamAggregationTransformTest extends BeamTransformBaseTest{
     //3. run aggregation functions
     PCollection<KV<BeamSqlRow, BeamSqlRow>> aggregatedStream = groupedStream.apply("aggregation",
         Combine.<BeamSqlRow, BeamSqlRow, BeamSqlRow>groupedValues(
-            new BeamAggregationTransforms.AggregationCombineFn(aggCalls, inputRowType)))
+            new BeamAggregationTransforms.AggregationAdaptor(aggCalls, inputRowType)))
         .setCoder(KvCoder.<BeamSqlRow, BeamSqlRow>of(keyCoder, aggCoder));
 
     //4. flat KV to a single record


Mime
View raw message