spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-5573][SQL] Add explode to dataframes
Date Thu, 12 Feb 2015 23:19:28 GMT
Repository: spark
Updated Branches:
  refs/heads/master c352ffbdb -> ee04a8b19


[SPARK-5573][SQL] Add explode to dataframes

Author: Michael Armbrust <michael@databricks.com>

Closes #4546 from marmbrus/explode and squashes the following commits:

eefd33a [Michael Armbrust] whitespace
a8d496c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into explode
4af740e [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explode
dc86a5c [Michael Armbrust] simple version
d633d01 [Michael Armbrust] add scala specific
950707a [Michael Armbrust] fix comments
ba8854c [Michael Armbrust] [SPARK-5573][SQL] Add explode to dataframes


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ee04a8b1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ee04a8b1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ee04a8b1

Branch: refs/heads/master
Commit: ee04a8b19be8330bfc48f470ef365622162c915f
Parents: c352ffb
Author: Michael Armbrust <michael@databricks.com>
Authored: Thu Feb 12 15:19:19 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Thu Feb 12 15:19:19 2015 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/generators.scala   | 19 ++++++++++
 .../scala/org/apache/spark/sql/DataFrame.scala  | 38 ++++++++++++++++++++
 .../org/apache/spark/sql/DataFrameImpl.scala    | 30 ++++++++++++++--
 .../apache/spark/sql/IncomputableColumn.scala   |  9 +++++
 .../org/apache/spark/sql/DataFrameSuite.scala   | 25 +++++++++++++
 5 files changed, 119 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ee04a8b1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 43b6482..0983d27 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -74,6 +74,25 @@ abstract class Generator extends Expression {
 }
 
 /**
+ * A generator that produces its output using the provided lambda function.
+ */
+case class UserDefinedGenerator(
+    schema: Seq[Attribute],
+    function: Row => TraversableOnce[Row],
+    children: Seq[Expression])
+  extends Generator{
+
+  override protected def makeOutput(): Seq[Attribute] = schema
+
+  override def eval(input: Row): TraversableOnce[Row] = {
+    val inputRow = new InterpretedProjection(children)
+    function(inputRow(input))
+  }
+
+  override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
+}
+
+/**
  * Given an input array produces a sequence of rows for each value in the array.
  */
 case class Explode(attributeNames: Seq[String], child: Expression)

http://git-wip-us.apache.org/repos/asf/spark/blob/ee04a8b1/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 13aff76..6525788 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import scala.collection.JavaConversions._
 import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
 import scala.util.control.NonFatal
 
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
@@ -441,6 +442,43 @@ trait DataFrame extends RDDApi[Row] with Serializable {
     sample(withReplacement, fraction, Utils.random.nextLong)
   }
 
+  /**
+   * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero
or more
+   * rows by the provided function.  This is similar to a `LATERAL VIEW` in HiveQL. The columns
of
+   * the input row are implicitly joined with each row that is output by the function.
+   *
+   * The following example uses this function to count the number of books which contain
+   * a given word:
+   *
+   * {{{
+   *   case class Book(title: String, words: String)
+   *   val df: RDD[Book]
+   *
+   *   case class Word(word: String)
+   *   val allWords = df.explode('words) {
+   *     case Row(words: String) => words.split(" ").map(Word(_))
+   *   }
+   *
+   *   val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
+   * }}}
+   */
+  def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]):
DataFrame
+
+
+  /**
+   * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded
to zero
+   * or more rows by the provided function.  This is similar to a `LATERAL VIEW` in HiveQL.
All
+   * columns of the input row are implicitly joined with each value that is output by the
function.
+   *
+   * {{{
+   *   df.explode("words", "word")(words: String => words.split(" "))
+   * }}}
+   */
+  def explode[A, B : TypeTag](
+      inputColumn: String,
+      outputColumn: String)(
+      f: A => TraversableOnce[B]): DataFrame
+
   /////////////////////////////////////////////////////////////////////////////
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ee04a8b1/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 4c6e19c..bb5c622 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -21,6 +21,7 @@ import java.io.CharArrayWriter
 
 import scala.language.implicitConversions
 import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
 import scala.collection.JavaConversions._
 
 import com.fasterxml.jackson.core.JsonFactory
@@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
+import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection}
 import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.{NumericType, StructType}
 
-
 /**
  * Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
  */
@@ -282,6 +282,32 @@ private[sql] class DataFrameImpl protected[sql](
     Sample(fraction, withReplacement, seed, logicalPlan)
   }
 
+  override def explode[A <: Product : TypeTag]
+      (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
+    val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+    val attributes = schema.toAttributes
+    val rowFunction =
+      f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
+    val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
+
+    Generate(generator, join = true, outer = false, None, logicalPlan)
+  }
+
+  override def explode[A, B : TypeTag](
+      inputColumn: String,
+      outputColumn: String)(
+      f: A => TraversableOnce[B]): DataFrame = {
+    val dataType = ScalaReflection.schemaFor[B].dataType
+    val attributes = AttributeReference(outputColumn, dataType)() :: Nil
+    def rowFunction(row: Row) = {
+      f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
+    }
+    val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr
:: Nil)
+
+    Generate(generator, join = true, outer = false, None, logicalPlan)
+
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // RDD API
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/ee04a8b1/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 4f9d92d..19c8e3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql
 
 import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.rdd.RDD
@@ -110,6 +111,14 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression)
exten
 
   override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame
= err()
 
+  override def explode[A <: Product : TypeTag]
+      (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()
+
+  override def explode[A, B : TypeTag](
+      inputColumn: String,
+      outputColumn: String)(
+      f: A => TraversableOnce[B]): DataFrame = err()
+
   /////////////////////////////////////////////////////////////////////////////
 
   override def head(n: Int): Array[Row] = err()

http://git-wip-us.apache.org/repos/asf/spark/blob/ee04a8b1/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 7be9215..33b35f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -98,6 +98,31 @@ class DataFrameSuite extends QueryTest {
       sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
   }
 
+  test("simple explode") {
+    val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")
+
+    checkAnswer(
+      df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
+      Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
+    )
+  }
+
+  test("explode") {
+    val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
+    val df2 =
+      df.explode('letters) {
+        case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
+      }
+
+    checkAnswer(
+      df2
+        .select('_1 as 'letter, 'number)
+        .groupBy('letter)
+        .agg('letter, countDistinct('number)),
+      Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
+    )
+  }
+
   test("selectExpr") {
     checkAnswer(
       testData.selectExpr("abs(key)", "value"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message