spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-11654][SQL][FOLLOW-UP] fix some mistakes and clean up
Date Fri, 13 Nov 2015 19:13:14 GMT
Repository: spark
Updated Branches:
  refs/heads/master a24477996 -> 23b8188f7


[SPARK-11654][SQL][FOLLOW-UP] fix some mistakes and clean up

* rename `AppendColumn` to `AppendColumns` to be consistent with the physical plan name.
* clean up stale comments.
* always pass in resolved encoder to `TypedColumn.withInputType`(test added)
* enable a mistakenly disabled java test.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9688 from cloud-fan/follow.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/23b8188f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/23b8188f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/23b8188f

Branch: refs/heads/master
Commit: 23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3
Parents: a244779
Author: Wenchen Fan <wenchen@databricks.com>
Authored: Fri Nov 13 11:13:09 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Fri Nov 13 11:13:09 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/plans/logical/basicOperators.scala |  8 ++++----
 sql/core/src/main/scala/org/apache/spark/sql/Column.scala |  3 ++-
 .../src/main/scala/org/apache/spark/sql/Dataset.scala     | 10 +++-------
 .../main/scala/org/apache/spark/sql/GroupedDataset.scala  |  4 ++--
 .../org/apache/spark/sql/execution/SparkStrategies.scala  |  2 +-
 .../java/test/org/apache/spark/sql/JavaDatasetSuite.java  |  1 +
 .../org/apache/spark/sql/DatasetAggregatorSuite.scala     |  4 ++++
 7 files changed, 17 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/23b8188f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d9f046e..e2b97b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -482,13 +482,13 @@ case class MapPartitions[T, U](
 }
 
 /** Factory for constructing new `AppendColumn` nodes. */
-object AppendColumn {
+object AppendColumns {
   def apply[T, U : Encoder](
       func: T => U,
       tEncoder: ExpressionEncoder[T],
-      child: LogicalPlan): AppendColumn[T, U] = {
+      child: LogicalPlan): AppendColumns[T, U] = {
     val attrs = encoderFor[U].schema.toAttributes
-    new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
+    new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child)
   }
 }
 
@@ -497,7 +497,7 @@ object AppendColumn {
  * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively
to
  * decode/encode from the JVM object representation expected by `func.`
  */
-case class AppendColumn[T, U](
+case class AppendColumns[T, U](
     func: T => U,
     tEncoder: ExpressionEncoder[T],
     uEncoder: ExpressionEncoder[U],

http://git-wip-us.apache.org/repos/asf/spark/blob/23b8188f/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 9292244..82e9cd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -58,10 +58,11 @@ class TypedColumn[-T, U](
   private[sql] def withInputType(
       inputEncoder: ExpressionEncoder[_],
       schema: Seq[Attribute]): TypedColumn[T, U] = {
+    val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]]
     new TypedColumn[T, U] (expr transform {
       case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
         ta.copy(
-          aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]),
+          aEncoder = Some(boundEncoder),
           children = schema)
     }, encoder)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/23b8188f/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b930e46..4cc3aa2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -299,7 +299,7 @@ class Dataset[T] private[sql](
    */
   def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
     val inputPlan = queryExecution.analyzed
-    val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan)
+    val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan)
     val executed = sqlContext.executePlan(withGroupingKey)
 
     new GroupedDataset(
@@ -364,13 +364,11 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
-    // We use an unbound encoder since the expression will make up its own schema.
-    // TODO: This probably doesn't work if we are relying on reordering of the input class
fields.
     new Dataset[U1](
       sqlContext,
       Project(
         c1.withInputType(
-          resolvedTEncoder.bind(queryExecution.analyzed.output),
+          resolvedTEncoder,
           queryExecution.analyzed.output).named :: Nil,
         logicalPlan))
   }
@@ -382,10 +380,8 @@ class Dataset[T] private[sql](
    */
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
-    // We use an unbound encoder since the expression will make up its own schema.
-    // TODO: This probably doesn't work if we are relying on reordering of the input class
fields.
     val namedColumns =
-      columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named)
+      columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named)
     val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
 
     new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))

http://git-wip-us.apache.org/repos/asf/spark/blob/23b8188f/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index ae1272a..9c16940 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -89,7 +89,7 @@ class GroupedDataset[K, T] private[sql](
   }
 
   /**
-   * Applies the given function to each group of data.  For each unique group, the function
 will
+   * Applies the given function to each group of data.  For each unique group, the function
will
    * be passed the group key and an iterator that contains all of the elements in the group.
The
    * function can return an iterator containing elements of an arbitrary type which will
be returned
    * as a new [[Dataset]].
@@ -162,7 +162,7 @@ class GroupedDataset[K, T] private[sql](
     val encoders = columns.map(_.encoder)
     val namedColumns =
       columns.map(
-        _.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named)
+        _.withInputType(resolvedTEncoder, dataAttributes).named)
     val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)
     val execution = new QueryExecution(sqlContext, aggregate)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/23b8188f/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index a99ae46..67201a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -321,7 +321,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan]
{
 
       case logical.MapPartitions(f, tEnc, uEnc, output, child) =>
         execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil
-      case logical.AppendColumn(f, tEnc, uEnc, newCol, child) =>
+      case logical.AppendColumns(f, tEnc, uEnc, newCol, child) =>
         execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
       case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
         execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/23b8188f/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 46169ca..eb6fa1e 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -157,6 +157,7 @@ public class JavaDatasetSuite implements Serializable {
     Assert.assertEquals(6, reduced);
   }
 
+  @Test
   public void testGroupBy() {
     List<String> data = Arrays.asList("a", "foo", "bar");
     Dataset<String> ds = context.createDataset(data, Encoders.STRING());

http://git-wip-us.apache.org/repos/asf/spark/blob/23b8188f/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 20896ef..46f9f07 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -162,6 +162,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext
{
       1)
 
     checkAnswer(
+      ds.select(expr("avg(a)").as[Double], ClassInputAgg.toColumn),
+      (1.0, 1))
+
+    checkAnswer(
       ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
       ("one", 1))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message