spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yh...@apache.org
Subject [2/4] spark git commit: [SPARK-13244][SQL] Migrates DataFrame to Dataset
Date Fri, 11 Mar 2016 01:00:27 GMT
http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 75ecbaa..b95c5dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -388,6 +388,8 @@ case class MapObjects private(
     case a: ArrayType => (i: String) => s".getArray($i)"
     case _: MapType => (i: String) => s".getMap($i)"
     case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
+    case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)"
+    case DateType => (i: String) => s".getInt($i)"
   }
 
   private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
@@ -485,7 +487,9 @@ case class MapObjects private(
  *
  * @param children A list of expression to use as content of the external row.
  */
-case class CreateExternalRow(children: Seq[Expression]) extends Expression with NonSQLExpression {
+case class CreateExternalRow(children: Seq[Expression], schema: StructType)
+  extends Expression with NonSQLExpression {
+
   override def dataType: DataType = ObjectType(classOf[Row])
 
   override def nullable: Boolean = false
@@ -494,8 +498,9 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
     throw new UnsupportedOperationException("Only code-generated evaluation is supported")
 
   override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
-    val rowClass = classOf[GenericRow].getName
+    val rowClass = classOf[GenericRowWithSchema].getName
     val values = ctx.freshName("values")
+    val schemaField = ctx.addReferenceObj("schema", schema)
     s"""
       boolean ${ev.isNull} = false;
       final Object[] $values = new Object[${children.size}];
@@ -510,7 +515,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression with
           }
          """
       }.mkString("\n") +
-      s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
+      s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);"
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 24f6199..17a9197 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import java.io.CharArrayWriter
 
+import scala.collection.JavaConverters._
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
 
@@ -26,30 +27,38 @@ import com.fasterxml.jackson.core.JsonFactory
 
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.function._
 import org.apache.spark.api.python.PythonRDD
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst._
 import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
-import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable,
-  QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
 import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
 import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
 private[sql] object DataFrame {
   def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
-    new DataFrame(sqlContext, logicalPlan)
+    val qe = sqlContext.executePlan(logicalPlan)
+    qe.assertAnalyzed()
+    new Dataset[Row](sqlContext, logicalPlan, RowEncoder(qe.analyzed.schema))
+  }
+}
+
+private[sql] object Dataset {
+  def apply[T: Encoder](sqlContext: SQLContext, logicalPlan: LogicalPlan): Dataset[T] = {
+    new Dataset(sqlContext, logicalPlan, implicitly[Encoder[T]])
   }
 }
 
