spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hvanhov...@apache.org
Subject spark git commit: [SPARK-17069] Expose spark.range() as table-valued function in SQL
Date Thu, 18 Aug 2016 11:34:01 GMT
Repository: spark
Updated Branches:
  refs/heads/master b81421afb -> 412dba63b


[SPARK-17069] Expose spark.range() as table-valued function in SQL

## What changes were proposed in this pull request?

This adds analyzer rules for resolving table-valued functions, and adds one builtin implementation
for range(). The arguments for range() are the same as those of `spark.range()`.

## How was this patch tested?

Unit tests.

cc hvanhovell

Author: Eric Liang <ekl@databricks.com>

Closes #14656 from ericl/sc-4309.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/412dba63
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/412dba63
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/412dba63

Branch: refs/heads/master
Commit: 412dba63b511474a6db3c43c8618d803e604bc6b
Parents: b81421a
Author: Eric Liang <ekl@databricks.com>
Authored: Thu Aug 18 13:33:55 2016 +0200
Committer: Herman van Hovell <hvanhovell@databricks.com>
Committed: Thu Aug 18 13:33:55 2016 +0200

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/parser/SqlBase.g4 |   1 +
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   1 +
 .../analysis/ResolveTableValuedFunctions.scala  | 132 +++++++++++++++++++
 .../sql/catalyst/analysis/unresolved.scala      |  11 ++
 .../spark/sql/catalyst/parser/AstBuilder.scala  |   8 ++
 .../sql/catalyst/parser/PlanParserSuite.scala   |   8 +-
 .../sql-tests/inputs/table-valued-functions.sql |  20 +++
 .../results/table-valued-functions.sql.out      |  87 ++++++++++++
 8 files changed, 267 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 6122bcd..cab7c3f 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -433,6 +433,7 @@ relationPrimary
     | '(' queryNoWith ')' sample? (AS? strictIdentifier)?           #aliasedQuery
     | '(' relation ')' sample? (AS? strictIdentifier)?              #aliasedRelation
     | inlineTable                                                   #inlineTableDefault2
+    | identifier '(' (expression (',' expression)*)? ')'            #tableValuedFunction
     ;
 
 inlineTable

http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cfab6ae..333dd4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -86,6 +86,7 @@ class Analyzer(
       EliminateUnions,
       new SubstituteUnresolvedOrdinals(conf)),
     Batch("Resolution", fixedPoint,
+      ResolveTableValuedFunctions ::
       ResolveRelations ::
       ResolveReferences ::
       ResolveDeserializer ::

http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
new file mode 100644
index 0000000..7fdf7fa
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
+
+/**
+ * Rule that resolves table-valued function references.
+ */
+object ResolveTableValuedFunctions extends Rule[LogicalPlan] {
+  private lazy val defaultParallelism =
+    SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism
+
+  /**
+   * List of argument names and their types, used to declare a function.
+   */
+  private case class ArgumentList(args: (String, DataType)*) {
+    /**
+     * Try to cast the expressions to satisfy the expected types of this argument list. If
there
+     * are any types that cannot be casted, then None is returned.
+     */
+    def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = {
+      if (args.length == values.length) {
+        val casted = values.zip(args).map { case (value, (_, expectedType)) =>
+          TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType)
+        }
+        if (casted.forall(_.isDefined)) {
+          return Some(casted.map(_.get))
+        }
+      }
+      None
+    }
+
+    override def toString: String = {
+      args.map { a =>
+        s"${a._1}: ${a._2.typeName}"
+      }.mkString(", ")
+    }
+  }
+
+  /**
+   * A TVF maps argument lists to resolver functions that accept those arguments. Using a
map
+   * here allows for function overloading.
+   */
+  private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan]
+
+  /**
+   * TVF builder.
+   */
+  private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan])
+      : (ArgumentList, Seq[Any] => LogicalPlan) = {
+    (ArgumentList(args: _*),
+     pf orElse {
+       case args =>
+         throw new IllegalArgumentException(
+           "Invalid arguments for resolved function: " + args.mkString(", "))
+     })
+  }
+
+  /**
+   * Internal registry of table-valued functions.
+   */
+  private val builtinFunctions: Map[String, TVF] = Map(
+    "range" -> Map(
+      /* range(end) */
+      tvf("end" -> LongType) { case Seq(end: Long) =>
+        Range(0, end, 1, defaultParallelism)
+      },
+
+      /* range(start, end) */
+      tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long)
=>
+        Range(start, end, 1, defaultParallelism)
+      },
+
+      /* range(start, end, step) */
+      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) {
+        case Seq(start: Long, end: Long, step: Long) =>
+          Range(start, end, step, defaultParallelism)
+      },
+
+      /* range(start, end, step, numPartitions) */
+      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
+          "numPartitions" -> IntegerType) {
+        case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
+          Range(start, end, step, numPartitions)
+      })
+  )
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+    case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
+      builtinFunctions.get(u.functionName) match {
+        case Some(tvf) =>
+          val resolved = tvf.flatMap { case (argList, resolver) =>
+            argList.implicitCast(u.functionArgs) match {
+              case Some(casted) =>
+                Some(resolver(casted.map(_.eval())))
+              case _ =>
+                None
+            }
+          }
+          resolved.headOption.getOrElse {
+            val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
+            u.failAnalysis(
+              s"""error: table-valued function ${u.functionName} with alternatives:
+                |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
+                |cannot be applied to: (${argTypes})""".stripMargin)
+          }
+        case _ =>
+          u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 42e7aae..3735a15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -50,6 +50,17 @@ case class UnresolvedRelation(
 }
 
 /**
+ * Holds a table-valued function call that has yet to be resolved.
+ */
+case class UnresolvedTableValuedFunction(
+    functionName: String, functionArgs: Seq[Expression]) extends LeafNode {
+
+  override def output: Seq[Attribute] = Nil
+
+  override lazy val resolved = false
+}
+
+/**
  * Holds the name of an attribute that has yet to be resolved.
  */
 case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable
{

http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index adf7839..01322ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -658,6 +658,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
   }
 
   /**
+   * Create a table-valued function call with arguments, e.g. range(1000)
+   */
+  override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
+      : LogicalPlan = withOrigin(ctx) {
+    UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression))
+  }
+
+  /**
    * Create an inline table (a virtual table in Hive parlance).
    */
   override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {

http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 7af333b..cbe4a02 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser
 
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.FunctionIdentifier
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedTableValuedFunction}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -426,6 +426,12 @@ class PlanParserSuite extends PlanTest {
     assertEqual("table d.t", table("d", "t"))
   }
 
