calcite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mm...@apache.org
Subject [07/16] calcite git commit: [CALCITE-1945] Make return types of AVG, VARIANCE, STDDEV and COVAR customizable via RelDataTypeSystem
Date Tue, 05 Sep 2017 14:36:50 GMT
[CALCITE-1945] Make return types of AVG, VARIANCE, STDDEV and COVAR customizable via RelDataTypeSystem

* Introduce VARIANCE and STDDEV as alias for _SAMP

Close apache/calcite#518


Project: http://git-wip-us.apache.org/repos/asf/calcite/repo
Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/4208d802
Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/4208d802
Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/4208d802

Branch: refs/heads/branch-1.14
Commit: 4208d8021b4978a3f0a259ec299fa7a62c582180
Parents: 6d2fc4e
Author: MinJi Kim <minji@apache.org>
Authored: Sun Aug 27 16:21:22 2017 -0700
Committer: Julian Hyde <jhyde@apache.org>
Committed: Tue Aug 29 10:15:17 2017 -0700

----------------------------------------------------------------------
 .../calcite/rel/rel2sql/SqlImplementor.java     |   6 +-
 .../rel/rules/AggregateReduceFunctionsRule.java | 127 ++++++++++++-------
 .../calcite/rel/type/RelDataTypeSystem.java     |  18 ++-
 .../calcite/rel/type/RelDataTypeSystemImpl.java |  14 +-
 .../apache/calcite/runtime/SqlFunctions.java    |   4 +
 .../java/org/apache/calcite/sql/SqlKind.java    |   9 ++
 .../calcite/sql/fun/SqlAvgAggFunction.java      |  15 ++-
 .../calcite/sql/fun/SqlCovarAggFunction.java    |   2 +-
 .../calcite/sql/fun/SqlStdOperatorTable.java    |  12 ++
 .../apache/calcite/sql/type/ReturnTypes.java    |  34 ++++-
 .../sql2rel/StandardConvertletTable.java        |  49 +++++--
 .../calcite/sql/test/SqlOperatorBaseTest.java   |  70 ++++++++++
 core/src/test/resources/sql/agg.iq              |  24 +++-
 13 files changed, 303 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
index d227310..57155b7 100644
--- a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
+++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
@@ -665,7 +665,11 @@ public abstract class SqlImplementor {
         }
 
         final RexCall call = (RexCall) stripCastFromString(rex);