@@ -112,28 +121,19 @@ private[sql] object DataFrame {
  * @since 1.3.0
  */
 @Experimental
-class DataFrame private[sql](
+class Dataset[T] private[sql](
     @transient override val sqlContext: SQLContext,
-    @DeveloperApi @transient override val queryExecution: QueryExecution)
+    @DeveloperApi @transient override val queryExecution: QueryExecution,
+    encoder: Encoder[T])
   extends Queryable with Serializable {
 
+  queryExecution.assertAnalyzed()
+
   // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure
   // you wrap it with `withNewExecutionId` if this actions doesn't call other action.
 
-  /**
-   * A constructor that automatically analyzes the logical plan.
-   *
-   * This reports error eagerly as the [[DataFrame]] is constructed, unless
-   * [[SQLConf.dataFrameEagerAnalysis]] is turned off.
-   */
-  def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
-    this(sqlContext, {
-      val qe = sqlContext.executePlan(logicalPlan)
-      if (sqlContext.conf.dataFrameEagerAnalysis) {
-        qe.assertAnalyzed()  // This should force analysis and throw errors if there are any
-      }
-      qe
-    })
+  def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = {
+    this(sqlContext, sqlContext.executePlan(logicalPlan), encoder)
   }
 
   @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
@@ -147,6 +147,26 @@ class DataFrame private[sql](
       queryExecution.analyzed
   }
 
+  /**
+   * An unresolved version of the internal encoder for the type of this [[Dataset]].  This one is
+   * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
+   * same object type (that will be possibly resolved to a different schema).
+   */
+  private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder)
+  unresolvedTEncoder.validate(logicalPlan.output)
+
+  /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
+  private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
+    unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+
+  /**
+   * The encoder where the expressions used to construct an object from an input row have been
+   * bound to the ordinals of this [[Dataset]]'s output schema.
+   */
+  private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
+
+  private implicit def classTag = unresolvedTEncoder.clsTag
+
   protected[sql] def resolve(colName: String): NamedExpression = {
     queryExecution.analyzed.resolveQuoted(colName, sqlContext.analyzer.resolver).getOrElse {
       throw new AnalysisException(
@@ -173,7 +193,11 @@ class DataFrame private[sql](
 
     // For array values, replace Seq and Array with square brackets
     // For cells that are beyond 20 characters, replace it with the first 17 and "..."
-    val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+    val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map {
+      case r: Row => r
+      case tuple: Product => Row.fromTuple(tuple)
+      case o => Row(o)
+    }.map { row =>
       row.toSeq.map { cell =>
         val str = cell match {
           case null => "null"
@@ -196,7 +220,7 @@ class DataFrame private[sql](
    */
   // This is declared with parentheses to prevent the Scala compiler from treating
   // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
-  def toDF(): DataFrame = this
+  def toDF(): DataFrame = new Dataset[Row](sqlContext, queryExecution, RowEncoder(schema))
 
   /**
    * :: Experimental ::
@@ -206,7 +230,7 @@ class DataFrame private[sql](
    * @since 1.6.0
    */
   @Experimental
-  def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan)
+  def as[U : Encoder]: Dataset[U] = Dataset[U](sqlContext, logicalPlan)
 
   /**
    * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
@@ -360,7 +384,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.1
    */
-  def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)
+  def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF())
 
   /**
    * Returns a [[DataFrameStatFunctions]] for working statistic functions support.
@@ -372,7 +396,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.4.0
    */
-  def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this)
+  def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
 
   /**
    * Cartesian join with another [[DataFrame]].
@@ -573,6 +597,62 @@ class DataFrame private[sql](
   }
 
   /**
+   * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
+   * true.
+   *
+   * This is similar to the relation `join` function with one important difference in the
+   * result schema. Since `joinWith` preserves objects present on either side of the join, the
+   * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
+   *
+   * This type of join can be useful both for preserving type-safety with the original object
+   * types as well as working with relational data where either side of the join has column
+   * names in common.
+   *
+   * @param other Right side of the join.
+   * @param condition Join expression.
+   * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
+   * @since 1.6.0
+   */
+  def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
+    val left = this.logicalPlan
+    val right = other.logicalPlan
+
+    val joined = sqlContext.executePlan(Join(left, right, joinType =
+      JoinType(joinType), Some(condition.expr)))
+    val leftOutput = joined.analyzed.output.take(left.output.length)
+    val rightOutput = joined.analyzed.output.takeRight(right.output.length)
+
+    val leftData = this.unresolvedTEncoder match {
+      case e if e.flat => Alias(leftOutput.head, "_1")()
+      case _ => Alias(CreateStruct(leftOutput), "_1")()
+    }
+    val rightData = other.unresolvedTEncoder match {
+      case e if e.flat => Alias(rightOutput.head, "_2")()
+      case _ => Alias(CreateStruct(rightOutput), "_2")()
+    }
+
+    implicit val tuple2Encoder: Encoder[(T, U)] =
+      ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
+    withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) =>
+      Project(
+        leftData :: rightData :: Nil,
+        joined.analyzed)
+    }
+  }
+
+  /**
+   * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
+   * where `condition` evaluates to true.
+   *
+   * @param other Right side of the join.
+   * @param condition Join expression.
+   * @since 1.6.0
+   */
+  def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
+    joinWith(other, condition, "inner")
+  }
+
+  /**
    * Returns a new [[DataFrame]] with each partition sorted by the given expressions.
    *
    * This is the same operation as "SORT BY" in SQL (Hive QL).
@@ -581,7 +661,7 @@ class DataFrame private[sql](
    * @since 1.6.0
    */
   @scala.annotation.varargs
-  def sortWithinPartitions(sortCol: String, sortCols: String*): DataFrame = {
+  def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = {
     sortWithinPartitions((sortCol +: sortCols).map(Column(_)) : _*)
   }
 
@@ -594,7 +674,7 @@ class DataFrame private[sql](
    * @since 1.6.0
    */
   @scala.annotation.varargs
-  def sortWithinPartitions(sortExprs: Column*): DataFrame = {
+  def sortWithinPartitions(sortExprs: Column*): Dataset[T] = {
     sortInternal(global = false, sortExprs)
   }
 
@@ -610,7 +690,7 @@ class DataFrame private[sql](
    * @since 1.3.0
    */
   @scala.annotation.varargs
-  def sort(sortCol: String, sortCols: String*): DataFrame = {
+  def sort(sortCol: String, sortCols: String*): Dataset[T] = {
     sort((sortCol +: sortCols).map(apply) : _*)
   }
 
@@ -623,7 +703,7 @@ class DataFrame private[sql](
    * @since 1.3.0
    */
   @scala.annotation.varargs
-  def sort(sortExprs: Column*): DataFrame = {
+  def sort(sortExprs: Column*): Dataset[T] = {
     sortInternal(global = true, sortExprs)
   }
 
@@ -634,7 +714,7 @@ class DataFrame private[sql](
    * @since 1.3.0
    */
   @scala.annotation.varargs
-  def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols : _*)
+  def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols : _*)
 
   /**
    * Returns a new [[DataFrame]] sorted by the given expressions.
@@ -643,7 +723,7 @@ class DataFrame private[sql](
    * @since 1.3.0
    */
   @scala.annotation.varargs
-  def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs : _*)
+  def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs : _*)
 
   /**
    * Selects column based on the column name and return it as a [[Column]].
@@ -672,7 +752,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def as(alias: String): DataFrame = withPlan {
+  def as(alias: String): Dataset[T] = withTypedPlan {
     SubqueryAlias(alias, logicalPlan)
   }
 
@@ -681,21 +761,21 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def as(alias: Symbol): DataFrame = as(alias.name)
+  def as(alias: Symbol): Dataset[T] = as(alias.name)
 
   /**
    * Returns a new [[DataFrame]] with an alias set. Same as `as`.
    * @group dfops
    * @since 1.6.0
    */
-  def alias(alias: String): DataFrame = as(alias)
+  def alias(alias: String): Dataset[T] = as(alias)
 
   /**
    * (Scala-specific) Returns a new [[DataFrame]] with an alias set. Same as `as`.
    * @group dfops
    * @since 1.6.0
    */
-  def alias(alias: Symbol): DataFrame = as(alias)
+  def alias(alias: Symbol): Dataset[T] = as(alias)
 
   /**
    * Selects a set of column based expressions.
@@ -745,6 +825,80 @@ class DataFrame private[sql](
   }
 
   /**
+   * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
+   *
+   * {{{
+   *   val ds = Seq(1, 2, 3).toDS()
+   *   val newDS = ds.select(expr("value + 1").as[Int])
+   * }}}
+   * @since 1.6.0
+   */
+  def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
+    new Dataset[U1](
+      sqlContext,
+      Project(
+        c1.withInputType(
+          boundTEncoder,
+          logicalPlan.output).named :: Nil,
+        logicalPlan),
+      implicitly[Encoder[U1]])
+  }
+
+  /**
+   * Internal helper function for building typed selects that return tuples.  For simplicity and
+   * code reuse, we do this without the help of the type system and then use helper functions
+   * that cast appropriately for the user facing interface.
+   */
+  protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
+    val encoders = columns.map(_.encoder)
+    val namedColumns =
+      columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
+    val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
+
+    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+  }
+
+  /**
+   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+   * @since 1.6.0
+   */
+  def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
+    selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
+
+  /**
+   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+   * @since 1.6.0
+   */
+  def select[U1, U2, U3](
+      c1: TypedColumn[T, U1],
+      c2: TypedColumn[T, U2],
+      c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
+    selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
+
+  /**
+   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+   * @since 1.6.0
+   */
+  def select[U1, U2, U3, U4](
+      c1: TypedColumn[T, U1],
+      c2: TypedColumn[T, U2],
+      c3: TypedColumn[T, U3],
+      c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
+    selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
+
+  /**
+   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
+   * @since 1.6.0
+   */
+  def select[U1, U2, U3, U4, U5](
+      c1: TypedColumn[T, U1],
+      c2: TypedColumn[T, U2],
+      c3: TypedColumn[T, U3],
+      c4: TypedColumn[T, U4],
+      c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
+    selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
+
+  /**
    * Filters rows using the given condition.
    * {{{
    *   // The following are equivalent:
@@ -754,7 +908,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def filter(condition: Column): DataFrame = withPlan {
+  def filter(condition: Column): Dataset[T] = withTypedPlan {
     Filter(condition.expr, logicalPlan)
   }
 
@@ -766,7 +920,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def filter(conditionExpr: String): DataFrame = {
+  def filter(conditionExpr: String): Dataset[T] = {
     filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
   }
 
@@ -780,7 +934,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def where(condition: Column): DataFrame = filter(condition)
+  def where(condition: Column): Dataset[T] = filter(condition)
 
   /**
    * Filters rows using the given SQL expression.
@@ -790,7 +944,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.5.0
    */
-  def where(conditionExpr: String): DataFrame = {
+  def where(conditionExpr: String): Dataset[T] = {
     filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
   }
 
@@ -813,7 +967,7 @@ class DataFrame private[sql](
    */
   @scala.annotation.varargs
   def groupBy(cols: Column*): GroupedData = {
-    GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
+    GroupedData(toDF(), cols.map(_.expr), GroupedData.GroupByType)
   }
 
   /**
@@ -836,7 +990,7 @@ class DataFrame private[sql](
    */
   @scala.annotation.varargs
   def rollup(cols: Column*): GroupedData = {
-    GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
+    GroupedData(toDF(), cols.map(_.expr), GroupedData.RollupType)
   }
 
   /**
@@ -858,7 +1012,7 @@ class DataFrame private[sql](
    * @since 1.4.0
    */
   @scala.annotation.varargs
-  def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType)
+  def cube(cols: Column*): GroupedData = GroupedData(toDF(), cols.map(_.expr), GroupedData.CubeType)
 
   /**
    * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -883,10 +1037,73 @@ class DataFrame private[sql](
   @scala.annotation.varargs
   def groupBy(col1: String, cols: String*): GroupedData = {
     val colNames: Seq[String] = col1 +: cols
-    GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
+    GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
+  }
+
+  /**
+   * (Scala-specific)
+   * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
+   * must be commutative and associative or the result may be non-deterministic.
+   * @since 1.6.0
+   */
+  def reduce(func: (T, T) => T): T = rdd.reduce(func)
+
+  /**
+   * (Java-specific)
+   * Reduces the elements of this Dataset using the specified binary function.  The given `func`
+   * must be commutative and associative or the result may be non-deterministic.
+   * @since 1.6.0
+   */
+  def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
+
+  /**
+   * (Scala-specific)
+   * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
+   * @since 1.6.0
+   */
+  def groupByKey[K: Encoder](func: T => K): GroupedDataset[K, T] = {
+    val inputPlan = logicalPlan
+    val withGroupingKey = AppendColumns(func, inputPlan)
+    val executed = sqlContext.executePlan(withGroupingKey)
+
+    new GroupedDataset(
+      encoderFor[K],
+      encoderFor[T],
+      executed,
+      inputPlan.output,
+      withGroupingKey.newColumns)
+  }
+
+  /**
+   * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def groupByKey(cols: Column*): GroupedDataset[Row, T] = {
+    val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
+    val withKey = Project(withKeyColumns, logicalPlan)
+    val executed = sqlContext.executePlan(withKey)
+
+    val dataAttributes = executed.analyzed.output.dropRight(cols.size)
+    val keyAttributes = executed.analyzed.output.takeRight(cols.size)
+
+    new GroupedDataset(
+      RowEncoder(keyAttributes.toStructType),
+      encoderFor[T],
+      executed,
+      dataAttributes,
+      keyAttributes)
   }
 
   /**
+   * (Java-specific)
+   * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
+   * @since 1.6.0
+   */
+  def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
+    groupByKey(func.call(_))(encoder)
+
+  /**
    * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
    * so we can run aggregation on them.
    * See [[GroupedData]] for all the available aggregate functions.
@@ -910,7 +1127,7 @@ class DataFrame private[sql](
   @scala.annotation.varargs
   def rollup(col1: String, cols: String*): GroupedData = {
     val colNames: Seq[String] = col1 +: cols
-    GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType)
+    GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.RollupType)
   }
 
   /**
@@ -937,7 +1154,7 @@ class DataFrame private[sql](
   @scala.annotation.varargs
   def cube(col1: String, cols: String*): GroupedData = {
     val colNames: Seq[String] = col1 +: cols
-    GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
+    GroupedData(toDF(), colNames.map(colName => resolve(colName)), GroupedData.CubeType)
   }
 
   /**
@@ -997,7 +1214,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def limit(n: Int): DataFrame = withPlan {
+  def limit(n: Int): Dataset[T] = withTypedPlan {
     Limit(Literal(n), logicalPlan)
   }
 
@@ -1007,19 +1224,21 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def unionAll(other: DataFrame): DataFrame = withPlan {
+  def unionAll(other: Dataset[T]): Dataset[T] = withTypedPlan {
     // This breaks caching, but it's usually ok because it addresses a very specific use case:
     // using union to union many files or partitions.
     CombineUnions(Union(logicalPlan, other.logicalPlan))
   }
 
+  def union(other: Dataset[T]): Dataset[T] = unionAll(other)
+
   /**
    * Returns a new [[DataFrame]] containing rows only in both this frame and another frame.
    * This is equivalent to `INTERSECT` in SQL.
    * @group dfops
    * @since 1.3.0
    */
-  def intersect(other: DataFrame): DataFrame = withPlan {
+  def intersect(other: Dataset[T]): Dataset[T] = withTypedPlan {
     Intersect(logicalPlan, other.logicalPlan)
   }
 
@@ -1029,10 +1248,12 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def except(other: DataFrame): DataFrame = withPlan {
+  def except(other: Dataset[T]): Dataset[T] = withTypedPlan {
     Except(logicalPlan, other.logicalPlan)
   }
 
+  def subtract(other: Dataset[T]): Dataset[T] = except(other)
+
   /**
    * Returns a new [[DataFrame]] by sampling a fraction of rows.
    *
@@ -1042,7 +1263,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = withPlan {
+  def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] = withTypedPlan {
     Sample(0.0, fraction, withReplacement, seed, logicalPlan)()
   }
 
@@ -1054,7 +1275,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
+  def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = {
     sample(withReplacement, fraction, Utils.random.nextLong)
   }
 
@@ -1066,7 +1287,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.4.0
    */
-  def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
+  def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] = {
     // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
     // constituent partitions each time a split is materialized which could result in
     // overlapping splits. To prevent this, we explicitly sort each input partition to make the
@@ -1075,7 +1296,8 @@ class DataFrame private[sql](
     val sum = weights.sum
     val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
     normalizedCumWeights.sliding(2).map { x =>
-      new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)())
+      new Dataset[T](
+        sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder)
     }.toArray
   }
 
@@ -1086,7 +1308,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.4.0
    */
-  def randomSplit(weights: Array[Double]): Array[DataFrame] = {
+  def randomSplit(weights: Array[Double]): Array[Dataset[T]] = {
     randomSplit(weights, Utils.random.nextLong)
   }
 
@@ -1097,7 +1319,7 @@ class DataFrame private[sql](
    * @param seed Seed for sampling.
    * @group dfops
    */
-  private[spark] def randomSplit(weights: List[Double], seed: Long): Array[DataFrame] = {
+  private[spark] def randomSplit(weights: List[Double], seed: Long): Array[Dataset[T]] = {
     randomSplit(weights.toArray, seed)
   }
 
@@ -1238,7 +1460,7 @@ class DataFrame private[sql](
       }
       select(columns : _*)
     } else {
-      this
+      toDF()
     }
   }
 
@@ -1264,7 +1486,7 @@ class DataFrame private[sql](
     val remainingCols =
       schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
     if (remainingCols.size == this.schema.size) {
-      this
+      toDF()
     } else {
       this.select(remainingCols: _*)
     }
@@ -1297,7 +1519,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.4.0
    */
-  def dropDuplicates(): DataFrame = dropDuplicates(this.columns)
+  def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns)
 
   /**
    * (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only
@@ -1306,7 +1528,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.4.0
    */
-  def dropDuplicates(colNames: Seq[String]): DataFrame = withPlan {
+  def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
     val groupCols = colNames.map(resolve)
     val groupColExprIds = groupCols.map(_.exprId)
     val aggCols = logicalPlan.output.map { attr =>
@@ -1326,7 +1548,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.4.0
    */
-  def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq)
+  def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toSeq)
 
   /**
    * Computes statistics for numeric columns, including count, mean, stddev, min, and max.
@@ -1396,7 +1618,7 @@ class DataFrame private[sql](
    * @group action
    * @since 1.3.0
    */
-  def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df =>
+  def head(n: Int): Array[T] = withTypedCallback("head", limit(n)) { df =>
     df.collect(needCallback = false)
   }
 
@@ -1405,14 +1627,14 @@ class DataFrame private[sql](
    * @group action
    * @since 1.3.0
    */
-  def head(): Row = head(1).head
+  def head(): T = head(1).head
 
   /**
    * Returns the first row. Alias for head().
    * @group action
    * @since 1.3.0
    */
-  def first(): Row = head()
+  def first(): T = head()
 
   /**
    * Concise syntax for chaining custom transformations.
@@ -1425,27 +1647,113 @@ class DataFrame private[sql](
    * }}}
    * @since 1.6.0
    */
-  def transform[U](t: DataFrame => DataFrame): DataFrame = t(this)
+  def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
+
+  /**
+   * (Scala-specific)
+   * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+   * @since 1.6.0
+   */
+  def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+
+  /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
+   * @since 1.6.0
+   */
+  def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+
+  /**
+   * (Scala-specific)
+   * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+   * @since 1.6.0
+   */
+  def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
+
+  /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
+   * @since 1.6.0
+   */
+  def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
+    map(t => func.call(t))(encoder)
+
+  /**
+   * (Scala-specific)
+   * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
+   * @since 1.6.0
+   */
+  def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
+    new Dataset[U](
+      sqlContext,
+      MapPartitions[T, U](func, logicalPlan),
+      implicitly[Encoder[U]])
+  }
+
+  /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
+   * @since 1.6.0
+   */
+  def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+    val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
+    mapPartitions(func)(encoder)
+  }
+
+  /**
+   * (Scala-specific)
+   * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+   * and then flattening the results.
+   * @since 1.6.0
+   */
+  def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
+    mapPartitions(_.flatMap(func))
+
+  /**
+   * (Java-specific)
+   * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
+   * and then flattening the results.
+   * @since 1.6.0
+   */
+  def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
+    val func: (T) => Iterator[U] = x => f.call(x).asScala
+    flatMap(func)(encoder)
+  }
 
   /**
    * Applies a function `f` to all rows.
    * @group rdd
    * @since 1.3.0
    */
-  def foreach(f: Row => Unit): Unit = withNewExecutionId {
+  def foreach(f: T => Unit): Unit = withNewExecutionId {
     rdd.foreach(f)
   }
 
   /**
+   * (Java-specific)
+   * Runs `func` on each element of this [[Dataset]].
+   * @since 1.6.0
+   */
+  def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
+
+  /**
    * Applies a function f to each partition of this [[DataFrame]].
    * @group rdd
    * @since 1.3.0
    */
-  def foreachPartition(f: Iterator[Row] => Unit): Unit = withNewExecutionId {
+  def foreachPartition(f: Iterator[T] => Unit): Unit = withNewExecutionId {
     rdd.foreachPartition(f)
   }
 
   /**
+   * (Java-specific)
+   * Runs `func` on each partition of this [[Dataset]].
+   * @since 1.6.0
+   */
+  def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
+    foreachPartition(it => func.call(it.asJava))
+
+  /**
    * Returns the first `n` rows in the [[DataFrame]].
    *
    * Running take requires moving data into the application's driver process, and doing so with
@@ -1454,7 +1762,11 @@ class DataFrame private[sql](
    * @group action
    * @since 1.3.0
    */
-  def take(n: Int): Array[Row] = head(n)
+  def take(n: Int): Array[T] = head(n)
+
+  def takeRows(n: Int): Array[Row] = withTypedCallback("takeRows", limit(n)) { ds =>
+    ds.collectRows(needCallback = false)
+  }
 
   /**
    * Returns the first `n` rows in the [[DataFrame]] as a list.
@@ -1465,7 +1777,7 @@ class DataFrame private[sql](
    * @group action
    * @since 1.6.0
    */
-  def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*)
+  def takeAsList(n: Int): java.util.List[T] = java.util.Arrays.asList(take(n) : _*)
 
   /**
    * Returns an array that contains all of [[Row]]s in this [[DataFrame]].
@@ -1478,7 +1790,9 @@ class DataFrame private[sql](
    * @group action
    * @since 1.3.0
    */
-  def collect(): Array[Row] = collect(needCallback = true)
+  def collect(): Array[T] = collect(needCallback = true)
+
+  def collectRows(): Array[Row] = collectRows(needCallback = true)
 
   /**
    * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
@@ -1489,19 +1803,32 @@ class DataFrame private[sql](
    * @group action
    * @since 1.3.0
    */
-  def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
+  def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
     withNewExecutionId {
-      java.util.Arrays.asList(rdd.collect() : _*)
+      val values = queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+      java.util.Arrays.asList(values : _*)
     }
   }
 
-  private def collect(needCallback: Boolean): Array[Row] = {
+  private def collect(needCallback: Boolean): Array[T] = {
+    def execute(): Array[T] = withNewExecutionId {
+      queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
+    }
+
+    if (needCallback) {
+      withCallback("collect", toDF())(_ => execute())
+    } else {
+      execute()
+    }
+  }
+
+  private def collectRows(needCallback: Boolean): Array[Row] = {
     def execute(): Array[Row] = withNewExecutionId {
       queryExecution.executedPlan.executeCollectPublic()
     }
 
     if (needCallback) {
-      withCallback("collect", this)(_ => execute())
+      withCallback("collect", toDF())(_ => execute())
     } else {
       execute()
     }
@@ -1521,7 +1848,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def repartition(numPartitions: Int): DataFrame = withPlan {
+  def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
     Repartition(numPartitions, shuffle = true, logicalPlan)
   }
 
@@ -1535,7 +1862,7 @@ class DataFrame private[sql](
    * @since 1.6.0
    */
   @scala.annotation.varargs
-  def repartition(numPartitions: Int, partitionExprs: Column*): DataFrame = withPlan {
+  def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = withTypedPlan {
     RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, Some(numPartitions))
   }
 
@@ -1549,7 +1876,7 @@ class DataFrame private[sql](
    * @since 1.6.0
    */
   @scala.annotation.varargs
-  def repartition(partitionExprs: Column*): DataFrame = withPlan {
+  def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
     RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None)
   }
 
@@ -1561,7 +1888,7 @@ class DataFrame private[sql](
    * @group rdd
    * @since 1.4.0
    */
-  def coalesce(numPartitions: Int): DataFrame = withPlan {
+  def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
     Repartition(numPartitions, shuffle = false, logicalPlan)
   }
 
@@ -1571,7 +1898,7 @@ class DataFrame private[sql](
    * @group dfops
    * @since 1.3.0
    */
-  def distinct(): DataFrame = dropDuplicates()
+  def distinct(): Dataset[T] = dropDuplicates()
 
   /**
    * Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
@@ -1632,12 +1959,11 @@ class DataFrame private[sql](
    * @group rdd
    * @since 1.3.0
    */
-  lazy val rdd: RDD[Row] = {
+  lazy val rdd: RDD[T] = {
     // use a local variable to make sure the map closure doesn't capture the whole DataFrame
     val schema = this.schema
     queryExecution.toRdd.mapPartitions { rows =>
-      val converter = CatalystTypeConverters.createToScalaConverter(schema)
-      rows.map(converter(_).asInstanceOf[Row])
+      rows.map(boundTEncoder.fromRow)
     }
   }
 
@@ -1646,14 +1972,14 @@ class DataFrame private[sql](
    * @group rdd
    * @since 1.3.0
    */
-  def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD()
+  def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD()
 
   /**
    * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
    * @group rdd
    * @since 1.3.0
    */
-  def javaRDD: JavaRDD[Row] = toJavaRDD
+  def javaRDD: JavaRDD[T] = toJavaRDD
 
   /**
    * Registers this [[DataFrame]] as a temporary table using the given name.  The lifetime of this
@@ -1663,7 +1989,7 @@ class DataFrame private[sql](
    * @since 1.3.0
    */
   def registerTempTable(tableName: String): Unit = {
-    sqlContext.registerDataFrameAsTable(this, tableName)
+    sqlContext.registerDataFrameAsTable(toDF(), tableName)
   }
 
   /**
@@ -1674,7 +2000,7 @@ class DataFrame private[sql](
    * @since 1.4.0
    */
   @Experimental
-  def write: DataFrameWriter = new DataFrameWriter(this)
+  def write: DataFrameWriter = new DataFrameWriter(toDF())
 
   /**
    * Returns the content of the [[DataFrame]] as a RDD of JSON strings.
@@ -1745,7 +2071,7 @@ class DataFrame private[sql](
    * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with
    * an execution.
    */
-  private[sql] def withNewExecutionId[T](body: => T): T = {
+  private[sql] def withNewExecutionId[U](body: => U): U = {
     SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body)
   }
 
@@ -1753,7 +2079,7 @@ class DataFrame private[sql](
    * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the
    * user-registered callback functions.
    */
-  private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = {
+  private def withCallback[U](name: String, df: DataFrame)(action: DataFrame => U) = {
     try {
       df.queryExecution.executedPlan.foreach { plan =>
         plan.resetMetrics()
@@ -1770,7 +2096,24 @@ class DataFrame private[sql](
     }
   }
 
-  private def sortInternal(global: Boolean, sortExprs: Seq[Column]): DataFrame = {
+  private def withTypedCallback[A, B](name: String, ds: Dataset[A])(action: Dataset[A] => B) = {
+    try {
+      ds.queryExecution.executedPlan.foreach { plan =>
+        plan.resetMetrics()
+      }
+      val start = System.nanoTime()
+      val result = action(ds)
+      val end = System.nanoTime()
+      sqlContext.listenerManager.onSuccess(name, ds.queryExecution, end - start)
+      result
+    } catch {
+      case e: Exception =>
+        sqlContext.listenerManager.onFailure(name, ds.queryExecution, e)
+        throw e
+    }
+  }
+
+  private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
     val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
       col.expr match {
         case expr: SortOrder =>
@@ -1779,14 +2122,23 @@ class DataFrame private[sql](
           SortOrder(expr, Ascending)
       }
     }
-    withPlan {
+    withTypedPlan {
       Sort(sortOrder, global = global, logicalPlan)
     }
   }
 
   /** A convenient function to wrap a logical plan and produce a DataFrame. */
   @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = {
-    new DataFrame(sqlContext, logicalPlan)
+    DataFrame(sqlContext, logicalPlan)
+  }
+
+  /** A convenient function to wrap a logical plan and produce a DataFrame. */
+  @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = {
+    new Dataset[T](sqlContext, logicalPlan, encoder)
   }
 
+  private[sql] def withTypedPlan[R](
+      other: Dataset[_], encoder: Encoder[R])(
+      f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
+    new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 509b299..8227024 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -345,7 +345,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
       InferSchema.infer(jsonRDD, sqlContext.conf.columnNameOfCorruptRecord, parsedOptions)
     }
 
-    new DataFrame(
+    DataFrame(
       sqlContext,
       LogicalRDD(
         schema.toAttributes,

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
deleted file mode 100644
index daddf6e..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ /dev/null
@@ -1,794 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.function._
-import org.apache.spark.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.CombineUnions
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{Queryable, QueryExecution}
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.Utils
-
-/**
- * :: Experimental ::
- * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel
- * using functional or relational operations.
- *
- * A [[Dataset]] differs from an [[RDD]] in the following ways:
- *  - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored
- *    in the encoded form.  This representation allows for additional logical operations and
- *    enables many operations (sorting, shuffling, etc.) to be performed without deserializing to
- *    an object.
- *  - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be
- *    used to serialize the object into a binary format.  Encoders are also capable of mapping the
- *    schema of a given object to the Spark SQL type system.  In contrast, RDDs rely on runtime
- *    reflection based serialization. Operations that change the type of object stored in the
- *    dataset also need an encoder for the new type.
- *
- * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific
- * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into
- * specific Dataset by calling `df.as[ElementType]`.  Similarly you can transform a strongly-typed
- * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`.
- *
- * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`.  However,
- * making this change to the class hierarchy would break the function signatures for the existing
- * functional operations (map, flatMap, etc).  As such, this class should be considered a preview
- * of the final API.  Changes will be made to the interface after Spark 1.6.
- *
- * @since 1.6.0
- */
-@Experimental
-class Dataset[T] private[sql](
-    @transient override val sqlContext: SQLContext,
-    @transient override val queryExecution: QueryExecution,
-    tEncoder: Encoder[T]) extends Queryable with Serializable with Logging {
-
-  /**
-   * An unresolved version of the internal encoder for the type of this [[Dataset]].  This one is
-   * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
-   * same object type (that will be possibly resolved to a different schema).
-   */
-  private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
-  unresolvedTEncoder.validate(logicalPlan.output)
-
-  /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
-  private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
-    unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
-
-  /**
-   * The encoder where the expressions used to construct an object from an input row have been
-   * bound to the ordinals of this [[Dataset]]'s output schema.
-   */
-  private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
-
-  private implicit def classTag = unresolvedTEncoder.clsTag
-
-  private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
-    this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
-
-  /**
-   * Returns the schema of the encoded form of the objects in this [[Dataset]].
-   * @since 1.6.0
-   */
-  override def schema: StructType = resolvedTEncoder.schema
-
-  /**
-   * Prints the schema of the underlying [[Dataset]] to the console in a nice tree format.
-   * @since 1.6.0
-   */
-  override def printSchema(): Unit = toDF().printSchema()
-
-  /**
-   * Prints the plans (logical and physical) to the console for debugging purposes.
-   * @since 1.6.0
-   */
-  override def explain(extended: Boolean): Unit = toDF().explain(extended)
-
-  /**
-   * Prints the physical plan to the console for debugging purposes.
-   * @since 1.6.0
-   */
-  override def explain(): Unit = toDF().explain()
-
-  /* ************* *
-   *  Conversions  *
-   * ************* */
-
-  /**
-   * Returns a new [[Dataset]] where each record has been mapped on to the specified type.  The
-   * method used to map columns depend on the type of `U`:
-   *  - When `U` is a class, fields for the class will be mapped to columns of the same name
-   *    (case sensitivity is determined by `spark.sql.caseSensitive`)
-   *  - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will
-   *    be assigned to `_1`).
-   *  - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the
-   *    [[DataFrame]] will be used.
-   *
-   * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select`
-   * along with `alias` or `as` to rearrange or rename as required.
-   * @since 1.6.0
-   */
-  def as[U : Encoder]: Dataset[U] = {
-    new Dataset(sqlContext, queryExecution, encoderFor[U])
-  }
-
-  /**
-   * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have
-   * the same name after two Datasets have been joined.
-   * @since 1.6.0
-   */
-  def as(alias: String): Dataset[T] = withPlan(SubqueryAlias(alias, _))
-
-  /**
-   * Converts this strongly typed collection of data to generic Dataframe.  In contrast to the
-   * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]]
-   * objects that allow fields to be accessed by ordinal or name.
-   */
-  // This is declared with parentheses to prevent the Scala compiler from treating
-  // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
-  def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan)
-
-  /**
-   * Returns this [[Dataset]].
-   * @since 1.6.0
-   */
-  // This is declared with parentheses to prevent the Scala compiler from treating
-  // `ds.toDS("1")` as invoking this toDF and then apply on the returned Dataset.
-  def toDS(): Dataset[T] = this
-
-  /**
-   * Converts this [[Dataset]] to an [[RDD]].
-   * @since 1.6.0
-   */
-  def rdd: RDD[T] = {
-    queryExecution.toRdd.mapPartitions { iter =>
-      iter.map(boundTEncoder.fromRow)
-    }
-  }
-
-  /**
-   * Returns the number of elements in the [[Dataset]].
-   * @since 1.6.0
-   */
-  def count(): Long = toDF().count()
-
-  /**
-   * Displays the content of this [[Dataset]] in a tabular form. Strings more than 20 characters
-   * will be truncated, and all cells will be aligned right. For example:
-   * {{{
-   *   year  month AVG('Adj Close) MAX('Adj Close)
-   *   1980  12    0.503218        0.595103
-   *   1981  01    0.523289        0.570307
-   *   1982  02    0.436504        0.475256
-   *   1983  03    0.410516        0.442194
-   *   1984  04    0.450090        0.483521
-   * }}}
-   * @param numRows Number of rows to show
-   *
-   * @since 1.6.0
-   */
-  def show(numRows: Int): Unit = show(numRows, truncate = true)
-
-  /**
-   * Displays the top 20 rows of [[Dataset]] in a tabular form. Strings more than 20 characters
-   * will be truncated, and all cells will be aligned right.
-   *
-   * @since 1.6.0
-   */
-  def show(): Unit = show(20)
-
-  /**
-   * Displays the top 20 rows of [[Dataset]] in a tabular form.
-   *
-   * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
-   *              be truncated and all cells will be aligned right
-   *
-   * @since 1.6.0
-   */
-  def show(truncate: Boolean): Unit = show(20, truncate)
-
-  /**
-   * Displays the [[Dataset]] in a tabular form. For example:
-   * {{{
-   *   year  month AVG('Adj Close) MAX('Adj Close)
-   *   1980  12    0.503218        0.595103
-   *   1981  01    0.523289        0.570307
-   *   1982  02    0.436504        0.475256
-   *   1983  03    0.410516        0.442194
-   *   1984  04    0.450090        0.483521
-   * }}}
-   * @param numRows Number of rows to show
-   * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
-   *              be truncated and all cells will be aligned right
-   *
-   * @since 1.6.0
-   */
-  // scalastyle:off println
-  def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate))
-  // scalastyle:on println
-
-  /**
-   * Compose the string representing rows for output
-   * @param _numRows Number of rows to show
-   * @param truncate Whether truncate long strings and align cells right
-   */
-  override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
-    val numRows = _numRows.max(0)
-    val takeResult = take(numRows + 1)
-    val hasMoreData = takeResult.length > numRows
-    val data = takeResult.take(numRows)
-
-    // For array values, replace Seq and Array with square brackets
-    // For cells that are beyond 20 characters, replace it with the first 17 and "..."
-    val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: (data.map {
-      case r: Row => r
-      case tuple: Product => Row.fromTuple(tuple)
-      case o => Row(o)
-    } map { row =>
-      row.toSeq.map { cell =>
-        val str = cell match {
-          case null => "null"
-          case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]")
-          case array: Array[_] => array.mkString("[", ", ", "]")
-          case seq: Seq[_] => seq.mkString("[", ", ", "]")
-          case _ => cell.toString
-        }
-        if (truncate && str.length > 20) str.substring(0, 17) + "..." else str
-      }: Seq[String]
-    })
-
-    formatString ( rows, numRows, hasMoreData, truncate )
-  }
-
-  /**
-    * Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
-    * @since 1.6.0
-    */
-  def repartition(numPartitions: Int): Dataset[T] = withPlan {
-    Repartition(numPartitions, shuffle = true, _)
-  }
-
-  /**
-    * Returns a new [[Dataset]] that has exactly `numPartitions` partitions.
-    * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
-    * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
-    * the 100 new partitions will claim 10 of the current partitions.
-    * @since 1.6.0
-    */
-  def coalesce(numPartitions: Int): Dataset[T] = withPlan {
-    Repartition(numPartitions, shuffle = false, _)
-  }
-
-  /* *********************** *
-   *  Functional Operations  *
-   * *********************** */
-
-  /**
-   * Concise syntax for chaining custom transformations.
-   * {{{
-   *   def featurize(ds: Dataset[T]) = ...
-   *
-   *   dataset
-   *     .transform(featurize)
-   *     .transform(...)
-   * }}}
-   * @since 1.6.0
-   */
-  def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this)
-
-  /**
-   * (Scala-specific)
-   * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
-   * @since 1.6.0
-   */
-  def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
-
-  /**
-   * (Java-specific)
-   * Returns a new [[Dataset]] that only contains elements where `func` returns `true`.
-   * @since 1.6.0
-   */
-  def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
-
-  /**
-   * (Scala-specific)
-   * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
-   * @since 1.6.0
-   */
-  def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func))
-
-  /**
-   * (Java-specific)
-   * Returns a new [[Dataset]] that contains the result of applying `func` to each element.
-   * @since 1.6.0
-   */
-  def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] =
-    map(t => func.call(t))(encoder)
-
-  /**
-   * (Scala-specific)
-   * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
-   * @since 1.6.0
-   */
-  def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
-    new Dataset[U](
-      sqlContext,
-      MapPartitions[T, U](func, logicalPlan))
-  }
-
-  /**
-   * (Java-specific)
-   * Returns a new [[Dataset]] that contains the result of applying `func` to each partition.
-   * @since 1.6.0
-   */
-  def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).asScala
-    mapPartitions(func)(encoder)
-  }
-
-  /**
-   * (Scala-specific)
-   * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
-   * and then flattening the results.
-   * @since 1.6.0
-   */
-  def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] =
-    mapPartitions(_.flatMap(func))
-
-  /**
-   * (Java-specific)
-   * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]],
-   * and then flattening the results.
-   * @since 1.6.0
-   */
-  def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = {
-    val func: (T) => Iterator[U] = x => f.call(x).asScala
-    flatMap(func)(encoder)
-  }
-
-  /* ************** *
-   *  Side effects  *
-   * ************** */
-
-  /**
-   * (Scala-specific)
-   * Runs `func` on each element of this [[Dataset]].
-   * @since 1.6.0
-   */
-  def foreach(func: T => Unit): Unit = rdd.foreach(func)
-
-  /**
-   * (Java-specific)
-   * Runs `func` on each element of this [[Dataset]].
-   * @since 1.6.0
-   */
-  def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_))
-
-  /**
-   * (Scala-specific)
-   * Runs `func` on each partition of this [[Dataset]].
-   * @since 1.6.0
-   */
-  def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func)
-
-  /**
-   * (Java-specific)
-   * Runs `func` on each partition of this [[Dataset]].
-   * @since 1.6.0
-   */
-  def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
-    foreachPartition(it => func.call(it.asJava))
-
-  /* ************* *
-   *  Aggregation  *
-   * ************* */
-
-  /**
-   * (Scala-specific)
-   * Reduces the elements of this [[Dataset]] using the specified binary function. The given `func`
-   * must be commutative and associative or the result may be non-deterministic.
-   * @since 1.6.0
-   */
-  def reduce(func: (T, T) => T): T = rdd.reduce(func)
-
-  /**
-   * (Java-specific)
-   * Reduces the elements of this Dataset using the specified binary function.  The given `func`
-   * must be commutative and associative or the result may be non-deterministic.
-   * @since 1.6.0
-   */
-  def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _))
-
-  /**
-   * (Scala-specific)
-   * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
-   * @since 1.6.0
-   */
-  def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
-    val inputPlan = logicalPlan
-    val withGroupingKey = AppendColumns(func, inputPlan)
-    val executed = sqlContext.executePlan(withGroupingKey)
-
-    new GroupedDataset(
-      encoderFor[K],
-      encoderFor[T],
-      executed,
-      inputPlan.output,
-      withGroupingKey.newColumns)
-  }
-
-  /**
-   * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
-   * @since 1.6.0
-   */
-  @scala.annotation.varargs
-  def groupBy(cols: Column*): GroupedDataset[Row, T] = {
-    val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
-    val withKey = Project(withKeyColumns, logicalPlan)
-    val executed = sqlContext.executePlan(withKey)
-
-    val dataAttributes = executed.analyzed.output.dropRight(cols.size)
-    val keyAttributes = executed.analyzed.output.takeRight(cols.size)
-
-    new GroupedDataset(
-      RowEncoder(keyAttributes.toStructType),
-      encoderFor[T],
-      executed,
-      dataAttributes,
-      keyAttributes)
-  }
-
-  /**
-   * (Java-specific)
-   * Returns a [[GroupedDataset]] where the data is grouped by the given key `func`.
-   * @since 1.6.0
-   */
-  def groupBy[K](func: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] =
-    groupBy(func.call(_))(encoder)
-
-  /* ****************** *
-   *  Typed Relational  *
-   * ****************** */
-
-  /**
-   * Returns a new [[DataFrame]] by selecting a set of column based expressions.
-   * {{{
-   *   df.select($"colA", $"colB" + 1)
-   * }}}
-   * @since 1.6.0
-   */
-  // Copied from Dataframe to make sure we don't have invalid overloads.
-  @scala.annotation.varargs
-  protected def select(cols: Column*): DataFrame = toDF().select(cols: _*)
-
-  /**
-   * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element.
-   *
-   * {{{
-   *   val ds = Seq(1, 2, 3).toDS()
-   *   val newDS = ds.select(expr("value + 1").as[Int])
-   * }}}
-   * @since 1.6.0
-   */
-  def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
-    new Dataset[U1](
-      sqlContext,
-      Project(
-        c1.withInputType(
-          boundTEncoder,
-          logicalPlan.output).named :: Nil,
-        logicalPlan))
-  }
-
-  /**
-   * Internal helper function for building typed selects that return tuples.  For simplicity and
-   * code reuse, we do this without the help of the type system and then use helper functions
-   * that cast appropriately for the user facing interface.
-   */
-  protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    val encoders = columns.map(_.encoder)
-    val namedColumns =
-      columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named)
-    val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
-
-    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
-  }
-
-  /**
-   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
-   * @since 1.6.0
-   */
-  def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] =
-    selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]]
-
-  /**
-   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
-   * @since 1.6.0
-   */
-  def select[U1, U2, U3](
-      c1: TypedColumn[T, U1],
-      c2: TypedColumn[T, U2],
-      c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] =
-    selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]]
-
-  /**
-   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
-   * @since 1.6.0
-   */
-  def select[U1, U2, U3, U4](
-      c1: TypedColumn[T, U1],
-      c2: TypedColumn[T, U2],
-      c3: TypedColumn[T, U3],
-      c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] =
-    selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]]
-
-  /**
-   * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
-   * @since 1.6.0
-   */
-  def select[U1, U2, U3, U4, U5](
-      c1: TypedColumn[T, U1],
-      c2: TypedColumn[T, U2],
-      c3: TypedColumn[T, U3],
-      c4: TypedColumn[T, U4],
-      c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] =
-    selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]]
-
-  /**
-   * Returns a new [[Dataset]] by sampling a fraction of records.
-   * @since 1.6.0
-   */
-  def sample(withReplacement: Boolean, fraction: Double, seed: Long) : Dataset[T] =
-    withPlan(Sample(0.0, fraction, withReplacement, seed, _)())
-
-  /**
-   * Returns a new [[Dataset]] by sampling a fraction of records, using a random seed.
-   * @since 1.6.0
-   */
-  def sample(withReplacement: Boolean, fraction: Double) : Dataset[T] = {
-    sample(withReplacement, fraction, Utils.random.nextLong)
-  }
-
-  /* **************** *
-   *  Set operations  *
-   * **************** */
-
-  /**
-   * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]].
-   *
-   * Note that, equality checking is performed directly on the encoded representation of the data
-   * and thus is not affected by a custom `equals` function defined on `T`.
-   * @since 1.6.0
-   */
-  def distinct: Dataset[T] = withPlan(Distinct)
-
-  /**
-   * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also
-   * present in `other`.
-   *
-   * Note that, equality checking is performed directly on the encoded representation of the data
-   * and thus is not affected by a custom `equals` function defined on `T`.
-   * @since 1.6.0
-   */
-  def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect)
-
-  /**
-   * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]]
-   * combined.
-   *
-   * Note that, this function is not a typical set union operation, in that it does not eliminate
-   * duplicate items.  As such, it is analogous to `UNION ALL` in SQL.
-   * @since 1.6.0
-   */
-  def union(other: Dataset[T]): Dataset[T] = withPlan[T](other) { (left, right) =>
-    // This breaks caching, but it's usually ok because it addresses a very specific use case:
-    // using union to union many files or partitions.
-    CombineUnions(Union(left, right))
-  }
-
-  /**
-   * Returns a new [[Dataset]] where any elements present in `other` have been removed.
-   *
-   * Note that, equality checking is performed directly on the encoded representation of the data
-   * and thus is not affected by a custom `equals` function defined on `T`.
-   * @since 1.6.0
-   */
-  def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except)
-
-  /* ****** *
-   *  Joins *
-   * ****** */
-
-  /**
-   * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to
-   * true.
-   *
-   * This is similar to the relation `join` function with one important difference in the
-   * result schema. Since `joinWith` preserves objects present on either side of the join, the
-   * result schema is similarly nested into a tuple under the column names `_1` and `_2`.
-   *
-   * This type of join can be useful both for preserving type-safety with the original object
-   * types as well as working with relational data where either side of the join has column
-   * names in common.
-   *
-   * @param other Right side of the join.
-   * @param condition Join expression.
-   * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
-   * @since 1.6.0
-   */
-  def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = {
-    val left = this.logicalPlan
-    val right = other.logicalPlan
-
-    val joined = sqlContext.executePlan(Join(left, right, joinType =
-      JoinType(joinType), Some(condition.expr)))
-    val leftOutput = joined.analyzed.output.take(left.output.length)
-    val rightOutput = joined.analyzed.output.takeRight(right.output.length)
-
-    val leftData = this.unresolvedTEncoder match {
-      case e if e.flat => Alias(leftOutput.head, "_1")()
-      case _ => Alias(CreateStruct(leftOutput), "_1")()
-    }
-    val rightData = other.unresolvedTEncoder match {
-      case e if e.flat => Alias(rightOutput.head, "_2")()
-      case _ => Alias(CreateStruct(rightOutput), "_2")()
-    }
-
-    implicit val tuple2Encoder: Encoder[(T, U)] =
-      ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
-    withPlan[(T, U)](other) { (left, right) =>
-      Project(
-        leftData :: rightData :: Nil,
-        joined.analyzed)
-    }
-  }
-
-  /**
-   * Using inner equi-join to join this [[Dataset]] returning a [[Tuple2]] for each pair
-   * where `condition` evaluates to true.
-   *
-   * @param other Right side of the join.
-   * @param condition Join expression.
-   * @since 1.6.0
-   */
-  def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = {
-    joinWith(other, condition, "inner")
-  }
-
-  /* ************************** *
-   *  Gather to Driver Actions  *
-   * ************************** */
-
-  /**
-   * Returns the first element in this [[Dataset]].
-   * @since 1.6.0
-   */
-  def first(): T = take(1).head
-
-  /**
-   * Returns an array that contains all the elements in this [[Dataset]].
-   *
-   * Running collect requires moving all the data into the application's driver process, and
-   * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
-   *
-   * For Java API, use [[collectAsList]].
-   * @since 1.6.0
-   */
-  def collect(): Array[T] = {
-    // This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
-    // to convert the rows into objects of type T.
-    queryExecution.toRdd.map(_.copy()).collect().map(boundTEncoder.fromRow)
-  }
-
-  /**
-   * Returns an array that contains all the elements in this [[Dataset]].
-   *
-   * Running collect requires moving all the data into the application's driver process, and
-   * doing so on a very large [[Dataset]] can crash the driver process with OutOfMemoryError.
-   *
-   * For Java API, use [[collectAsList]].
-   * @since 1.6.0
-   */
-  def collectAsList(): java.util.List[T] = collect().toSeq.asJava
-
-  /**
-   * Returns the first `num` elements of this [[Dataset]] as an array.
-   *
-   * Running take requires moving data into the application's driver process, and doing so with
-   * a very large `num` can crash the driver process with OutOfMemoryError.
-   * @since 1.6.0
-   */
-  def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
-
-  /**
-   * Returns the first `num` elements of this [[Dataset]] as an array.
-   *
-   * Running take requires moving data into the application's driver process, and doing so with
-   * a very large `num` can crash the driver process with OutOfMemoryError.
-   * @since 1.6.0
-   */
-  def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
-
-  /**
-    * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
-    * @since 1.6.0
-    */
-  def persist(): this.type = {
-    sqlContext.cacheManager.cacheQuery(this)
-    this
-  }
-
-  /**
-    * Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
-    * @since 1.6.0
-    */
-  def cache(): this.type = persist()
-
-  /**
-    * Persist this [[Dataset]] with the given storage level.
-    * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
-    *                 `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
-    *                 `MEMORY_AND_DISK_2`, etc.
-    * @group basic
-    * @since 1.6.0
-    */
-  def persist(newLevel: StorageLevel): this.type = {
-    sqlContext.cacheManager.cacheQuery(this, None, newLevel)
-    this
-  }
-
-  /**
-    * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
-    * @param blocking Whether to block until all blocks are deleted.
-    * @since 1.6.0
-    */
-  def unpersist(blocking: Boolean): this.type = {
-    sqlContext.cacheManager.tryUncacheQuery(this, blocking)
-    this
-  }
-
-  /**
-    * Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
-    * @since 1.6.0
-    */
-  def unpersist(): this.type = unpersist(blocking = false)
-
-  /* ******************** *
-   *  Internal Functions  *
-   * ******************** */
-
-  private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed
-
-  private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
-    new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)
-
-  private[sql] def withPlan[R : Encoder](
-      other: Dataset[_])(
-      f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] =
-    new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan))
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index a7258d7..2a0f773 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.NumericType
 
 /**
  * :: Experimental ::
- * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
+ * A set of methods for aggregations on a [[DataFrame]], created by [[Dataset.groupBy]].
  *
  * The main method is the agg function, which has multiple variants. This class also contains
  * convenience some first order statistics such as mean, sum for convenience.

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index cd8ed47..1639cc8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -64,7 +64,7 @@ class GroupedDataset[K, V] private[sql](
 
   private def groupedData =
     new GroupedData(
-      new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
+      DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
 
   /**
    * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
@@ -86,7 +86,7 @@ class GroupedDataset[K, V] private[sql](
    * @since 1.6.0
    */
   def keys: Dataset[K] = {
-    new Dataset[K](
+    Dataset[K](
       sqlContext,
       Distinct(
         Project(groupingAttributes, logicalPlan)))
@@ -111,7 +111,7 @@ class GroupedDataset[K, V] private[sql](
    * @since 1.6.0
    */
   def flatMapGroups[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
-    new Dataset[U](
+    Dataset[U](
       sqlContext,
       MapGroups(
         f,
@@ -308,7 +308,7 @@ class GroupedDataset[K, V] private[sql](
       other: GroupedDataset[K, U])(
       f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
     implicit val uEncoder = other.unresolvedVEncoder
-    new Dataset[R](
+    Dataset[R](
       sqlContext,
       CoGroup(
         f,

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index c742bf2..54dbd6b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -464,7 +464,7 @@ class SQLContext private[sql](
     val encoded = data.map(d => enc.toRow(d).copy())
     val plan = new LocalRelation(attributes, encoded)
 
-    new Dataset[T](this, plan)
+    Dataset[T](this, plan)
   }
 
   def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
@@ -473,7 +473,7 @@ class SQLContext private[sql](
     val encoded = data.map(d => enc.toRow(d))
     val plan = LogicalRDD(attributes, encoded)(self)
 
-    new Dataset[T](this, plan)
+    Dataset[T](this, plan)
   }
 
   def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 16c4095..e23d5e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -126,6 +126,7 @@ abstract class SQLImplicits {
 
   /**
    * Creates a [[Dataset]] from an RDD.
+   *
    * @since 1.6.0
    */
   implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 8616fe3..19ab3ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{AnalysisException, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
 
@@ -31,7 +31,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
  */
 class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
 
-  def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed)
+  def assertAnalyzed(): Unit = try sqlContext.analyzer.checkAnalysis(analyzed) catch {
+    case e: AnalysisException =>
+      throw new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed))
+  }
 
   lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index e048ee1..60ec67c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -154,7 +154,7 @@ case class DataSource(
         }
 
         def dataFrameBuilder(files: Array[String]): DataFrame = {
-          new DataFrame(
+          DataFrame(
             sqlContext,
             LogicalRelation(
               DataSource(

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
index a191759..0dc3481 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala
@@ -121,6 +121,6 @@ private[sql] object FrequentItems extends Logging {
       StructField(v._1 + "_freqItems", ArrayType(v._2, false))
     }
     val schema = StructType(outputCols).toAttributes
-    new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
+    DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 26e4eda..daa065e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -454,6 +454,6 @@ private[sql] object StatFunctions extends Logging {
     }
     val schema = StructType(StructField(tableName, StringType) +: headerNames)
 
-    new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
+    DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index bc7c520..7d7c51b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -211,7 +211,7 @@ class StreamExecution(
 
         // Construct the batch and send it to the sink.
         val batchOffset = streamProgress.toCompositeOffset(sources)
-        val nextBatch = new Batch(batchOffset, new DataFrame(sqlContext, newPlan))
+        val nextBatch = new Batch(batchOffset, DataFrame(sqlContext, newPlan))
         sink.addBatch(nextBatch)
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 8124df1..3b764c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -55,11 +55,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
   def schema: StructType = encoder.schema
 
   def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
-    new Dataset(sqlContext, logicalPlan)
+    Dataset(sqlContext, logicalPlan)
   }
 
   def toDF()(implicit sqlContext: SQLContext): DataFrame = {
-    new DataFrame(sqlContext, logicalPlan)
+    DataFrame(sqlContext, logicalPlan)
   }
 
   def addData(data: A*): Offset = {

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 6eea924..844f305 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -46,7 +46,6 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
  * @tparam I The input type for the aggregation.
  * @tparam B The type of the intermediate value of the reduction.
  * @tparam O The type of the final output result.
- *
  * @since 1.6.0
  */
 abstract class Aggregator[-I, B, O] extends Serializable {

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/main/scala/org/apache/spark/sql/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index bd73a36..97e35bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -42,4 +42,5 @@ package object sql {
   @DeveloperApi
   type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan]
 
+  type DataFrame = Dataset[Row]
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1d542785/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 51f987f..42af813 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
 import org.apache.spark.sql.SQLContext;
@@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable {
     fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
     StructType schema = DataTypes.createStructType(fields);
 
-    DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+    Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
     df.registerTempTable("people");
-    Row[] actual = sqlContext.sql("SELECT * FROM people").collect();
+    Row[] actual = sqlContext.sql("SELECT * FROM people").collectRows();
 
     List<Row> expected = new ArrayList<>(2);
     expected.add(RowFactory.create("Michael", 29));
@@ -143,7 +143,7 @@ public class JavaApplySchemaSuite implements Serializable {
     fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
     StructType schema = DataTypes.createStructType(fields);
 
-    DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
+    Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema);
     df.registerTempTable("people");
     List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function<Row, String>() {
       @Override
@@ -198,14 +198,14 @@ public class JavaApplySchemaSuite implements Serializable {
         null,
         "this is another simple string."));
 
-    DataFrame df1 = sqlContext.read().json(jsonRDD);
+    Dataset<Row> df1 = sqlContext.read().json(jsonRDD);
     StructType actualSchema1 = df1.schema();
     Assert.assertEquals(expectedSchema, actualSchema1);
     df1.registerTempTable("jsonTable1");
     List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
     Assert.assertEquals(expectedResult, actual1);
 
-    DataFrame df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
+    Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD);
     StructType actualSchema2 = df2.schema();
     Assert.assertEquals(expectedSchema, actualSchema2);
     df2.registerTempTable("jsonTable2");


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message