beam-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ming...@apache.org
Subject [1/2] beam git commit: var_pop and var_samp
Date Tue, 29 Aug 2017 20:35:33 GMT
Repository: beam
Updated Branches:
  refs/heads/DSL_SQL c6f1f9fd2 -> 417ff43c0


var_pop and var_samp

two builtin aggregation functions

fix checkstyle

address comments

change type of sum in class VarAgg to BigDecimal
move isSamp field to Var
rename VarPop -> Var to make more generic
move logic to prepareOutput() both for Avg and Var
set MathContext to handle potential exception with BigDecimal divide.

newlines

rebase issue

assertEquals to test with delta


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/c26cbeaf
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/c26cbeaf
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/c26cbeaf

Branch: refs/heads/DSL_SQL
Commit: c26cbeaf46fbb5dc45e42c8a7de58dd7056f4f06
Parents: c6f1f9f
Author: Kai Jiang <jiangkai@gmail.com>
Authored: Wed Aug 16 16:13:26 2017 -0700
Committer: Kai Jiang <jiangkai@gmail.com>
Committed: Tue Aug 29 13:28:44 2017 -0700

----------------------------------------------------------------------
 .../transform/BeamAggregationTransforms.java    |   8 +
 .../impl/transform/BeamBuiltinAggregations.java | 433 +++++++++++++++----
 .../sql/BeamSqlDslAggregationTest.java          | 119 ++++-
 3 files changed, 461 insertions(+), 99 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/c26cbeaf/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
