spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From maropu <...@git.apache.org>
Subject [GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Date Mon, 11 Dec 2017 14:48:30 GMT
Github user maropu commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19082#discussion_r156092752
  
    --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
---
    @@ -256,6 +258,85 @@ case class HashAggregateExec(
          """.stripMargin
       }
     
    +  // Extracts all the input variable references for a given `aggExpr`. This result will
be used
    +  // to split aggregation into small functions.
    +  private def getInputVariableReferences(
    +      ctx: CodegenContext,
    +      aggExpr: Expression,
    +      subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = {
    +    // `argSet` collects all the pairs of variable names and their types, the first in
the pair is
    +    // a type name and the second is a variable name.
    +    val argSet = mutable.Set[(String, String)]()
    +    val stack = mutable.Stack[Expression](aggExpr)
    +    while (stack.nonEmpty) {
    +      stack.pop() match {
    +        case e if subExprs.contains(e) =>
    +          val exprCode = subExprs(e)
    +          if (CodegenContext.isJavaIdentifier(exprCode.value)) {
    +            argSet += ((ctx.javaType(e.dataType), exprCode.value))
    +          }
    +          if (CodegenContext.isJavaIdentifier(exprCode.isNull)) {
    +            argSet += (("boolean", exprCode.isNull))
    +          }
    +          // Since the children possibly has common expressions, we push them here
    +          stack.pushAll(e.children)
    +        case ref: BoundReference
    +            if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null
=>
    +          val value = ctx.currentVars(ref.ordinal).value
    +          val isNull = ctx.currentVars(ref.ordinal).isNull
    +          if (CodegenContext.isJavaIdentifier(value)) {
    +            argSet += ((ctx.javaType(ref.dataType), value))
    +          }
    +          if (CodegenContext.isJavaIdentifier(isNull)) {
    +            argSet += (("boolean", isNull))
    +          }
    +        case _: BoundReference =>
    +          argSet += (("InternalRow", ctx.INPUT_ROW))
    +        case e =>
    +          stack.pushAll(e.children)
    +      }
    +    }
    +
    +    argSet.toSet
    +  }
    +
    +  // Splits aggregate code into small functions because JVMs does not compile too long
functions
    +  private def splitAggregateExpressions(
    +      ctx: CodegenContext,
    +      aggExprs: Seq[Expression],
    +      evalAndUpdateCodes: Seq[String],
    +      subExprs: Map[Expression, SubExprEliminationState],
    +      otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = {
    +    aggExprs.zipWithIndex.map { case (aggExpr, i) =>
    +      // The maximum length of parameters in non-static Java methods is 254, but a parameter
of
    +      // type long or double contributes two units to the length. So, this method gives
up
    +      // splitting the code if the parameter length goes over 127.
    +      val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq
    +
    +      // This is for testing/benchmarking only
    +      val maxParamNumInJavaMethod =
    +          sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null)
match {
    --- End diff --
    
    ok


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message