+  test("table valued function") {
+    assertEqual(
+      "select * from range(2)",
+      UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
+  }
+
   test("inline table") {
     assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
       Seq('col1.int),

http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
new file mode 100644
index 0000000..2e6dcd5
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -0,0 +1,20 @@
+-- unresolved function
+select * from dummy(3);
+
+-- range call with end
+select * from range(6 + cos(3));
+
+-- range call with start and end
+select * from range(5, 10);
+
+-- range call with step
+select * from range(0, 10, 2);
+
+-- range call with numPartitions
+select * from range(0, 10, 1, 200);
+
+-- range call error
+select * from range(1, 1, 1, 1, 1);
+
+-- range call with null
+select * from range(1, null);

http://git-wip-us.apache.org/repos/asf/spark/blob/412dba63/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
new file mode 100644
index 0000000..d769bce
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -0,0 +1,87 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 7
+
+
+-- !query 0
+select * from dummy(3)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+org.apache.spark.sql.AnalysisException
+could not resolve `dummy` to a table-valued function; line 1 pos 14
+
+
+-- !query 1
+select * from range(6 + cos(3))
+-- !query 1 schema
+struct<id:bigint>
+-- !query 1 output
+0
+1
+2
+3
+4
+
+
+-- !query 2
+select * from range(5, 10)
+-- !query 2 schema
+struct<id:bigint>
+-- !query 2 output
+5
+6
+7
+8
+9
+
+
+-- !query 3
+select * from range(0, 10, 2)
+-- !query 3 schema
+struct<id:bigint>
+-- !query 3 output
+0
+2
+4
+6
+8
+
+
+-- !query 4
+select * from range(0, 10, 1, 200)
+-- !query 4 schema
+struct<id:bigint>
+-- !query 4 output
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+
+
+-- !query 5
+select * from range(1, 1, 1, 1, 1)
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+error: table-valued function range with alternatives:
+ (end: long)
+ (start: long, end: long)
+ (start: long, end: long, step: long)
+ (start: long, end: long, step: long, numPartitions: integer)
+cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14
+
+
+-- !query 6
+select * from range(1, null)
+-- !query 6 schema
+struct<>
+-- !query 6 output
+java.lang.IllegalArgumentException
+Invalid arguments for resolved function: 1, null


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message