index 9a50e21..f8c4c6f 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamAggregationTransforms.java
@@ -187,6 +187,14 @@ public class BeamAggregationTransforms implements Serializable{
           case "AVG":
             aggregators.add(BeamBuiltinAggregations.createAvg(call.type.getSqlTypeName()));
             break;
+          case "VAR_POP":
+            aggregators.add(BeamBuiltinAggregations.createVar(call.type.getSqlTypeName(),
+                    false));
+            break;
+          case "VAR_SAMP":
+            aggregators.add(BeamBuiltinAggregations.createVar(call.type.getSqlTypeName(),
+                    true));
+            break;
           default:
             if (call.getAggregation() instanceof SqlUserDefinedAggFunction) {
               // handle UDAF.

http://git-wip-us.apache.org/repos/asf/beam/blob/c26cbeaf/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
index 03edf13..b5a5266 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
@@ -17,7 +17,10 @@
  */
 package org.apache.beam.sdk.extensions.sql.impl.transform;
 
+import java.io.Serializable;
 import java.math.BigDecimal;
+import java.math.MathContext;
+import java.math.RoundingMode;
 import java.util.Date;
 import java.util.Iterator;
 import org.apache.beam.sdk.coders.BigDecimalCoder;
@@ -26,6 +29,7 @@ import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.Max;
@@ -35,34 +39,36 @@ import org.apache.beam.sdk.values.KV;
 import org.apache.calcite.sql.type.SqlTypeName;
 
 /**
- * Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG.
+ * Built-in aggregations functions for COUNT/MAX/MIN/SUM/AVG/VAR_POP/VAR_SAMP.
  */
 class BeamBuiltinAggregations {
+  private static MathContext mc = new MathContext(10, RoundingMode.HALF_UP);
+
   /**
    * {@link CombineFn} for MAX based on {@link Max} and {@link Combine.BinaryCombineFn}.
    */
   public static CombineFn createMax(SqlTypeName fieldType) {
     switch (fieldType) {
-    case INTEGER:
-      return Max.ofIntegers();
-    case SMALLINT:
-      return new CustMax<Short>();
-    case TINYINT:
-      return new CustMax<Byte>();
-    case BIGINT:
-      return Max.ofLongs();
-    case FLOAT:
-      return new CustMax<Float>();
-    case DOUBLE:
-      return Max.ofDoubles();
-    case TIMESTAMP:
-      return new CustMax<Date>();
-    case DECIMAL:
-      return new CustMax<BigDecimal>();
-    default:
-      throw new UnsupportedOperationException(
-          String.format("[%s] is not support in MAX", fieldType));
-  }
+      case INTEGER:
+        return Max.ofIntegers();
+      case SMALLINT:
+        return new CustMax<Short>();
+      case TINYINT:
+        return new CustMax<Byte>();
+      case BIGINT:
+        return Max.ofLongs();
+      case FLOAT:
+        return new CustMax<Float>();
+      case DOUBLE:
+        return Max.ofDoubles();
+      case TIMESTAMP:
+        return new CustMax<Date>();
+      case DECIMAL:
+        return new CustMax<BigDecimal>();
+      default:
+        throw new UnsupportedOperationException(
+            String.format("[%s] is not support in MAX", fieldType));
+    }
   }
 
   /**
@@ -70,26 +76,26 @@ class BeamBuiltinAggregations {
    */
   public static CombineFn createMin(SqlTypeName fieldType) {
     switch (fieldType) {
-    case INTEGER:
-      return Min.ofIntegers();
-    case SMALLINT:
-      return new CustMin<Short>();
-    case TINYINT:
-      return new CustMin<Byte>();
-    case BIGINT:
-      return Min.ofLongs();
-    case FLOAT:
-      return new CustMin<Float>();
-    case DOUBLE:
-      return Min.ofDoubles();
-    case TIMESTAMP:
-      return new CustMin<Date>();
-    case DECIMAL:
-      return new CustMin<BigDecimal>();
-    default:
-      throw new UnsupportedOperationException(
-          String.format("[%s] is not support in MIN", fieldType));
-  }
+      case INTEGER:
+        return Min.ofIntegers();
+      case SMALLINT:
+        return new CustMin<Short>();
+      case TINYINT:
+        return new CustMin<Byte>();
+      case BIGINT:
+        return Min.ofLongs();
+      case FLOAT:
+        return new CustMin<Float>();
+      case DOUBLE:
+        return Min.ofDoubles();
+      case TIMESTAMP:
+        return new CustMin<Date>();
+      case DECIMAL:
+        return new CustMin<BigDecimal>();
+      default:
+        throw new UnsupportedOperationException(
+            String.format("[%s] is not support in MIN", fieldType));
+    }
   }
 
   /**
@@ -97,24 +103,24 @@ class BeamBuiltinAggregations {
    */
   public static CombineFn createSum(SqlTypeName fieldType) {
     switch (fieldType) {
-    case INTEGER:
-      return Sum.ofIntegers();
-    case SMALLINT:
-      return new ShortSum();
-    case TINYINT:
-      return new ByteSum();
-    case BIGINT:
-      return Sum.ofLongs();
-    case FLOAT:
-      return new FloatSum();
-    case DOUBLE:
-      return Sum.ofDoubles();
-    case DECIMAL:
-      return new BigDecimalSum();
-    default:
-      throw new UnsupportedOperationException(
-          String.format("[%s] is not support in SUM", fieldType));
-  }
+      case INTEGER:
+        return Sum.ofIntegers();
+      case SMALLINT:
+        return new ShortSum();
+      case TINYINT:
+        return new ByteSum();
+      case BIGINT:
+        return Sum.ofLongs();
+      case FLOAT:
+        return new FloatSum();
+      case DOUBLE:
+        return Sum.ofDoubles();
+      case DECIMAL:
+        return new BigDecimalSum();
+      default:
+        throw new UnsupportedOperationException(
+            String.format("[%s] is not support in SUM", fieldType));
+    }
   }
 
   /**
@@ -122,24 +128,49 @@ class BeamBuiltinAggregations {
    */
   public static CombineFn createAvg(SqlTypeName fieldType) {
     switch (fieldType) {
-    case INTEGER:
-      return new IntegerAvg();
-    case SMALLINT:
-      return new ShortAvg();
-    case TINYINT:
-      return new ByteAvg();
-    case BIGINT:
-      return new LongAvg();
-    case FLOAT:
-      return new FloatAvg();
-    case DOUBLE:
-      return new DoubleAvg();
-    case DECIMAL:
-      return new BigDecimalAvg();
-    default:
-      throw new UnsupportedOperationException(
-          String.format("[%s] is not support in AVG", fieldType));
+      case INTEGER:
+        return new IntegerAvg();
+      case SMALLINT:
+        return new ShortAvg();
+      case TINYINT:
+        return new ByteAvg();
+      case BIGINT:
+        return new LongAvg();
+      case FLOAT:
+        return new FloatAvg();
+      case DOUBLE:
+        return new DoubleAvg();
+      case DECIMAL:
+        return new BigDecimalAvg();
+      default:
+        throw new UnsupportedOperationException(
+            String.format("[%s] is not support in AVG", fieldType));
+    }
   }
