beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From taki...@apache.org
Subject [1/2] beam git commit: take CombineFn as UDAF.
Date Sat, 12 Aug 2017 00:14:11 GMT
Repository: beam
Updated Branches:
  refs/heads/DSL_SQL f37a7a19c -> 9eec6a030


take CombineFn as UDAF.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1770c861
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1770c861
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1770c861

Branch: refs/heads/DSL_SQL
Commit: 1770c86121d7edc388cadc0e2791c19b027cc50f
Parents: f37a7a1
Author: mingmxu <mingmxu@ebay.com>
Authored: Thu Aug 10 17:42:29 2017 -0700
Committer: Tyler Akidau <takidau@apache.org>
Committed: Fri Aug 11 17:11:23 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/sdk/coders/BeamRecordCoder.java |  16 +-
 .../apache/beam/sdk/extensions/sql/BeamSql.java |  22 +-
 .../beam/sdk/extensions/sql/BeamSqlEnv.java     |  11 +-
 .../operator/BeamSqlInputRefExpression.java     |   4 +
 .../sql/impl/interpreter/operator/UdafImpl.java |  87 ++++
 .../transform/BeamAggregationTransforms.java    |  44 +-
 .../impl/transform/BeamBuiltinAggregations.java | 504 +++++++------------
 .../sdk/extensions/sql/schema/BeamSqlUdaf.java  |  72 ---
 .../extensions/sql/BeamSqlDslUdfUdafTest.java   |  22 +-
 9 files changed, 344 insertions(+), 438 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
index cbed87d..7b1b681 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/BeamRecordCoder.java
@@ -35,11 +35,11 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
   private static final BitSetCoder nullListCoder = BitSetCoder.of();
 
   private BeamRecordType recordType;
-  private List<Coder> coderArray;
+  private List<Coder> coders;
 
-  private BeamRecordCoder(BeamRecordType recordType, List<Coder> coderArray) {
+  private BeamRecordCoder(BeamRecordType recordType, List<Coder> coders) {
     this.recordType = recordType;
-    this.coderArray = coderArray;
+    this.coders = coders;
   }
 
   public static BeamRecordCoder of(BeamRecordType recordType, List<Coder> coderArray){
@@ -62,7 +62,7 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
         continue;
       }
 
-      coderArray.get(idx).encode(value.getFieldValue(idx), outStream);
+      coders.get(idx).encode(value.getFieldValue(idx), outStream);
     }
   }
 
@@ -75,7 +75,7 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
       if (nullFields.get(idx)) {
         fieldValues.add(null);
       } else {
-        fieldValues.add(coderArray.get(idx).decode(inStream));
+        fieldValues.add(coders.get(idx).decode(inStream));
       }
     }
     BeamRecord record = new BeamRecord(recordType, fieldValues);
