drill-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jinfengni <...@git.apache.org>
Subject [GitHub] drill pull request #882: DRILL-4735: ConvertCountToDirectScan rule enhanceme...
Date Fri, 04 Aug 2017 18:02:02 GMT
Github user jinfengni commented on a diff in the pull request:

    https://github.com/apache/drill/pull/882#discussion_r131447047
  
    --- Diff: exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ConvertCountToDirectScan.java
---
    @@ -85,109 +91,231 @@ protected ConvertCountToDirectScan(RelOptRuleOperand rule, String
id) {
       @Override
       public void onMatch(RelOptRuleCall call) {
         final DrillAggregateRel agg = (DrillAggregateRel) call.rel(0);
    -    final DrillScanRel scan = (DrillScanRel) call.rel(call.rels.length -1);
    -    final DrillProjectRel proj = call.rels.length == 3 ? (DrillProjectRel) call.rel(1)
: null;
    +    final DrillScanRel scan = (DrillScanRel) call.rel(call.rels.length - 1);
    +    final DrillProjectRel project = call.rels.length == 3 ? (DrillProjectRel) call.rel(1)
: null;
     
         final GroupScan oldGrpScan = scan.getGroupScan();
         final PlannerSettings settings = PrelUtil.getPlannerSettings(call.getPlanner());
     
    -    // Only apply the rule when :
    +    // Only apply the rule when:
         //    1) scan knows the exact row count in getSize() call,
         //    2) No GroupBY key,
    -    //    3) only one agg function (Check if it's count(*) below).
    -    //    4) No distinct agg call.
    +    //    3) No distinct agg call.
         if (!(oldGrpScan.getScanStats(settings).getGroupScanProperty().hasExactRowCount()
             && agg.getGroupCount() == 0
    -        && agg.getAggCallList().size() == 1
             && !agg.containsDistinctCall())) {
           return;
         }
     
    -    AggregateCall aggCall = agg.getAggCallList().get(0);
    -
    -    if (aggCall.getAggregation().getName().equals("COUNT") ) {
    -
    -      long cnt = 0;
    -      //  count(*)  == >  empty arg  ==>  rowCount
    -      //  count(Not-null-input) ==> rowCount
    -      if (aggCall.getArgList().isEmpty() ||
    -          (aggCall.getArgList().size() == 1 &&
    -           ! agg.getInput().getRowType().getFieldList().get(aggCall.getArgList().get(0).intValue()).getType().isNullable()))
{
    -        cnt = (long) oldGrpScan.getScanStats(settings).getRecordCount();
    -      } else if (aggCall.getArgList().size() == 1) {
    -      // count(columnName) ==> Agg ( Scan )) ==> columnValueCount
    -        int index = aggCall.getArgList().get(0);
    -
    -        if (proj != null) {
    -          // project in the middle of Agg and Scan : Only when input of AggCall is a
RexInputRef in Project, we find the index of Scan's field.
    -          // For instance,
    -          // Agg - count($0)
    -          //  \
    -          //  Proj - Exp={$1}
    -          //    \
    -          //   Scan (col1, col2).
    -          // return count of "col2" in Scan's metadata, if found.
    -
    -          if (proj.getProjects().get(index) instanceof RexInputRef) {
    -            index = ((RexInputRef) proj.getProjects().get(index)).getIndex();
    -          } else {
    -            return;  // do not apply for all other cases.
    -          }
    -        }
    +    final CountsCollector countsCollector = new CountsCollector(settings);
    +    // if counts were not collected, rule won't be applied
    +    if (!countsCollector.collect(agg, scan, project)) {
    +      return;
    +    }
     
    -        String columnName = scan.getRowType().getFieldNames().get(index).toLowerCase();
    +    final RelDataType scanRowType = constructDataType(agg);
     
    -        cnt = oldGrpScan.getColumnValueCount(SchemaPath.getSimplePath(columnName));
    -        if (cnt == GroupScan.NO_COLUMN_STATS) {
    -          // if column stats are not available don't apply this rule
    -          return;
    -        }
    -      } else {
    -        return; // do nothing.
    -      }
    +    final DynamicPojoRecordReader<Long> reader = new DynamicPojoRecordReader<>(
    +        buildSchema(scanRowType.getFieldNames()),
    +        Collections.singletonList(countsCollector.getCounts()));
     
    -      RelDataType scanRowType = getCountDirectScanRowType(agg.getCluster().getTypeFactory());
    +    final ScanStats scanStats = new ScanStats(ScanStats.GroupScanProperty.EXACT_ROW_COUNT,
1, 1, scanRowType.getFieldCount());
    +    final GroupScan directScan = new MetadataDirectGroupScan(reader, oldGrpScan.getFiles(),
scanStats);
     
    -      final ScanPrel newScan = ScanPrel.create(scan,
    -          scan.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON),
getCountDirectScan(cnt),
    -          scanRowType);
    +    final ScanPrel newScan = ScanPrel.create(scan,
    +        scan.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON),
directScan,
    +        scanRowType);
     
    -      List<RexNode> exprs = Lists.newArrayList();
    -      exprs.add(RexInputRef.of(0, scanRowType));
    +    final ProjectPrel newProject = new ProjectPrel(agg.getCluster(), agg.getTraitSet().plus(Prel.DRILL_PHYSICAL)
    +        .plus(DrillDistributionTrait.SINGLETON), newScan, prepareFieldExpressions(scanRowType),
agg.getRowType());
     
    -      final ProjectPrel newProj = new ProjectPrel(agg.getCluster(), agg.getTraitSet().plus(Prel.DRILL_PHYSICAL)
    -          .plus(DrillDistributionTrait.SINGLETON), newScan, exprs, agg.getRowType());
    +    call.transformTo(newProject);
    +  }
     
    -      call.transformTo(newProj);
    +  /**
    +   * For each aggregate call creates field with "count$" prefix and bigint type.
    +   * Constructs record type for created fields.
    +   *
    +   * @param aggregateRel aggregate relation expression
    +   * @return record type
    +   */
    +  private RelDataType constructDataType(DrillAggregateRel aggregateRel) {
    +    List<RelDataTypeField> fields = new ArrayList<>();
    +    for (int i = 0; i < aggregateRel.getAggCallList().size(); i++) {
    +      RelDataTypeField field = new RelDataTypeFieldImpl("count$" + i, i, aggregateRel.getCluster().getTypeFactory().createSqlType(SqlTypeName.BIGINT));
    +      fields.add(field);
         }
    -
    +    return new RelRecordType(fields);
       }
     
       /**
    -   * Class to represent the count aggregate result.
    +   * Builds schema based on given field names.
    +   * Type for each schema is set to long.class.
    +   *
    +   * @param fieldNames field names
    +   * @return schema
        */
    -  public static class CountQueryResult {
    -    public long count;
    +  private LinkedHashMap<String, Class<?>> buildSchema(List<String>
fieldNames) {
    +    LinkedHashMap<String, Class<?>> schema = new LinkedHashMap<>();
    +    for (String fieldName: fieldNames) {
    +      schema.put(fieldName, long.class);
    +    }
    +    return schema;
    +  }
     
    -    public CountQueryResult(long cnt) {
    -      this.count = cnt;
    +  /**
    +   * For each field creates row expression.
    +   *
    +   * @param rowType row type
    +   * @return list of row expressions
    +   */
    +  private List<RexNode> prepareFieldExpressions(RelDataType rowType) {
    +    List<RexNode> expressions = new ArrayList<>();
    +    for (int i = 0; i < rowType.getFieldCount(); i++) {
    +      expressions.add(RexInputRef.of(i, rowType));
         }
    +    return expressions;
       }
     
    -  private RelDataType getCountDirectScanRowType(RelDataTypeFactory typeFactory) {
    -    List<RelDataTypeField> fields = Lists.newArrayList();
    -    fields.add(new RelDataTypeFieldImpl("count", 0, typeFactory.createSqlType(SqlTypeName.BIGINT)));
    +  /**
    +   * Helper class to collect counts based on metadata information.
    +   * For example, for parquet files it can be obtained from parquet footer (total row
count)
    +   * or from parquet metadata files (column counts).
    +   */
    +  private class CountsCollector {
    +
    +    private final PlannerSettings settings;
    +    private final Set<String> implicitColumnsNames;
    +    private final List<SchemaPath> columns;
    +    private final List<Long> counts;
    +
    +    CountsCollector(PlannerSettings settings) {
    +      this.settings = settings;
    +      this.implicitColumnsNames = ColumnExplorer.initImplicitFileColumns(settings.getOptions()).keySet();
    +      this.counts = new ArrayList<>();
    +      this.columns = new ArrayList<>();
    +    }
     
    -    return new RelRecordType(fields);
    -  }
    +    /**
    +     * Collects counts for each aggregation call.
    +     * Will fail to collect counts if:
    +     * <ol>
    +     *   <li>was not able to determine count for at least one aggregation call</li>
    +     *   <li>count if used for file system partition column</li>
    +     * </ol>
    +     *
    +     * @param agg aggregate relational expression
    +     * @param scan scan relational expression
    +     * @param project project relational expression
    +     *
    +     * @return true if counts were collected, false otherwise
    +     */
    +    boolean collect(DrillAggregateRel agg, DrillScanRel scan, DrillProjectRel project)
{
    +      return calculateCounts(agg, scan, project) && !containsPartitionColumns();
    +    }
    +
    +    /**
    +     * @return list of counts
    +     */
    +    List<Long> getCounts() {
    --- End diff --
    
    In stead of returning list of Long, can we return either a map from schemaPath to counts,
or list of pair schemaPath, count?  Returning just the lists of counts make it hard to figure
out which count corresponds to which aggregate function. 



---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message