+
+  /**
+   * {@link CombineFn} for VAR_POP and VAR_SAMP.
+   */
+  public static CombineFn createVar(SqlTypeName fieldType, boolean isSamp) {
+    switch (fieldType) {
+      case INTEGER:
+        return new IntegerVar(isSamp);
+      case SMALLINT:
+        return new ShortVar(isSamp);
+      case TINYINT:
+        return new ByteVar(isSamp);
+      case BIGINT:
+        return new LongVar(isSamp);
+      case FLOAT:
+        return new FloatVar(isSamp);
+      case DOUBLE:
+        return new DoubleVar(isSamp);
+      case DECIMAL:
+        return new BigDecimalVar(isSamp);
+      default:
+        throw new UnsupportedOperationException(
+            String.format("[%s] is not support in AVG", fieldType));
+    }
   }
 
   static class CustMax<T extends Comparable<T>> extends Combine.BinaryCombineFn<T>
{
@@ -213,14 +244,17 @@ class BeamBuiltinAggregations {
       return KvCoder.of(BigEndianIntegerCoder.of(), BigDecimalCoder.of());
     }
 
+    protected BigDecimal prepareOutput(KV<Integer, BigDecimal> accumulator){
+      return accumulator.getValue().divide(new BigDecimal(accumulator.getKey()), mc);
+    }
+
     public abstract T extractOutput(KV<Integer, BigDecimal> accumulator);
     public abstract BigDecimal toBigDecimal(T record);
   }
 
   static class IntegerAvg extends Avg<Integer>{
     public Integer extractOutput(KV<Integer, BigDecimal> accumulator) {
-      return accumulator.getKey() == 0 ? null
-          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).intValue();
+      return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).intValue();
     }
 
     public BigDecimal toBigDecimal(Integer record) {
@@ -230,8 +264,7 @@ class BeamBuiltinAggregations {
 
   static class LongAvg extends Avg<Long>{
     public Long extractOutput(KV<Integer, BigDecimal> accumulator) {
-      return accumulator.getKey() == 0 ? null
-          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).longValue();
+      return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).longValue();
     }
 
     public BigDecimal toBigDecimal(Long record) {
@@ -241,8 +274,7 @@ class BeamBuiltinAggregations {
 
   static class ShortAvg extends Avg<Short>{
     public Short extractOutput(KV<Integer, BigDecimal> accumulator) {
-      return accumulator.getKey() == 0 ? null
-          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).shortValue();
+      return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).shortValue();
     }
 
     public BigDecimal toBigDecimal(Short record) {
@@ -252,8 +284,7 @@ class BeamBuiltinAggregations {
 
   static class ByteAvg extends Avg<Byte>{
     public Byte extractOutput(KV<Integer, BigDecimal> accumulator) {
-      return accumulator.getKey() == 0 ? null
-          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).byteValue();
+      return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).byteValue();
     }
 
     public BigDecimal toBigDecimal(Byte record) {
@@ -263,8 +294,7 @@ class BeamBuiltinAggregations {
 
   static class FloatAvg extends Avg<Float>{
     public Float extractOutput(KV<Integer, BigDecimal> accumulator) {
-      return accumulator.getKey() == 0 ? null
-          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).floatValue();
+      return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).floatValue();
     }
 
     public BigDecimal toBigDecimal(Float record) {
@@ -274,8 +304,7 @@ class BeamBuiltinAggregations {
 
   static class DoubleAvg extends Avg<Double>{
     public Double extractOutput(KV<Integer, BigDecimal> accumulator) {
-      return accumulator.getKey() == 0 ? null
-          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey())).doubleValue();
+      return accumulator.getKey() == 0 ? null : prepareOutput(accumulator).doubleValue();
     }
 
     public BigDecimal toBigDecimal(Double record) {
@@ -285,10 +314,234 @@ class BeamBuiltinAggregations {
 
   static class BigDecimalAvg extends Avg<BigDecimal>{
     public BigDecimal extractOutput(KV<Integer, BigDecimal> accumulator) {
-      return accumulator.getKey() == 0 ? null
-          : accumulator.getValue().divide(new BigDecimal(accumulator.getKey()));
+      return accumulator.getKey() == 0 ? null : prepareOutput(accumulator);
+    }
+
+    public BigDecimal toBigDecimal(BigDecimal record) {
+      return record;
+    }
+  }
+
+  static class VarAgg implements Serializable {
+    long count; // number of elements
+    BigDecimal sum; // sum of elements
+
+    public VarAgg(long count, BigDecimal sum) {
+      this.count = count;
+      this.sum = sum;
+   }
+  }
+
+  /**
+   * {@link CombineFn} for <em>Var</em> on {@link Number} types.
+   * Variance Pop and Variance Sample
+   * <p>Evaluate the variance using the algorithm described by Chan, Golub, and LeVeque
in
+   * "Algorithms for computing the sample variance: analysis and recommendations"
+   * The American Statistician, 37 (1983) pp. 242--247.</p>
+   * <p>variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2),2)</p>
+   * <p>where: - variance is sum[x-avg^2] (this is actually n times the variance)
+   * and is updated at every step. - n is the count of elements in chunk1 - m is
+   * the count of elements in chunk2 - t1 = sum of elements in chunk1, t2 =
+   * sum of elements in chunk2.</p>
+   */
+  abstract static class Var<T extends Number>
+          extends CombineFn<T, KV<BigDecimal, VarAgg>, T> {
+    boolean isSamp;  // flag to determine return value should be Variance Pop or Variance
Sample
+
+    public Var(boolean isSamp){
+      this.isSamp = isSamp;
+    }
+
+    @Override
+    public KV<BigDecimal, VarAgg> createAccumulator() {
+      VarAgg varagg = new VarAgg(0L, new BigDecimal(0));
+      return KV.of(new BigDecimal(0), varagg);
+    }
+
+    @Override
+    public KV<BigDecimal, VarAgg> addInput(KV<BigDecimal, VarAgg> accumulator,
T input) {
+      BigDecimal v;
+      if (input == null) {
+        return accumulator;
+      } else {
+        v = new BigDecimal(input.toString());
+        accumulator.getValue().count++;
+        accumulator.getValue().sum = accumulator.getValue().sum
+                .add(new BigDecimal(input.toString()));
+        BigDecimal variance;
+        if (accumulator.getValue().count > 1) {
+
+//          pseudo code for the formula
+//          t = count * v - sum;
+//          variance = (t^2) / (count * (count - 1));
+          BigDecimal t = v.multiply(new BigDecimal(accumulator.getValue().count))
+                                    .subtract(accumulator.getValue().sum);
+          variance = t.pow(2)
+                  .divide(new BigDecimal(accumulator.getValue().count)
+                            .multiply(new BigDecimal(accumulator.getValue().count)
+                                      .subtract(BigDecimal.ONE)), mc);
+        } else {
+          variance = BigDecimal.ZERO;
+        }
+       return KV.of(accumulator.getKey().add(variance), accumulator.getValue());
+      }
+    }
+
+    @Override
+    public KV<BigDecimal, VarAgg> mergeAccumulators(
+            Iterable<KV<BigDecimal, VarAgg>> accumulators) {
+      BigDecimal variance = new BigDecimal(0);
+      long count = 0;
+      BigDecimal sum = new BigDecimal(0);
+
+      Iterator<KV<BigDecimal, VarAgg>> ite = accumulators.iterator();
+      while (ite.hasNext()) {
+        KV<BigDecimal, VarAgg> r = ite.next();
+
+        BigDecimal b = r.getValue().sum;
+
+        count += r.getValue().count;
+        sum = sum.add(b);
+
+//        t = ( r.count / count ) * sum - b;
+//        d = t^2 * ( ( count / r.count ) / ( count + r.count ) );
+        BigDecimal t = new BigDecimal(r.getValue().count).divide(new BigDecimal(count), mc)
+                .multiply(sum).subtract(b);
+        BigDecimal d = t.pow(2)
+                .multiply(new BigDecimal(r.getValue().count).divide(new BigDecimal(count),
mc)
+                          .divide(new BigDecimal(count)
+                                  .add(new BigDecimal(r.getValue().count))), mc);
+        variance = variance.add(r.getKey().add(d));
+      }
+
+      return KV.of(variance, new VarAgg(count, sum));
+    }
+
+    @Override
+    public Coder<KV<BigDecimal, VarAgg>> getAccumulatorCoder(CoderRegistry registry,
+        Coder<T> inputCoder) throws CannotProvideCoderException {
+      return KvCoder.of(BigDecimalCoder.of(), SerializableCoder.of(VarAgg.class));
+    }
+
+    protected BigDecimal prepareOutput(KV<BigDecimal, VarAgg> accumulator){
+      BigDecimal decimalVar;
+      if (accumulator.getValue().count > 1) {
+        BigDecimal a = accumulator.getKey();
+        BigDecimal b = new BigDecimal(accumulator.getValue().count)
+                .subtract(this.isSamp ? BigDecimal.ONE : BigDecimal.ZERO);
+
+        decimalVar = a.divide(b, mc);
+      } else {
+        decimalVar = BigDecimal.ZERO;
+      }
+      return decimalVar;
+    }
+
+    public abstract T extractOutput(KV<BigDecimal, VarAgg> accumulator);
+
+    public abstract BigDecimal toBigDecimal(T record);
+  }
+
+  static class IntegerVar extends Var<Integer> {
+    public IntegerVar(boolean isSamp) {
+      super(isSamp);
+    }
+
+    public Integer extractOutput(KV<BigDecimal, VarAgg> accumulator) {
+      return prepareOutput(accumulator).intValue();
     }
 
+    @Override
+    public BigDecimal toBigDecimal(Integer record) {
+      return new BigDecimal(record);
+    }
+  }
+
+  static class ShortVar extends Var<Short> {
+    public ShortVar(boolean isSamp) {
+      super(isSamp);
+    }
+
+    public Short extractOutput(KV<BigDecimal, VarAgg> accumulator) {
+      return prepareOutput(accumulator).shortValue();
+    }
+
+    @Override
+    public BigDecimal toBigDecimal(Short record) {
+      return new BigDecimal(record);
+    }
+  }
+
+  static class ByteVar extends Var<Byte> {
+    public ByteVar(boolean isSamp) {
+      super(isSamp);
+    }
+
+    public Byte extractOutput(KV<BigDecimal, VarAgg> accumulator) {
+      return prepareOutput(accumulator).byteValue();
+    }
+
+    @Override
+    public BigDecimal toBigDecimal(Byte record) {
+      return new BigDecimal(record);
+    }
+  }
+
+  static class LongVar extends Var<Long> {
+    public LongVar(boolean isSamp) {
+      super(isSamp);
+    }
+
+    public Long extractOutput(KV<BigDecimal, VarAgg> accumulator) {
+      return prepareOutput(accumulator).longValue();
+    }
+
+    @Override
+    public BigDecimal toBigDecimal(Long record) {
+      return new BigDecimal(record);
+    }
+  }
+
+  static class FloatVar extends Var<Float> {
+    public FloatVar(boolean isSamp) {
+      super(isSamp);
+    }
+
+    public Float extractOutput(KV<BigDecimal, VarAgg> accumulator) {
+      return prepareOutput(accumulator).floatValue();
+    }
+
+    @Override
+    public BigDecimal toBigDecimal(Float record) {
+      return new BigDecimal(record);
+    }
+  }
+
+  static class DoubleVar extends Var<Double> {
+    public DoubleVar(boolean isSamp) {
+      super(isSamp);
+    }
+
+    public Double extractOutput(KV<BigDecimal, VarAgg> accumulator) {
+      return prepareOutput(accumulator).doubleValue();
+    }
+
+    @Override
+    public BigDecimal toBigDecimal(Double record) {
+      return new BigDecimal(record);
+    }
+  }
+
+  static class BigDecimalVar extends Var<BigDecimal> {
+    public BigDecimalVar(boolean isSamp) {
+      super(isSamp);
+    }
+
+    public BigDecimal extractOutput(KV<BigDecimal, VarAgg> accumulator) {
+      return prepareOutput(accumulator);
+    }
+
+    @Override
     public BigDecimal toBigDecimal(BigDecimal record) {
       return record;
     }

http://git-wip-us.apache.org/repos/asf/beam/blob/c26cbeaf/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
index c0b857d..76d2313 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationTest.java
@@ -17,13 +17,25 @@
  */
 package org.apache.beam.sdk.extensions.sql;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+import java.math.BigDecimal;
 import java.sql.Types;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.BeamRecord;
+import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.TupleTag;
+import org.junit.Before;
 import org.junit.Test;
 
 /**
@@ -31,6 +43,47 @@ import org.junit.Test;
  * with BOUNDED PCollection.
  */
 public class BeamSqlDslAggregationTest extends BeamSqlDslBase {
+  public PCollection<BeamRecord> boundedInput3;
+
+  @Before
+  public void setUp(){
+    BeamRecordSqlType rowTypeInTableB = BeamRecordSqlType.create(
+            Arrays.asList("f_int", "f_double", "f_int2", "f_decimal"),
+            Arrays.asList(Types.INTEGER, Types.DOUBLE, Types.INTEGER, Types.DECIMAL));
+
+    List<BeamRecord> recordsInTableB = new ArrayList<>();
+    BeamRecord row1 = new BeamRecord(rowTypeInTableB
+            , 1, 1.0, 0, new BigDecimal(1));
+    recordsInTableB.add(row1);
+
+    BeamRecord row2 = new BeamRecord(rowTypeInTableB
+            , 4, 4.0, 0, new BigDecimal(4));
+    recordsInTableB.add(row2);
+
+    BeamRecord row3 = new BeamRecord(rowTypeInTableB
+            , 7, 7.0, 0, new BigDecimal(7));
+    recordsInTableB.add(row3);
+
+    BeamRecord row4 = new BeamRecord(rowTypeInTableB
+            , 13, 13.0, 0, new BigDecimal(13));
+    recordsInTableB.add(row4);
+
+    BeamRecord row5 = new BeamRecord(rowTypeInTableB
+            , 5, 5.0, 0, new BigDecimal(5));
+    recordsInTableB.add(row5);
+
+    BeamRecord row6 = new BeamRecord(rowTypeInTableB
+            , 10, 10.0, 0, new BigDecimal(10));
+    recordsInTableB.add(row6);
+
+    BeamRecord row7 = new BeamRecord(rowTypeInTableB
+            , 17, 17.0, 0, new BigDecimal(17));
+    recordsInTableB.add(row7);
+
+    boundedInput3 = PBegin.in(pipeline).apply("boundedInput3",
+            Create.of(recordsInTableB).withCoder(rowTypeInTableB.getRecordCoder()));
+  }
+
   /**
    * GROUP-BY with single aggregation function with bounded PCollection.
    */
@@ -82,13 +135,15 @@ public class BeamSqlDslAggregationTest extends BeamSqlDslBase {
 
   private void runAggregationFunctions(PCollection<BeamRecord> input) throws Exception{
     String sql = "select f_int2, count(*) as getFieldCount, "
-        + "sum(f_long) as sum1, avg(f_long) as avg1, max(f_long) as max1, min(f_long) as
min1,"
-        + "sum(f_short) as sum2, avg(f_short) as avg2, max(f_short) as max2, min(f_short)
as min2,"
-        + "sum(f_byte) as sum3, avg(f_byte) as avg3, max(f_byte) as max3, min(f_byte) as
min3,"
-        + "sum(f_float) as sum4, avg(f_float) as avg4, max(f_float) as max4, min(f_float)
as min4,"
+        + "sum(f_long) as sum1, avg(f_long) as avg1, max(f_long) as max1, min(f_long) as
min1, "
+        + "sum(f_short) as sum2, avg(f_short) as avg2, max(f_short) as max2, min(f_short)
as min2, "
+        + "sum(f_byte) as sum3, avg(f_byte) as avg3, max(f_byte) as max3, min(f_byte) as
min3, "
+        + "sum(f_float) as sum4, avg(f_float) as avg4, max(f_float) as max4, min(f_float)
as min4, "
         + "sum(f_double) as sum5, avg(f_double) as avg5, "
-        + "max(f_double) as max5, min(f_double) as min5,"
-        + "max(f_timestamp) as max6, min(f_timestamp) as min6 "
+        + "max(f_double) as max5, min(f_double) as min5, "
+        + "max(f_timestamp) as max6, min(f_timestamp) as min6, "
+        + "var_pop(f_double) as varpop1, var_samp(f_double) as varsamp1, "
+        + "var_pop(f_int) as varpop2, var_samp(f_int) as varsamp2 "
         + "FROM TABLE_A group by f_int2";
 
     PCollection<BeamRecord> result =
@@ -98,12 +153,14 @@ public class BeamSqlDslAggregationTest extends BeamSqlDslBase {
     BeamRecordSqlType resultType = BeamRecordSqlType.create(
         Arrays.asList("f_int2", "size", "sum1", "avg1", "max1", "min1", "sum2", "avg2", "max2",
             "min2", "sum3", "avg3", "max3", "min3", "sum4", "avg4", "max4", "min4", "sum5",
"avg5",
-            "max5", "min5", "max6", "min6"),
+            "max5", "min5", "max6", "min6",
+            "varpop1", "varsamp1", "varpop2", "varsamp2"),
         Arrays.asList(Types.INTEGER, Types.BIGINT, Types.BIGINT, Types.BIGINT, Types.BIGINT,
             Types.BIGINT, Types.SMALLINT, Types.SMALLINT, Types.SMALLINT, Types.SMALLINT,
             Types.TINYINT, Types.TINYINT, Types.TINYINT, Types.TINYINT, Types.FLOAT, Types.FLOAT,
             Types.FLOAT, Types.FLOAT, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE, Types.DOUBLE,
-            Types.TIMESTAMP, Types.TIMESTAMP));
+            Types.TIMESTAMP, Types.TIMESTAMP,
+            Types.DOUBLE, Types.DOUBLE, Types.INTEGER, Types.INTEGER));
 
     BeamRecord record = new BeamRecord(resultType
         , 0, 4L
@@ -112,13 +169,57 @@ public class BeamSqlDslAggregationTest extends BeamSqlDslBase {
         , (byte) 10, (byte) 2, (byte) 4, (byte) 1
         , 10.0F, 2.5F, 4.0F, 1.0F
         , 10.0, 2.5, 4.0, 1.0
-        , FORMAT.parse("2017-01-01 02:04:03"), FORMAT.parse("2017-01-01 01:01:03"));
+        , FORMAT.parse("2017-01-01 02:04:03"), FORMAT.parse("2017-01-01 01:01:03")
+        , 1.25, 1.666666667, 1, 1);
 
     PAssert.that(result).containsInAnyOrder(record);
 
     pipeline.run().waitUntilFinish();
   }
 
+  private static class CheckerBigDecimalDivide
+          implements SerializableFunction<Iterable<BeamRecord>, Void> {
+    @Override public Void apply(Iterable<BeamRecord> input) {
+      Iterator<BeamRecord> iter = input.iterator();
+      assertTrue(iter.hasNext());
+      BeamRecord row = iter.next();
+      assertEquals(row.getDouble("avg1"), 8.142857143, 1e-7);
+      assertTrue(row.getInteger("avg2") == 8);
+      assertEquals(row.getDouble("varpop1"), 26.40816326, 1e-7);
+      assertTrue(row.getInteger("varpop2") == 26);
+      assertEquals(row.getDouble("varsamp1"), 30.80952381, 1e-7);
+      assertTrue(row.getInteger("varsamp2") == 30);
+      assertFalse(iter.hasNext());
+      return null;
+    }
+  }
+
+  /**
+   * GROUP-BY with aggregation functions with BigDeciaml Calculation (Avg, Var_Pop, etc).
+   */
+  @Test
+  public void testAggregationFunctionsWithBoundedOnBigDecimalDivide() throws Exception {
+    String sql = "SELECT AVG(f_double) as avg1, AVG(f_int) as avg2, "
+            + "VAR_POP(f_double) as varpop1, VAR_POP(f_int) as varpop2, "
+            + "VAR_SAMP(f_double) as varsamp1, VAR_SAMP(f_int) as varsamp2 "
+            + "FROM PCOLLECTION GROUP BY f_int2";
+
+    PCollection<BeamRecord> result =
+            boundedInput3.apply("testAggregationWithDecimalValue", BeamSql.query(sql));
+
+    BeamRecordSqlType resultType = BeamRecordSqlType.create(
+            Arrays.asList("avg1", "avg2", "avg3",
+                    "varpop1", "varpop2",
+                    "varsamp1", "varsamp2"),
+            Arrays.asList(Types.DOUBLE, Types.INTEGER, Types.DECIMAL,
+                    Types.DOUBLE, Types.INTEGER,
+                    Types.DOUBLE, Types.INTEGER));
+
+    PAssert.that(result).satisfies(new CheckerBigDecimalDivide());
+
+    pipeline.run().waitUntilFinish();
+  }
+
   /**
    * Implicit GROUP-BY with DISTINCT with bounded PCollection.
    */


Mime
View raw message