@@ -99,8 +99,12 @@ public class BeamRecordCoder extends CustomCoder<BeamRecord> {
   @Override
   public void verifyDeterministic()
       throws org.apache.beam.sdk.coders.Coder.NonDeterministicException {
-    for (Coder c : coderArray) {
+    for (Coder c : coders) {
       c.verifyDeterministic();
     }
   }
+
+  public List<Coder> getCoders() {
+    return coders;
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
index a1e9877..bf6a9c0 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSql.java
@@ -23,8 +23,8 @@ import org.apache.beam.sdk.coders.BeamRecordCoder;
 import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
 import org.apache.beam.sdk.extensions.sql.schema.BeamPCollectionTable;
 import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
 import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdf;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.BeamRecord;
@@ -155,10 +155,10 @@ public class BeamSql {
       }
 
      /**
-      * register a UDAF function used in this query.
+      * register a {@link CombineFn} as UDAF function used in this query.
       */
-     public QueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){
-       getSqlEnv().registerUdaf(functionName, clazz);
+     public QueryTransform withUdaf(String functionName, CombineFn combineFn){
+       getSqlEnv().registerUdaf(functionName, combineFn);
        return this;
      }
 
@@ -231,13 +231,13 @@ public class BeamSql {
         return this;
       }
 
-     /**
-      * register a UDAF function used in this query.
-      */
-     public SimpleQueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){
-       getSqlEnv().registerUdaf(functionName, clazz);
-       return this;
-     }
+      /**
+       * register a {@link CombineFn} as UDAF function used in this query.
+       */
+      public SimpleQueryTransform withUdaf(String functionName, CombineFn combineFn){
+        getSqlEnv().registerUdaf(functionName, combineFn);
+        return this;
+      }
 
     private void validateQuery() {
       SqlNode sqlNode;

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
index 0737c49..79f2b32 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/BeamSqlEnv.java
@@ -18,12 +18,13 @@
 package org.apache.beam.sdk.extensions.sql;
 
 import java.io.Serializable;
+import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl;
 import org.apache.beam.sdk.extensions.sql.impl.planner.BeamQueryPlanner;
 import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
 import org.apache.beam.sdk.extensions.sql.schema.BaseBeamTable;
 import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
 import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdf;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.calcite.DataContext;
 import org.apache.calcite.linq4j.Enumerable;
@@ -34,7 +35,6 @@ import org.apache.calcite.schema.Schema;
 import org.apache.calcite.schema.SchemaPlus;
 import org.apache.calcite.schema.Statistic;
 import org.apache.calcite.schema.Statistics;
-import org.apache.calcite.schema.impl.AggregateFunctionImpl;
 import org.apache.calcite.schema.impl.ScalarFunctionImpl;
 import org.apache.calcite.tools.Frameworks;
 
@@ -69,11 +69,10 @@ public class BeamSqlEnv implements Serializable{
   }
 
   /**
-   * Register a UDAF function which can be used in GROUP-BY expression.
-   * See {@link BeamSqlUdaf} on how to implement a UDAF.
+   * Register a {@link CombineFn} as UDAF function which can be used in GROUP-BY expression.
    */
-  public void registerUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz) {
-    schema.add(functionName, AggregateFunctionImpl.create(clazz));
+  public void registerUdaf(String functionName, CombineFn combineFn) {
+    schema.add(functionName, new UdafImpl(combineFn));
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
index a2d1624..2c321f7 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/BeamSqlInputRefExpression.java
@@ -41,4 +41,8 @@ public class BeamSqlInputRefExpression extends BeamSqlExpression {
   public BeamSqlPrimitive evaluate(BeamRecord inputRow, BoundedWindow window) {
     return BeamSqlPrimitive.of(outputType, inputRow.getFieldValue(inputRef));
   }
+
+  public int getInputRef() {
+    return inputRef;
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java
new file mode 100644
index 0000000..83ed7f8
--- /dev/null
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/interpreter/operator/UdafImpl.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.impl.interpreter.operator;
+
+import java.io.Serializable;
+import java.lang.reflect.ParameterizedType;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.calcite.adapter.enumerable.AggImplementor;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.schema.AggregateFunction;
+import org.apache.calcite.schema.FunctionParameter;
+import org.apache.calcite.schema.ImplementableAggFunction;
+
+/**
+ * Implement {@link AggregateFunction} to take a {@link CombineFn} as UDAF.
+ */
+public final class UdafImpl<InputT, AccumT, OutputT>
+    implements AggregateFunction, ImplementableAggFunction, Serializable{
+  private CombineFn<InputT, AccumT, OutputT> combineFn;
+
+  public UdafImpl(CombineFn<InputT, AccumT, OutputT> combineFn) {
+    this.combineFn = combineFn;
+  }
+
+  public CombineFn<InputT, AccumT, OutputT> getCombineFn() {
+    return combineFn;
+  }
+
+  @Override
+  public List<FunctionParameter> getParameters() {
+    List<FunctionParameter> para = new ArrayList<>();
+    para.add(new FunctionParameter() {
+          public int getOrdinal() {
+            return 0; //up to one parameter is supported in UDAF.
+          }
+
+          public String getName() {
+            // not used as Beam SQL uses its own execution engine
+            return null;
+          }
+
+          public RelDataType getType(RelDataTypeFactory typeFactory) {
+            //the first generic type of CombineFn is the input type.
+            ParameterizedType parameterizedType = (ParameterizedType) combineFn.getClass()
+                .getGenericSuperclass();
+            return typeFactory.createJavaType(
+                (Class) parameterizedType.getActualTypeArguments()[0]);
+          }
+
+          public boolean isOptional() {
+            // not used as Beam SQL uses its own execution engine
+            return false;
+          }
+        });
+    return para;
+  }
+
+  @Override
+  public AggImplementor getImplementor(boolean windowContext) {
+    // not used as Beam SQL uses its own execution engine
+    return null;
+  }
+
+  @Override
+  public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
+    return typeFactory.createJavaType((Class) combineFn.getOutputType().getType());
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
index 0f90bee..40b7b58 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
@@ -25,6 +25,7 @@ import java.math.BigDecimal;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.List;
+import org.apache.beam.sdk.coders.BeamRecordCoder;
 import org.apache.beam.sdk.coders.BigDecimalCoder;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
@@ -32,13 +33,13 @@ import org.apache.beam.sdk.coders.CoderException;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.CustomCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
-import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlExpression;
 import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.BeamSqlInputRefExpression;
+import org.apache.beam.sdk.extensions.sql.impl.interpreter.operator.UdafImpl;
 import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
 import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
 import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRecordHelper;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -46,7 +47,6 @@ import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
 import org.apache.beam.sdk.values.BeamRecord;
 import org.apache.beam.sdk.values.KV;
 import org.apache.calcite.rel.core.AggregateCall;
-import org.apache.calcite.schema.impl.AggregateFunctionImpl;
 import org.apache.calcite.sql.validate.SqlUserDefinedAggFunction;
 import org.apache.calcite.util.ImmutableBitSet;
 import org.joda.time.Instant;
@@ -151,8 +151,8 @@ public class BeamAggregationTransforms implements Serializable{
    */
   public static class AggregationAdaptor
     extends CombineFn<BeamRecord, AggregationAccumulator, BeamRecord> {
-    private List<BeamSqlUdaf> aggregators;
-    private List<BeamSqlExpression> sourceFieldExps;
+    private List<CombineFn> aggregators;
+    private List<BeamSqlInputRefExpression> sourceFieldExps;
     private BeamRecordSqlType finalRowType;
 
     public AggregationAdaptor(List<AggregateCall> aggregationCalls,
@@ -163,7 +163,7 @@ public class BeamAggregationTransforms implements Serializable{
       List<Integer> outFieldsType = new ArrayList<>();
       for (AggregateCall call : aggregationCalls) {
         int refIndex = call.getArgList().size() > 0 ? call.getArgList().get(0) : 0;
-        BeamSqlExpression sourceExp = new BeamSqlInputRefExpression(
+        BeamSqlInputRefExpression sourceExp = new BeamSqlInputRefExpression(
             CalciteUtils.getFieldType(sourceRowType, refIndex), refIndex);
         sourceFieldExps.add(sourceExp);
 
@@ -173,27 +173,27 @@ public class BeamAggregationTransforms implements Serializable{
 
         switch (call.getAggregation().getName()) {
           case "COUNT":
-            aggregators.add(new BeamBuiltinAggregations.Count());
+            aggregators.add(Count.combineFn());
             break;
           case "MAX":
-            aggregators.add(BeamBuiltinAggregations.Max.create(call.type.getSqlTypeName()));
+            aggregators.add(BeamBuiltinAggregations.createMax(call.type.getSqlTypeName()));
             break;
           case "MIN":
-            aggregators.add(BeamBuiltinAggregations.Min.create(call.type.getSqlTypeName()));
+            aggregators.add(BeamBuiltinAggregations.createMin(call.type.getSqlTypeName()));
             break;
           case "SUM":
-            aggregators.add(BeamBuiltinAggregations.Sum.create(call.type.getSqlTypeName()));
+            aggregators.add(BeamBuiltinAggregations.createSum(call.type.getSqlTypeName()));
             break;
           case "AVG":
-            aggregators.add(BeamBuiltinAggregations.Avg.create(call.type.getSqlTypeName()));
+            aggregators.add(BeamBuiltinAggregations.createAvg(call.type.getSqlTypeName()));
             break;
           default:
             if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
               // handle UDAF.
               SqlUserDefinedAggFunction udaf = (SqlUserDefinedAggFunction) call.getAggregation();
-              AggregateFunctionImpl fn = (AggregateFunctionImpl) udaf.function;
+              UdafImpl fn = (UdafImpl) udaf.function;
               try {
-                aggregators.add((BeamSqlUdaf) fn.declaringClass.newInstance());
+                aggregators.add(fn.getCombineFn());
               } catch (Exception e) {
                 throw new IllegalStateException(e);
               }
@@ -210,8 +210,8 @@ public class BeamAggregationTransforms implements Serializable{
     @Override
     public AggregationAccumulator createAccumulator() {
       AggregationAccumulator initialAccu = new AggregationAccumulator();
-      for (BeamSqlUdaf agg : aggregators) {
-        initialAccu.accumulatorElements.add(agg.init());
+      for (CombineFn agg : aggregators) {
+        initialAccu.accumulatorElements.add(agg.createAccumulator());
       }
       return initialAccu;
     }
@@ -220,7 +220,7 @@ public class BeamAggregationTransforms implements Serializable{
       AggregationAccumulator deltaAcc = new AggregationAccumulator();
       for (int idx = 0; idx < aggregators.size(); ++idx) {
         deltaAcc.accumulatorElements.add(
-            aggregators.get(idx).add(accumulator.accumulatorElements.get(idx),
+            aggregators.get(idx).addInput(accumulator.accumulatorElements.get(idx),
             sourceFieldExps.get(idx).evaluate(input, null).getValue()));
       }
       return deltaAcc;
@@ -234,7 +234,7 @@ public class BeamAggregationTransforms implements Serializable{
         while (ite.hasNext()) {
           accs.add(ite.next().accumulatorElements.get(idx));
         }
-        deltaAcc.accumulatorElements.add(aggregators.get(idx).merge(accs));
+        deltaAcc.accumulatorElements.add(aggregators.get(idx).mergeAccumulators(accs));
       }
       return deltaAcc;
     }
@@ -242,7 +242,8 @@ public class BeamAggregationTransforms implements Serializable{
     public BeamRecord extractOutput(AggregationAccumulator accumulator) {
       List<Object> fieldValues = new ArrayList<>(aggregators.size());
       for (int idx = 0; idx < aggregators.size(); ++idx) {
-        fieldValues.add(aggregators.get(idx).result(accumulator.accumulatorElements.get(idx)));
+        fieldValues
+            .add(aggregators.get(idx).extractOutput(accumulator.accumulatorElements.get(idx)));
       }
       return new BeamRecord(finalRowType, fieldValues);
     }
@@ -250,10 +251,13 @@ public class BeamAggregationTransforms implements Serializable{
     public Coder<AggregationAccumulator> getAccumulatorCoder(
         CoderRegistry registry, Coder<BeamRecord> inputCoder)
         throws CannotProvideCoderException {
+      BeamRecordCoder beamRecordCoder = (BeamRecordCoder) inputCoder;
       registry.registerCoderForClass(BigDecimal.class, BigDecimalCoder.of());
       List<Coder> aggAccuCoderList = new ArrayList<>();
-      for (BeamSqlUdaf udaf : aggregators) {
-        aggAccuCoderList.add(udaf.getAccumulatorCoder(registry));
+      for (int idx = 0; idx < aggregators.size(); ++idx) {
+        int srcFieldIndex = sourceFieldExps.get(idx).getInputRef();
+        Coder srcFieldCoder = beamRecordCoder.getCoders().get(srcFieldIndex);
+        aggAccuCoderList.add(aggregators.get(idx).getAccumulatorCoder(registry, srcFieldCoder));
       }
       return new AggregationAccumulatorCoder(aggAccuCoderList);
     }

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
index 1fc8cf6..03edf13 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
@@ -21,16 +21,16 @@ import java.math.BigDecimal;
 import java.util.Date;
 import java.util.Iterator;
 import org.apache.beam.sdk.coders.BigDecimalCoder;
-import org.apache.beam.sdk.coders.ByteCoder;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.coders.DoubleCoder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.SerializableCoder;
-import org.apache.beam.sdk.coders.VarIntCoder;
-import org.apache.beam.sdk.coders.VarLongCoder;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Max;
+import org.apache.beam.sdk.transforms.Min;
+import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.values.KV;
 import org.apache.calcite.sql.type.SqlTypeName;
 
@@ -39,374 +39,258 @@ import org.apache.calcite.sql.type.SqlTypeName;
  */
 class BeamBuiltinAggregations {
   /**
-   * Built-in aggregation for COUNT.
+   * {@link CombineFn} for MAX based on {@link Max} and {@link Combine.BinaryCombineFn}.
    */
-  public static final class Count<T> extends BeamSqlUdaf<T, Long, Long> {
-    public Count() {}
-
-    @Override
-    public Long init() {
-      return 0L;
-    }
-
-    @Override
-    public Long add(Long accumulator, T input) {
-      return accumulator + 1;
-    }
+  public static CombineFn createMax(SqlTypeName fieldType) {
+    switch (fieldType) {
+    case INTEGER:
+      return Max.ofIntegers();
+    case SMALLINT:
+      return new CustMax<Short>();
+    case TINYINT:
+      return new CustMax<Byte>();
+    case BIGINT:
+      return Max.ofLongs();
+    case FLOAT:
+      return new CustMax<Float>();
+    case DOUBLE:
+      return Max.ofDoubles();
+    case TIMESTAMP:
+      return new CustMax<Date>();
+    case DECIMAL:
+      return new CustMax<BigDecimal>();
+    default:
+      throw new UnsupportedOperationException(
+          String.format("[%s] is not support in MAX", fieldType));
+  }
+  }
 
-    @Override
-    public Long merge(Iterable<Long> accumulators) {
-      long v = 0L;
-      Iterator<Long> ite = accumulators.iterator();
-      while (ite.hasNext()) {
-        v += ite.next();
-      }
-      return v;
-    }
+  /**
+   * {@link CombineFn} for MAX based on {@link Min} and {@link Combine.BinaryCombineFn}.
+   */
+  public static CombineFn createMin(SqlTypeName fieldType) {
+    switch (fieldType) {
+    case INTEGER:
+      return Min.ofIntegers();
+    case SMALLINT:
+      return new CustMin<Short>();
+    case TINYINT:
+      return new CustMin<Byte>();
+    case BIGINT:
+      return Min.ofLongs();
+    case FLOAT:
+      return new CustMin<Float>();
+    case DOUBLE:
+      return Min.ofDoubles();
+    case TIMESTAMP:
+      return new CustMin<Date>();
+    case DECIMAL:
+      return new CustMin<BigDecimal>();
+    default:
+      throw new UnsupportedOperationException(
+          String.format("[%s] is not support in MIN", fieldType));
+  }
+  }
 
-    @Override
-    public Long result(Long accumulator) {
-      return accumulator;
-    }
+  /**
+   * {@link CombineFn} for MAX based on {@link Sum} and {@link Combine.BinaryCombineFn}.
+   */
+  public static CombineFn createSum(SqlTypeName fieldType) {
+    switch (fieldType) {
+    case INTEGER:
+      return Sum.ofIntegers();
+    case SMALLINT:
+      return new ShortSum();
+    case TINYINT:
+      return new ByteSum();
+    case BIGINT:
+      return Sum.ofLongs();
+    case FLOAT:
+      return new FloatSum();
+    case DOUBLE:
+      return Sum.ofDoubles();
+    case DECIMAL:
+      return new BigDecimalSum();
+    default:
+      throw new UnsupportedOperationException(
+          String.format("[%s] is not support in SUM", fieldType));
+  }
   }
 
   /**
-   * Built-in aggregation for MAX.
+   * {@link CombineFn} for AVG.
    */
-  public static final class Max<T extends Comparable<T>> extends BeamSqlUdaf<T, T, T> {
-    public static Max create(SqlTypeName fieldType) {
-      switch (fieldType) {
-        case INTEGER:
-          return new BeamBuiltinAggregations.Max<Integer>(fieldType);
-        case SMALLINT:
-          return new BeamBuiltinAggregations.Max<Short>(fieldType);
-        case TINYINT:
-          return new BeamBuiltinAggregations.Max<Byte>(fieldType);
-        case BIGINT:
-          return new BeamBuiltinAggregations.Max<Long>(fieldType);
-        case FLOAT:
-          return new BeamBuiltinAggregations.Max<Float>(fieldType);
-        case DOUBLE:
-          return new BeamBuiltinAggregations.Max<Double>(fieldType);
-        case TIMESTAMP:
-          return new BeamBuiltinAggregations.Max<Date>(fieldType);
-        case DECIMAL:
-          return new BeamBuiltinAggregations.Max<BigDecimal>(fieldType);
-        default:
-          throw new UnsupportedOperationException(
-              String.format("[%s] is not support in MAX", fieldType));
-      }
-    }
+  public static CombineFn createAvg(SqlTypeName fieldType) {
+    switch (fieldType) {
+    case INTEGER:
+      return new IntegerAvg();
+    case SMALLINT:
+      return new ShortAvg();
+    case TINYINT:
+      return new ByteAvg();
+    case BIGINT:
+      return new LongAvg();
+    case FLOAT:
+      return new FloatAvg();
+    case DOUBLE:
+      return new DoubleAvg();
+    case DECIMAL:
+      return new BigDecimalAvg();
+    default:
+      throw new UnsupportedOperationException(
+          String.format("[%s] is not support in AVG", fieldType));
+  }
+  }
 
-    private final SqlTypeName fieldType;
-    private Max(SqlTypeName fieldType) {
-      this.fieldType = fieldType;
+  static class CustMax<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
+    public T apply(T left, T right) {
+      return (right == null || right.compareTo(left) < 0) ? left : right;
     }
+  }
 
-    @Override
-    public T init() {
-      return null;
+  static class CustMin<T extends Comparable<T>> extends Combine.BinaryCombineFn<T> {
+    public T apply(T left, T right) {
+      return (left == null || left.compareTo(right) < 0) ? left : right;
     }
+  }
 
-    @Override
-    public T add(T accumulator, T input) {
-      return (accumulator == null || accumulator.compareTo(input) < 0) ? input : accumulator;
+  static class ShortSum extends Combine.BinaryCombineFn<Short> {
+    public Short apply(Short left, Short right) {
+      return (short) (left + right);
     }
+  }
 
-    @Override
-    public T merge(Iterable<T> accumulators) {
-      Iterator<T> ite = accumulators.iterator();
-      T mergedV = ite.next();
-      while (ite.hasNext()) {
-        T v = ite.next();
-        mergedV = mergedV.compareTo(v) > 0 ? mergedV : v;
-      }
-      return mergedV;
+  static class ByteSum extends Combine.BinaryCombineFn<Byte> {
+    public Byte apply(Byte left, Byte right) {
+      return (byte) (left + right);
     }
+  }
 
-    @Override
-    public T result(T accumulator) {
-      return accumulator;
+  static class FloatSum extends Combine.BinaryCombineFn<Float> {
+    public Float apply(Float left, Float right) {
+      return left + right;
     }
+  }
 
-    @Override
-    public Coder<T> getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException {
-      return BeamBuiltinAggregations.getSqlTypeCoder(fieldType);
+  static class BigDecimalSum extends Combine.BinaryCombineFn<BigDecimal> {
+    public BigDecimal apply(BigDecimal left, BigDecimal right) {
+      return left.add(right);
     }
   }
 
   /**
-   * Built-in aggregation for MIN.
+   * {@link CombineFn} for <em>AVG</em> on {@link Number} types.
    */
-  public static final class Min<T extends Comparable<T>> extends BeamSqlUdaf<T, T, T> {
-    public static Min create(SqlTypeName fieldType) {
-      switch (fieldType) {
-        case INTEGER:
-          return new BeamBuiltinAggregations.Min<Integer>(fieldType);
-        case SMALLINT:
-          return new BeamBuiltinAggregations.Min<Short>(fieldType);
-        case TINYINT:
-          return new BeamBuiltinAggregations.Min<Byte>(fieldType);
-        case BIGINT:
-          return new BeamBuiltinAggregations.Min<Long>(fieldType);
-        case FLOAT:
-          return new BeamBuiltinAggregations.Min<Float>(fieldType);
-        case DOUBLE:
-          return new BeamBuiltinAggregations.Min<Double>(fieldType);
-        case TIMESTAMP:
-          return new BeamBuiltinAggregations.Min<Date>(fieldType);
-        case DECIMAL:
-          return new BeamBuiltinAggregations.Min<BigDecimal>(fieldType);
-        default:
-          throw new UnsupportedOperationException(
-              String.format("[%s] is not support in MIN", fieldType));
-      }
-    }
-
-    private final SqlTypeName fieldType;
-    private Min(SqlTypeName fieldType) {
-      this.fieldType = fieldType;
-    }
-
+  abstract static class Avg<T extends Number>
+      extends CombineFn<T, KV<Integer, BigDecimal>, T> {
     @Override
-    public T init() {
-      return null;
+    public KV<Integer, BigDecimal> createAccumulator() {
+      return KV.of(0, new BigDecimal(0));
     }
 
     @Override
-    public T add(T accumulator, T input) {
-      return (accumulator == null || accumulator.compareTo(input) > 0) ? input : accumulator;
+    public KV<Integer, BigDecimal> addInput(KV<Integer, BigDecimal> accumulator, T input) {
+      return KV.of(accumulator.getKey() + 1, accumulator.getValue().add(toBigDecimal(input)));
     }
 
     @Override
-    public T merge(Iterable<T> accumulators) {
-      Iterator<T> ite = accumulators.iterator();
-      T mergedV = ite.next();
+    public KV<Integer, BigDecimal> mergeAccumulators(
+        Iterable<KV<Integer, BigDecimal>> accumulators) {
+      int size = 0;
+      BigDecimal acc = new BigDecimal(0);
+      Iterator<KV<Integer, BigDecimal>> ite = accumulators.iterator();
       while (ite.hasNext()) {
-        T v = ite.next();
-        mergedV = mergedV.compareTo(v) < 0 ? mergedV : v;
+        KV<Integer, BigDecimal> ele = ite.next();
+        size += ele.getKey();
+        acc = acc.add(ele.getValue());
       }
-      return mergedV;
+      return KV.of(size, acc);
     }
 
     @Override
-    public T result(T accumulator) {
-      return accumulator;
+    public Coder<KV<Integer, BigDecimal>> getAccumulatorCoder(CoderRegistry registry,
+        Coder<T> inputCoder) throws CannotProvideCoderException {
+      return KvCoder.of(BigEndianIntegerCoder.of(), BigDecimalCoder.of());
     }
 
-    @Override
-    public Coder<T> getAccumulatorCoder(CoderRegistry registry) throws CannotProvideCoderException {
-      return BeamBuiltinAggregations.getSqlTypeCoder(fieldType);
-    }
+    public abstract T extractOutput(KV<Integer, BigDecimal> accumulator);
+    public abstract BigDecimal toBigDecimal(T record);
   }
 
-  /**
-   * Built-in aggregation for SUM.
-   */
-  public static final class Sum<T> extends BeamSqlUdaf<T, BigDecimal, T> {
-    public static Sum create(SqlTypeName fieldType) {
-      switch (fieldType) {
-        case INTEGER:
-          return new BeamBuiltinAggregations.Sum<Integer>(fieldType);
-        case SMALLINT:
-          return new BeamBuiltinAggregations.Sum<Short>(fieldType);
-        case TINYINT:
-          return new BeamBuiltinAggregations.Sum<Byte>(fieldType);
-        case BIGINT:
-          return new BeamBuiltinAggregations.Sum<Long>(fieldType);
-        case FLOAT:
-          return new BeamBuiltinAggregations.Sum<Float>(fieldType);
-        case DOUBLE:
-          return new BeamBuiltinAggregations.Sum<Double>(fieldType);
-        case TIMESTAMP:
-          return new BeamBuiltinAggregations.Sum<Date>(fieldType);
-        case DECIMAL:
-          return new BeamBuiltinAggregations.Sum<BigDecimal>(fieldType);
-        default:
-          throw new UnsupportedOperationException(
-              String.format("[%s] is not support in SUM", fieldType));
-      }
+  static class IntegerAvg extends Avg<Integer>{
+    public Integer extractOutput(KV<Integer, BigDecimal> accumulator) {
+      return accumulator.getKey() == 0 ? null
+          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).intValue();
     }
 
-    private SqlTypeName fieldType;
-      private Sum(SqlTypeName fieldType) {
-        this.fieldType = fieldType;
-      }
+    public BigDecimal toBigDecimal(Integer record) {
+      return new BigDecimal(record);
+    }
+  }
 
-    @Override
-    public BigDecimal init() {
-      return new BigDecimal(0);
+  static class LongAvg extends Avg<Long>{
+    public Long extractOutput(KV<Integer, BigDecimal> accumulator) {
+      return accumulator.getKey() == 0 ? null
+          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).longValue();
     }
 
-    @Override
-    public BigDecimal add(BigDecimal accumulator, T input) {
-      return accumulator.add(new BigDecimal(input.toString()));
+    public BigDecimal toBigDecimal(Long record) {
+      return new BigDecimal(record);
     }
+  }
 
-    @Override
-    public BigDecimal merge(Iterable<BigDecimal> accumulators) {
-      BigDecimal v = new BigDecimal(0);
-      Iterator<BigDecimal> ite = accumulators.iterator();
-      while (ite.hasNext()) {
-        v = v.add(ite.next());
-      }
-      return v;
+  static class ShortAvg extends Avg<Short>{
+    public Short extractOutput(KV<Integer, BigDecimal> accumulator) {
+      return accumulator.getKey() == 0 ? null
+          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).shortValue();
     }
 
-    @Override
-    public T result(BigDecimal accumulator) {
-      Object result = null;
-      switch (fieldType) {
-        case INTEGER:
-          result = accumulator.intValue();
-          break;
-        case BIGINT:
-          result = accumulator.longValue();
-          break;
-        case SMALLINT:
-          result = accumulator.shortValue();
-          break;
-        case TINYINT:
-          result = accumulator.byteValue();
-          break;
-        case DOUBLE:
-          result = accumulator.doubleValue();
-          break;
-        case FLOAT:
-          result = accumulator.floatValue();
-          break;
-        case DECIMAL:
-          result = accumulator;
-          break;
-        default:
-          break;
-      }
-      return (T) result;
+    public BigDecimal toBigDecimal(Short record) {
+      return new BigDecimal(record);
     }
   }
 
-  /**
-   * Built-in aggregation for AVG.
-   */
-  public static final class Avg<T> extends BeamSqlUdaf<T, KV<BigDecimal, Long>, T> {
-    public static Avg create(SqlTypeName fieldType) {
-      switch (fieldType) {
-        case INTEGER:
-          return new BeamBuiltinAggregations.Avg<Integer>(fieldType);
-        case SMALLINT:
-          return new BeamBuiltinAggregations.Avg<Short>(fieldType);
-        case TINYINT:
-          return new BeamBuiltinAggregations.Avg<Byte>(fieldType);
-        case BIGINT:
-          return new BeamBuiltinAggregations.Avg<Long>(fieldType);
-        case FLOAT:
-          return new BeamBuiltinAggregations.Avg<Float>(fieldType);
-        case DOUBLE:
-          return new BeamBuiltinAggregations.Avg<Double>(fieldType);
-        case TIMESTAMP:
-          return new BeamBuiltinAggregations.Avg<Date>(fieldType);
-        case DECIMAL:
-          return new BeamBuiltinAggregations.Avg<BigDecimal>(fieldType);
-        default:
-          throw new UnsupportedOperationException(
-              String.format("[%s] is not support in AVG", fieldType));
-      }
+  static class ByteAvg extends Avg<Byte>{
+    public Byte extractOutput(KV<Integer, BigDecimal> accumulator) {
+      return accumulator.getKey() == 0 ? null
+          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).byteValue();
     }
 
-    private SqlTypeName fieldType;
-      private Avg(SqlTypeName fieldType) {
-        this.fieldType = fieldType;
-      }
-
-    @Override
-    public KV<BigDecimal, Long> init() {
-      return KV.of(new BigDecimal(0), 0L);
+    public BigDecimal toBigDecimal(Byte record) {
+      return new BigDecimal(record);
     }
+  }
 
-    @Override
-    public KV<BigDecimal, Long> add(KV<BigDecimal, Long> accumulator, T input) {
-      return KV.of(
-              accumulator.getKey().add(new BigDecimal(input.toString())),
-              accumulator.getValue() + 1);
+  static class FloatAvg extends Avg<Float>{
+    public Float extractOutput(KV<Integer, BigDecimal> accumulator) {
+      return accumulator.getKey() == 0 ? null
+          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).floatValue();
     }
 
-    @Override
-    public KV<BigDecimal, Long> merge(Iterable<KV<BigDecimal, Long>> accumulators) {
-      BigDecimal v = new BigDecimal(0);
-      long s = 0;
-      Iterator<KV<BigDecimal, Long>> ite = accumulators.iterator();
-      while (ite.hasNext()) {
-        KV<BigDecimal, Long> r = ite.next();
-        v = v.add(r.getKey());
-        s += r.getValue();
-      }
-      return KV.of(v, s);
+    public BigDecimal toBigDecimal(Float record) {
+      return new BigDecimal(record);
     }
+  }
 
-    @Override
-    public T result(KV<BigDecimal, Long> accumulator) {
-      BigDecimal decimalAvg = accumulator.getKey().divide(
-          new BigDecimal(accumulator.getValue()));
-      Object result = null;
-      switch (fieldType) {
-        case INTEGER:
-          result = decimalAvg.intValue();
-          break;
-        case BIGINT:
-          result = decimalAvg.longValue();
-          break;
-        case SMALLINT:
-          result = decimalAvg.shortValue();
-          break;
-        case TINYINT:
-          result = decimalAvg.byteValue();
-          break;
-        case DOUBLE:
-          result = decimalAvg.doubleValue();
-          break;
-        case FLOAT:
-          result = decimalAvg.floatValue();
-          break;
-        case DECIMAL:
-          result = decimalAvg;
-          break;
-        default:
-          break;
-      }
-      return (T) result;
+  static class DoubleAvg extends Avg<Double>{
+    public Double extractOutput(KV<Integer, BigDecimal> accumulator) {
+      return accumulator.getKey() == 0 ? null
+          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).doubleValue();
     }
 
-    @Override
-    public Coder<KV<BigDecimal, Long>> getAccumulatorCoder(CoderRegistry registry)
-        throws CannotProvideCoderException {
-      return KvCoder.of(BigDecimalCoder.of(), VarLongCoder.of());
+    public BigDecimal toBigDecimal(Double record) {
+      return new BigDecimal(record);
     }
   }
 
-  /**
-   * Find {@link Coder} for Beam SQL field types.
-   */
-  private static Coder getSqlTypeCoder(SqlTypeName sqlType) {
-    switch (sqlType) {
-      case INTEGER:
-        return VarIntCoder.of();
-      case SMALLINT:
-        return SerializableCoder.of(Short.class);
-      case TINYINT:
-        return ByteCoder.of();
-      case BIGINT:
-        return VarLongCoder.of();
-      case FLOAT:
-        return SerializableCoder.of(Float.class);
-      case DOUBLE:
-        return DoubleCoder.of();
-      case TIMESTAMP:
-        return SerializableCoder.of(Date.class);
-      case DECIMAL:
-        return BigDecimalCoder.of();
-      default:
-        throw new UnsupportedOperationException(
-            String.format("Cannot find a Coder for data type [%s]", sqlType));
+  static class BigDecimalAvg extends Avg<BigDecimal>{
+    public BigDecimal extractOutput(KV<Integer, BigDecimal> accumulator) {
+      return accumulator.getKey() == 0 ? null
+          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey()));
+    }
+
+    public BigDecimal toBigDecimal(BigDecimal record) {
+      return record;
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java
deleted file mode 100644
index 2f78586..0000000
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlUdaf.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.extensions.sql.schema;
-
-import java.io.Serializable;
-import java.lang.reflect.ParameterizedType;
-import org.apache.beam.sdk.coders.CannotProvideCoderException;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderRegistry;
-import org.apache.beam.sdk.transforms.Combine.CombineFn;
-
-/**
- * abstract class of aggregation functions in Beam SQL.
- *
- * <p>There're several constrains for a UDAF:<br>
- * 1. A constructor with an empty argument list is required;<br>
- * 2. The type of {@code InputT} and {@code OutputT} can only be Interger/Long/Short/Byte/Double
- * /Float/Date/BigDecimal, mapping as SQL type INTEGER/BIGINT/SMALLINT/TINYINE/DOUBLE/FLOAT
- * /TIMESTAMP/DECIMAL;<br>
- * 3. Keep intermediate data in {@code AccumT}, and do not rely on elements in class;<br>
- */
-public abstract class BeamSqlUdaf<InputT, AccumT, OutputT> implements Serializable {
-  public BeamSqlUdaf(){}
-
-  /**
-   * create an initial aggregation object, equals to {@link CombineFn#createAccumulator()}.
-   */
-  public abstract AccumT init();
-
-  /**
-   * add an input value, equals to {@link CombineFn#addInput(Object, Object)}.
-   */
-  public abstract AccumT add(AccumT accumulator, InputT input);
-
-  /**
-   * merge aggregation objects from parallel tasks, equals to
-   *  {@link CombineFn#mergeAccumulators(Iterable)}.
-   */
-  public abstract AccumT merge(Iterable<AccumT> accumulators);
-
-  /**
-   * extract output value from aggregation object, equals to
-   * {@link CombineFn#extractOutput(Object)}.
-   */
-  public abstract OutputT result(AccumT accumulator);
-
-  /**
-   * get the coder for AccumT which stores the intermediate result.
-   * By default it's fetched from {@link CoderRegistry}.
-   */
-  public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry)
-      throws CannotProvideCoderException {
-    return registry.getCoder(
-        (Class<AccumT>) ((ParameterizedType) getClass()
-        .getGenericSuperclass()).getActualTypeArguments()[1]);
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/1770c861/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
index 0552cbf..1541123 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUdfUdafTest.java
@@ -21,9 +21,9 @@ import java.sql.Types;
 import java.util.Arrays;
 import java.util.Iterator;
 import org.apache.beam.sdk.extensions.sql.schema.BeamRecordSqlType;
-import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdaf;
 import org.apache.beam.sdk.extensions.sql.schema.BeamSqlUdf;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.BeamRecord;
 import org.apache.beam.sdk.values.PCollection;
@@ -49,7 +49,7 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
         + " FROM PCOLLECTION GROUP BY f_int2";
     PCollection<BeamRecord> result1 =
         boundedInput1.apply("testUdaf1",
-            BeamSql.simpleQuery(sql1).withUdaf("squaresum1", SquareSum.class));
+            BeamSql.simpleQuery(sql1).withUdaf("squaresum1", new SquareSum()));
     PAssert.that(result1).containsInAnyOrder(record);
 
     String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum`"
@@ -57,7 +57,7 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
     PCollection<BeamRecord> result2 =
         PCollectionTuple.of(new TupleTag<BeamRecord>("PCOLLECTION"), boundedInput1)
         .apply("testUdaf2",
-            BeamSql.query(sql2).withUdaf("squaresum2", SquareSum.class));
+            BeamSql.query(sql2).withUdaf("squaresum2", new SquareSum()));
     PAssert.that(result2).containsInAnyOrder(record);
 
     pipeline.run().waitUntilFinish();
@@ -90,25 +90,21 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
   }
 
   /**
-   * UDAF for test, which returns the sum of square.
+   * UDAF(CombineFn) for test, which returns the sum of square.
    */
-  public static class SquareSum extends BeamSqlUdaf<Integer, Integer, Integer> {
-
-    public SquareSum() {
-    }
-
+  public static class SquareSum extends CombineFn<Integer, Integer, Integer> {
     @Override
-    public Integer init() {
+    public Integer createAccumulator() {
       return 0;
     }
 
     @Override
-    public Integer add(Integer accumulator, Integer input) {
+    public Integer addInput(Integer accumulator, Integer input) {
       return accumulator + input * input;
     }
 
     @Override
-    public Integer merge(Iterable<Integer> accumulators) {
+    public Integer mergeAccumulators(Iterable<Integer> accumulators) {
       int v = 0;
       Iterator<Integer> ite = accumulators.iterator();
       while (ite.hasNext()) {
@@ -118,7 +114,7 @@ public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
     }
 
     @Override
-    public Integer result(Integer accumulator) {
+    public Integer extractOutput(Integer accumulator) {
       return accumulator;
     }
 


Mime
View raw message