-        final SqlOperator op = call.getOperator();
+        SqlOperator op = call.getOperator();
+        switch (op.getKind()) {
+        case SUM0:
+          op = SqlStdOperatorTable.SUM;
+        }
         final List<SqlNode> nodeList = toSql(program, call.getOperands());
         switch (call.getKind()) {
         case CAST:

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index 8fceff0..7e6e4a1 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -31,10 +31,9 @@ import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
 import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.fun.SqlAvgAggFunction;
 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.sql.fun.SqlSumAggFunction;
 import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
@@ -117,8 +116,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
    */
   private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
     for (AggregateCall call : aggCallList) {
-      if (call.getAggregation() instanceof SqlAvgAggFunction
-          || call.getAggregation() instanceof SqlSumAggFunction) {
+      if (isReducible(call.getAggregation().getKind())) {
         return true;
       }
     }
@@ -126,6 +124,20 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
   }
 
   /**
+   * Returns whether the aggregate call is a reducible function
+   */
+  private boolean isReducible(final SqlKind kind) {
+    if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) {
+      return true;
+    }
+    switch (kind) {
+    case SUM:
+      return true;
+    }
+    return false;
+  }
+
+  /**
    * Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
    * the aggregates list to.
    *
@@ -187,17 +199,16 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
       List<AggregateCall> newCalls,
       Map<AggregateCall, RexNode> aggCallMapping,
       List<RexNode> inputExprs) {
-    if (oldCall.getAggregation() instanceof SqlSumAggFunction) {
-      // replace original SUM(x) with
-      // case COUNT(x) when 0 then null else SUM0(x) end
-      return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
-    }
-    if (oldCall.getAggregation() instanceof SqlAvgAggFunction) {
-      final SqlKind kind = oldCall.getAggregation().getKind();
+    final SqlKind kind = oldCall.getAggregation().getKind();
+    if (isReducible(kind)) {
       switch (kind) {
+      case SUM:
+        // replace original SUM(x) with
+        // case COUNT(x) when 0 then null else SUM0(x) end
+        return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
       case AVG:
         // replace original AVG(x) with SUM(x) / COUNT(x)
-        return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
+        return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
       case STDDEV_POP:
         // replace original STDDEV_POP(x) with
         //   SQRT(
@@ -243,19 +254,39 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
     }
   }
 
+  private AggregateCall createAggregateCallWithBinding(
+      RelDataTypeFactory typeFactory,
+      SqlAggFunction aggFunction,
+      RelDataType operandType,
+      Aggregate oldAggRel,
+      AggregateCall oldCall,
+      int argOrdinal) {
+    final Aggregate.AggCallBinding binding =
+        new Aggregate.AggCallBinding(typeFactory, aggFunction,
+            ImmutableList.of(operandType), oldAggRel.getGroupCount(),
+            oldCall.filterArg >= 0);
+    return AggregateCall.create(aggFunction,
+        oldCall.isDistinct(),
+        ImmutableIntList.of(argOrdinal),
+        oldCall.filterArg,
+        aggFunction.inferReturnType(binding),
+        null);
+  }
+
   private RexNode reduceAvg(
       Aggregate oldAggRel,
       AggregateCall oldCall,
       List<AggregateCall> newCalls,
-      Map<AggregateCall, RexNode> aggCallMapping) {
+      Map<AggregateCall, RexNode> aggCallMapping,
+      List<RexNode> inputExprs) {
     final int nGroups = oldAggRel.getGroupCount();
-    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
-    int iAvgInput = oldCall.getArgList().get(0);
-    RelDataType avgInputType =
+    final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+    final int iAvgInput = oldCall.getArgList().get(0);
+    final RelDataType avgInputType =
         getFieldType(
             oldAggRel.getInput(),
             iAvgInput);
-    AggregateCall sumCall =
+    final AggregateCall sumCall =
         AggregateCall.create(
             SqlStdOperatorTable.SUM,
             oldCall.isDistinct(),
@@ -265,7 +296,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
             oldAggRel.getInput(),
             null,
             null);
-    AggregateCall countCall =
+    final AggregateCall countCall =
         AggregateCall.create(
             SqlStdOperatorTable.COUNT,
             oldCall.isDistinct(),
@@ -285,17 +316,20 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
             newCalls,
             aggCallMapping,
             ImmutableList.of(avgInputType));
-    RexNode denominatorRef =
+    final RexNode denominatorRef =
         rexBuilder.addAggCall(countCall,
             nGroups,
             oldAggRel.indicator,
             newCalls,
             aggCallMapping,
             ImmutableList.of(avgInputType));
+
+    final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
+    final RelDataType avgType = typeFactory.createTypeWithNullability(
+        oldCall.getType(), numeratorRef.getType().isNullable());
+    numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
     final RexNode divideRef =
-        rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
-            numeratorRef,
-            denominatorRef);
+        rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
     return rexBuilder.makeCast(oldCall.getType(), divideRef);
   }
 
@@ -381,36 +415,30 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
 
     assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
     final int argOrdinal = oldCall.getArgList().get(0);
-    final RelDataType argType =
-        getFieldType(
-            oldAggRel.getInput(),
-            argOrdinal);
+    final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
+    final RelDataType oldCallType =
+        typeFactory.createTypeWithNullability(oldCall.getType(),
+            argOrdinalType.isNullable());
 
-    final RexNode argRef = inputExprs.get(argOrdinal);
-    final RexNode argSquared =
-        rexBuilder.makeCall(
-            SqlStdOperatorTable.MULTIPLY, argRef, argRef);
+    final RexNode argRef =
+        rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
+    final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
+
+    final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY,
+        argRef, argRef);
     final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
 
-    final Aggregate.AggCallBinding binding =
-        new Aggregate.AggCallBinding(typeFactory, SqlStdOperatorTable.SUM,
-            ImmutableList.of(argRef.getType()), oldAggRel.getGroupCount(),
-            oldCall.filterArg >= 0);
     final AggregateCall sumArgSquaredAggCall =
-        AggregateCall.create(
-            SqlStdOperatorTable.SUM,
-            oldCall.isDistinct(),
-            ImmutableIntList.of(argSquaredOrdinal),
-            oldCall.filterArg,
-            SqlStdOperatorTable.SUM.inferReturnType(binding),
-            null);
+        createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM,
+            argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
+
     final RexNode sumArgSquared =
         rexBuilder.addAggCall(sumArgSquaredAggCall,
             nGroups,
             oldAggRel.indicator,
             newCalls,
             aggCallMapping,
-            ImmutableList.of(argType));
+            ImmutableList.of(sumArgSquaredAggCall.getType()));
 
     final AggregateCall sumArgAggCall =
         AggregateCall.create(
@@ -422,17 +450,18 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
             oldAggRel.getInput(),
             null,
             null);
+
     final RexNode sumArg =
         rexBuilder.addAggCall(sumArgAggCall,
             nGroups,
             oldAggRel.indicator,
             newCalls,
             aggCallMapping,
-            ImmutableList.of(argType));
-
+            ImmutableList.of(sumArgAggCall.getType()));
+    final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
     final RexNode sumSquaredArg =
         rexBuilder.makeCall(
-            SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
+            SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
 
     final AggregateCall countArgAggCall =
         AggregateCall.create(
@@ -441,21 +470,21 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
             oldCall.getArgList(),
             oldCall.filterArg,
             oldAggRel.getGroupCount(),
-            oldAggRel.getInput(),
+            oldAggRel,
             null,
             null);
+
     final RexNode countArg =
         rexBuilder.addAggCall(countArgAggCall,
             nGroups,
             oldAggRel.indicator,
             newCalls,
             aggCallMapping,
-            ImmutableList.of(argType));
+            ImmutableList.of(argOrdinalType));
 
     final RexNode avgSumSquaredArg =
         rexBuilder.makeCall(
-            SqlStdOperatorTable.DIVIDE,
-            sumSquaredArg, countArg);
+            SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
 
     final RexNode diff =
         rexBuilder.makeCall(

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
index 858567c..b8a8088 100644
--- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
+++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
@@ -69,11 +69,21 @@ public interface RelDataTypeSystem {
    * 0 means "not applicable". */
   int getNumTypeRadix(SqlTypeName typeName);
 
-  /**
-   * Returns the return type of a call to the {@code SUM} aggregate function
-   * inferred from its argument type.
+  /** Returns the return type of a call to the {@code SUM} aggregate function,
+   * inferred from its argument type. */
+  RelDataType deriveSumType(RelDataTypeFactory typeFactory,
+      RelDataType argumentType);
+
+  /** Returns the return type of a call to the {@code AVG}, {@code STDDEV} or
+   * {@code VAR} aggregate functions, inferred from its argument type.
    */
-  RelDataType deriveSumType(RelDataTypeFactory typeFactory, RelDataType argumentType);
+  RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory,
+      RelDataType argumentType);
+
+  /** Returns the return type of a call to the {@code COVAR} aggregate function,
+   * inferred from its argument types. */
+  RelDataType deriveCovarType(RelDataTypeFactory typeFactory,
+      RelDataType arg0Type, RelDataType arg1Type);
 
   /** Returns the return type of the {@code CUME_DIST} and {@code PERCENT_RANK}
    * aggregate functions. */

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
index ef89895..3e0eebd 100644
--- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
+++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
@@ -207,11 +207,21 @@ public abstract class RelDataTypeSystemImpl implements RelDataTypeSystem
{
     return 0;
   }
 
-  @Override public RelDataType deriveSumType(
-      RelDataTypeFactory typeFactory, RelDataType argumentType) {
+  @Override public RelDataType deriveSumType(RelDataTypeFactory typeFactory,
+      RelDataType argumentType) {
     return argumentType;
   }
 
+  @Override public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory,
+      RelDataType argumentType) {
+    return argumentType;
+  }
+
+  @Override public RelDataType deriveCovarType(RelDataTypeFactory typeFactory,
+      RelDataType arg0Type, RelDataType arg1Type) {
+    return arg0Type;
+  }
+
   @Override public RelDataType deriveFractionalRankType(RelDataTypeFactory typeFactory) {
     return typeFactory.createTypeWithNullability(
         typeFactory.createSqlType(SqlTypeName.DOUBLE), false);

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index 69c6154..6832ee4 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -823,6 +823,10 @@ public class SqlFunctions {
     return Math.pow(b0, b1);
   }
 
+  public static double power(double b0, BigDecimal b1) {
+    return Math.pow(b0, b1.doubleValue());
+  }
+
   public static double power(long b0, long b1) {
     return Math.pow(b0, b1);
   }

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/SqlKind.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index ad7c4e2..8d7c8aa 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -1119,6 +1119,15 @@ public enum SqlKind {
       EnumSet.of(OTHER_FUNCTION, ROW, TRIM, LTRIM, RTRIM, CAST, JDBC_FN);
 
   /**
+   * Category of SqlAvgAggFunction.
+   *
+   * <p>Consists of {@link #AVG}, {@link #STDDEV_POP}, {@link #STDDEV_SAMP},
+   * {@link #VAR_POP}, {@link #VAR_SAMP}.
+   */
+  public static final Set<SqlKind> AVG_AGG_FUNCTIONS =
+      EnumSet.of(AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP);
+
+  /**
    * Category of comparison operators.
    *
    * <p>Consists of:

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
index 95f8049..6be1ce9 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
@@ -32,26 +32,27 @@ import com.google.common.base.Preconditions;
  * double</code>), and the result is the same type.
  */
 public class SqlAvgAggFunction extends SqlAggFunction {
+
   //~ Constructors -----------------------------------------------------------
 
   /**
    * Creates a SqlAvgAggFunction.
    */
   public SqlAvgAggFunction(SqlKind kind) {
-    super(kind.name(),
+    this(kind.name(), kind);
+  }
+
+  SqlAvgAggFunction(String name, SqlKind kind) {
+    super(name,
         null,
         kind,
-        ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+        ReturnTypes.AVG_AGG_FUNCTION,
         null,
         OperandTypes.NUMERIC,
         SqlFunctionCategory.NUMERIC,
         false,
         false);
-    Preconditions.checkArgument(kind == SqlKind.AVG
-        || kind == SqlKind.STDDEV_POP
-        || kind == SqlKind.STDDEV_SAMP
-        || kind == SqlKind.VAR_POP
-        || kind == SqlKind.VAR_SAMP);
+    Preconditions.checkArgument(SqlKind.AVG_AGG_FUNCTIONS.contains(kind), "unsupported sql
kind");
   }
 
   @Deprecated // to be removed before 2.0

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
index ea23300..8c62290 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
@@ -43,7 +43,7 @@ public class SqlCovarAggFunction extends SqlAggFunction {
     super(kind.name(),
         null,
         kind,
-        ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+        ReturnTypes.COVAR_FUNCTION,
         null,
         OperandTypes.NUMERIC_NUMERIC,
         SqlFunctionCategory.NUMERIC,

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index 3f125bd..39a45b3 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -918,6 +918,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
       new SqlAvgAggFunction(SqlKind.STDDEV_SAMP);
 
   /**
+   * <code>STDDEV</code> aggregate function.
+   */
+  public static final SqlAggFunction STDDEV =
+      new SqlAvgAggFunction("STDDEV", SqlKind.STDDEV_SAMP);
+
+  /**
    * <code>VAR_POP</code> aggregate function.
    */
   public static final SqlAggFunction VAR_POP =
@@ -929,6 +935,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
   public static final SqlAggFunction VAR_SAMP =
       new SqlAvgAggFunction(SqlKind.VAR_SAMP);
 
+  /**
+   * <code>VARIANCE</code> aggregate function.
+   */
+  public static final SqlAggFunction VARIANCE =
+      new SqlAvgAggFunction("VARIANCE", SqlKind.VAR_SAMP);
+
   //-------------------------------------------------------------
   // WINDOW Aggregate Functions
   //-------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 73e99f8..15ca544 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -779,8 +779,10 @@ public abstract class ReturnTypes {
         @Override public RelDataType
         inferReturnType(SqlOperatorBinding opBinding) {
           final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
-          return typeFactory.getTypeSystem()
+          final RelDataType sumType = typeFactory.getTypeSystem()
               .deriveSumType(typeFactory, opBinding.getOperandType(0));
+          // SUM0 should not return null.
+          return typeFactory.createTypeWithNullability(sumType, false);
         }
       };
 
@@ -809,6 +811,36 @@ public abstract class ReturnTypes {
           return typeFactory.getTypeSystem().deriveRankType(typeFactory);
         }
       };
+
+  public static final SqlReturnTypeInference AVG_AGG_FUNCTION =
+      new SqlReturnTypeInference() {
+        @Override public RelDataType
+        inferReturnType(SqlOperatorBinding opBinding) {
+          final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
+          final RelDataType relDataType = typeFactory.getTypeSystem().deriveAvgAggType(
+              typeFactory, opBinding.getOperandType(0));
+          if (opBinding.getGroupCount() == 0 || opBinding.hasFilter()) {
+            return typeFactory.createTypeWithNullability(relDataType, true);
+          } else {
+            return relDataType;
+          }
+        }
+      };
+
+  public static final SqlReturnTypeInference COVAR_FUNCTION =
+      new SqlReturnTypeInference() {
+        @Override public RelDataType
+        inferReturnType(SqlOperatorBinding opBinding) {
+          final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
+          final RelDataType relDataType = typeFactory.getTypeSystem().deriveCovarType(
+              typeFactory, opBinding.getOperandType(0), opBinding.getOperandType(1));
+          if (opBinding.getGroupCount() == 0 || opBinding.hasFilter()) {
+            return typeFactory.createTypeWithNullability(relDataType, true);
+          } else {
+            return relDataType;
+          }
+        }
+      };
 }
 
 // End ReturnTypes.java

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index 0d62f9f..8940629 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -299,11 +299,17 @@ public class StandardConvertletTable extends ReflectiveConvertletTable
{
         SqlStdOperatorTable.STDDEV_SAMP,
         new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
     registerOp(
+        SqlStdOperatorTable.STDDEV,
+        new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
+    registerOp(
         SqlStdOperatorTable.VAR_POP,
         new AvgVarianceConvertlet(SqlKind.VAR_POP));
     registerOp(
         SqlStdOperatorTable.VAR_SAMP,
         new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
+    registerOp(
+        SqlStdOperatorTable.VARIANCE,
+        new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
 
     final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
     registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
@@ -1272,44 +1278,56 @@ public class StandardConvertletTable extends ReflectiveConvertletTable
{
       assert call.operandCount() == 1;
       final SqlNode arg = call.operand(0);
       final SqlNode expr;
+      final RelDataType type =
+          cx.getValidator().getValidatedNodeType(call);
       switch (kind) {
       case AVG:
-        expr = expandAvg(arg);
+        expr = expandAvg(arg, type, cx);
         break;
       case STDDEV_POP:
-        expr = expandVariance(arg, true, true);
+        expr = expandVariance(arg, type, cx, true, true);
         break;
       case STDDEV_SAMP:
-        expr = expandVariance(arg, false, true);
+        expr = expandVariance(arg, type, cx, false, true);
         break;
       case VAR_POP:
-        expr = expandVariance(arg, true, false);
+        expr = expandVariance(arg, type, cx, true, false);
         break;
       case VAR_SAMP:
-        expr = expandVariance(arg, false, false);
+        expr = expandVariance(arg, type, cx, false, false);
         break;
       default:
         throw Util.unexpected(kind);
       }
-      RelDataType type =
-          cx.getValidator().getValidatedNodeType(call);
       RexNode rex = cx.convertExpression(expr);
       return cx.getRexBuilder().ensureType(type, rex, true);
     }
 
     private SqlNode expandAvg(
-        final SqlNode arg) {
+        final SqlNode arg, final RelDataType avgType, final SqlRexContext cx) {
       final SqlParserPos pos = SqlParserPos.ZERO;
       final SqlNode sum =
           SqlStdOperatorTable.SUM.createCall(pos, arg);
+      final RexNode sumRex = cx.convertExpression(sum);
+      final SqlNode sumCast;
+      if (!sumRex.getType().equals(avgType)) {
+        sumCast = SqlStdOperatorTable.CAST.createCall(pos,
+            new SqlDataTypeSpec(
+                new SqlIdentifier(avgType.getSqlTypeName().getName(), pos),
+                avgType.getPrecision(), avgType.getScale(), null, null, pos));
+      } else {
+        sumCast = sum;
+      }
       final SqlNode count =
           SqlStdOperatorTable.COUNT.createCall(pos, arg);
       return SqlStdOperatorTable.DIVIDE.createCall(
-          pos, sum, count);
+          pos, sumCast, count);
     }
 
     private SqlNode expandVariance(
-        final SqlNode arg,
+        final SqlNode argInput,
+        final RelDataType varType,
+        final SqlRexContext cx,
         boolean biased,
         boolean sqrt) {
       // stddev_pop(x) ==>
@@ -1332,6 +1350,17 @@ public class StandardConvertletTable extends ReflectiveConvertletTable
{
       //     (sum(x * x) - sum(x) * sum(x) / count(x))
       //     / (count(x) - 1)
       final SqlParserPos pos = SqlParserPos.ZERO;
+
+      final RexNode argRex = cx.convertExpression(argInput);
+      final SqlNode arg;
+      if (!argRex.getType().equals(varType)) {
+        arg = SqlStdOperatorTable.CAST.createCall(pos,
+            new SqlDataTypeSpec(new SqlIdentifier(varType.getSqlTypeName().getName(), pos),
+                varType.getPrecision(), varType.getScale(), null, null, pos));
+      } else {
+        arg = argInput;
+      }
+
       final SqlNode argSquared =
           SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
       final SqlNode sumArgSquared =

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
index f73921f..15ddb13 100644
--- a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
@@ -6417,6 +6417,33 @@ public abstract class SqlOperatorBaseTest {
         0d);
   }
 
+  @Test public void testStddevFunc() {
+    tester.setFor(SqlStdOperatorTable.STDDEV, VM_EXPAND);
+    tester.checkFails(
+        "stddev(^*^)",
+        "Unknown identifier '\\*'",
+        false);
+    tester.checkFails(
+        "^stddev(cast(null as varchar(2)))^",
+        "(?s)Cannot apply 'STDDEV' to arguments of type 'STDDEV\\(<VARCHAR\\(2\\)>\\)'\\.
Supported form\\(s\\): 'STDDEV\\(<NUMERIC>\\)'.*",
+        false);
+    tester.checkType("stddev(CAST(NULL AS INTEGER))", "INTEGER");
+    checkAggType(tester, "stddev(DISTINCT 1.5)", "DECIMAL(2, 1) NOT NULL");
+    final String[] values = {"0", "CAST(null AS FLOAT)", "3", "3"};
+    // with one value
+    tester.checkAgg(
+        "stddev(x)",
+        new String[]{"5"},
+        null,
+        0d);
+    // with zero values
+    tester.checkAgg(
+        "stddev(x)",
+        new String[]{},
+        null,
+        0d);
+  }
+
   @Test public void testVarPopFunc() {
     tester.setFor(SqlStdOperatorTable.VAR_POP, VM_EXPAND);
     tester.checkFails(
@@ -6505,6 +6532,49 @@ public abstract class SqlOperatorBaseTest {
         0d);
   }
 
+  @Test public void testVarFunc() {
+    tester.setFor(SqlStdOperatorTable.VARIANCE, VM_EXPAND);
+    tester.checkFails(
+        "variance(^*^)",
+        "Unknown identifier '\\*'",
+        false);
+    tester.checkFails(
+        "^variance(cast(null as varchar(2)))^",
+        "(?s)Cannot apply 'VARIANCE' to arguments of type 'VARIANCE\\(<VARCHAR\\(2\\)>\\)'\\.
Supported form\\(s\\): 'VARIANCE\\(<NUMERIC>\\)'.*",
+        false);
+    tester.checkType("variance(CAST(NULL AS INTEGER))", "INTEGER");
+    checkAggType(tester, "variance(DISTINCT 1.5)", "DECIMAL(2, 1) NOT NULL");
+    final String[] values = {"0", "CAST(null AS FLOAT)", "3", "3"};
+    if (!enable) {
+      return;
+    }
+    tester.checkAgg(
+        "variance(x)", values, 3d, // verified on Oracle 10g
+        0d);
+    tester.checkAgg(
+        "variance(DISTINCT x)", // Oracle does not allow distinct
+        values,
+        4.5d,
+        0.0001d);
+    tester.checkAgg(
+        "variance(DISTINCT CASE x WHEN 0 THEN NULL ELSE -1 END)",
+        values,
+        null,
+        0d);
+    // with one value
+    tester.checkAgg(
+        "variance(x)",
+        new String[]{"5"},
+        null,
+        0d);
+    // with zero values
+    tester.checkAgg(
+        "variance(x)",
+        new String[]{},
+        null,
+        0d);
+  }
+
   @Test public void testMinFunc() {
     tester.setFor(SqlStdOperatorTable.MIN, VM_EXPAND);
     tester.checkFails(

http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/test/resources/sql/agg.iq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq
index e4ec228..28e8b4b 100755
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -85,19 +85,31 @@ select stddev_pop(deptno) as s from emp;
 
 !ok
 
+# stddev
+select stddev(deptno) as s from emp;
++----+
+| S  |
++----+
+| 19 |
++----+
+(1 row)
+
+!ok
+
 # both
 select gender,
   stddev_pop(deptno) as p,
   stddev_samp(deptno) as s,
+  stddev(deptno) as ss,
   count(deptno) as c
 from emp
 group by gender;
-+--------+----+----+---+
-| GENDER | P  | S  | C |
-+--------+----+----+---+
-| F      | 17 | 19 | 5 |
-| M      | 17 | 20 | 3 |
-+--------+----+----+---+
++--------+----+----+----+---+
+| GENDER | P  | S  | SS | C |
++--------+----+----+----+---+
+| F      | 17 | 19 | 19 | 5 |
+| M      | 17 | 20 | 20 | 3 |
++--------+----+----+----+---+
 (2 rows)
 
 !ok


Mime
View raw message