spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yli...@apache.org
Subject spark git commit: [SPARK-19634][SQL][ML][FOLLOW-UP] Improve interface of dataframe vectorized summarizer
Date Thu, 21 Dec 2017 03:53:40 GMT
Repository: spark
Updated Branches:
  refs/heads/master 9c289a5cb -> d3ae3e1e8


[SPARK-19634][SQL][ML][FOLLOW-UP] Improve interface of dataframe vectorized summarizer

## What changes were proposed in this pull request?

Make several improvements in dataframe vectorized summarizer.

1. Make the summarizer return `Vector` type for all metrics (except "count").
It will return "WrappedArray" type before which won't be very convenient.

2. Make `MetricsAggregate` inherit `ImplicitCastInputTypes` trait. So it can check and implicitly
cast input values.

3. Add "weight" parameter for all single metric method.

4. Update doc and improve the example code in doc.

5. Simplified test cases.

## How was this patch tested?

Test added and simplified.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19156 from WeichenXu123/improve_vec_summarizer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d3ae3e1e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d3ae3e1e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d3ae3e1e

Branch: refs/heads/master
Commit: d3ae3e1e894f88a8500752d9633fe9ad00da5f20
Parents: 9c289a5
Author: WeichenXu <weichen.xu@databricks.com>
Authored: Wed Dec 20 19:53:35 2017 -0800
Committer: Yanbo Liang <ybliang8@gmail.com>
Committed: Wed Dec 20 19:53:35 2017 -0800

----------------------------------------------------------------------
 .../org/apache/spark/ml/stat/Summarizer.scala   | 128 ++++---
 .../spark/ml/stat/JavaSummarizerSuite.java      |  64 ++++
 .../apache/spark/ml/stat/SummarizerSuite.scala  | 362 ++++++++++---------
 3 files changed, 341 insertions(+), 213 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d3ae3e1e/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index cae41ed..9bed74a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
 import org.apache.spark.sql.Column
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, UnsafeArrayData}
 import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete,
TypedImperativeAggregate}
 import org.apache.spark.sql.functions.lit
 import org.apache.spark.sql.types._
@@ -41,7 +41,7 @@ sealed abstract class SummaryBuilder {
   /**
    * Returns an aggregate object that contains the summary of the column with the requested
metrics.
    * @param featuresCol a column that contains features Vector object.
-   * @param weightCol a column that contains weight value.
+   * @param weightCol a column that contains weight value. Default weight is 1.0.
    * @return an aggregate column that contains the statistics. The exact content of this
    *         structure is determined during the creation of the builder.
    */
@@ -50,6 +50,7 @@ sealed abstract class SummaryBuilder {
 
   @Since("2.3.0")
   def summary(featuresCol: Column): Column = summary(featuresCol, lit(1.0))
+
 }
 
 /**
@@ -60,15 +61,18 @@ sealed abstract class SummaryBuilder {
  * This class lets users pick the statistics they would like to extract for a given column.
Here is
  * an example in Scala:
  * {{{
- *   val dataframe = ... // Some dataframe containing a feature column
- *   val allStats = dataframe.select(Summarizer.metrics("min", "max").summary($"features"))
- *   val Row(Row(min_, max_)) = allStats.first()
+ *   import org.apache.spark.ml.linalg._
+ *   import org.apache.spark.sql.Row
+ *   val dataframe = ... // Some dataframe containing a feature column and a weight column
+ *   val multiStatsDF = dataframe.select(
+ *       Summarizer.metrics("min", "max", "count").summary($"features", $"weight")
+ *   val Row(Row(minVec, maxVec, count)) = multiStatsDF.first()
  * }}}
  *
  * If one wants to get a single metric, shortcuts are also available:
  * {{{
  *   val meanDF = dataframe.select(Summarizer.mean($"features"))
- *   val Row(mean_) = meanDF.first()
+ *   val Row(meanVec) = meanDF.first()
  * }}}
  *
  * Note: Currently, the performance of this interface is about 2x~3x slower then using the
RDD
@@ -94,8 +98,7 @@ object Summarizer extends Logging {
    *  - min: the minimum for each coefficient.
    *  - normL2: the Euclidian norm for each coefficient.
    *  - normL1: the L1 norm of each coefficient (sum of the absolute values).
-   * @param firstMetric the metric being provided
-   * @param metrics additional metrics that can be provided.
+   * @param metrics metrics that can be provided.
    * @return a builder.
    * @throws IllegalArgumentException if one of the metric names is not understood.
    *
@@ -103,37 +106,79 @@ object Summarizer extends Logging {
    * interface.
    */
   @Since("2.3.0")
-  def metrics(firstMetric: String, metrics: String*): SummaryBuilder = {
-    val (typedMetrics, computeMetrics) = getRelevantMetrics(Seq(firstMetric) ++ metrics)
+  @scala.annotation.varargs
+  def metrics(metrics: String*): SummaryBuilder = {
+    require(metrics.size >= 1, "Should include at least one metric")
+    val (typedMetrics, computeMetrics) = getRelevantMetrics(metrics)
     new SummaryBuilderImpl(typedMetrics, computeMetrics)
   }
 
   @Since("2.3.0")
-  def mean(col: Column): Column = getSingleMetric(col, "mean")
+  def mean(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "mean")
+  }
+
+  @Since("2.3.0")
+  def mean(col: Column): Column = mean(col, lit(1.0))
+
+  @Since("2.3.0")
+  def variance(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "variance")
+  }
+
+  @Since("2.3.0")
+  def variance(col: Column): Column = variance(col, lit(1.0))
+
+  @Since("2.3.0")
+  def count(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "count")
+  }
+
+  @Since("2.3.0")
+  def count(col: Column): Column = count(col, lit(1.0))
 
   @Since("2.3.0")
-  def variance(col: Column): Column = getSingleMetric(col, "variance")
+  def numNonZeros(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "numNonZeros")
+  }
+
+  @Since("2.3.0")
+  def numNonZeros(col: Column): Column = numNonZeros(col, lit(1.0))
+
+  @Since("2.3.0")
+  def max(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "max")
+  }
+
+  @Since("2.3.0")
+  def max(col: Column): Column = max(col, lit(1.0))
 
   @Since("2.3.0")
-  def count(col: Column): Column = getSingleMetric(col, "count")
+  def min(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "min")
+  }
 
   @Since("2.3.0")
-  def numNonZeros(col: Column): Column = getSingleMetric(col, "numNonZeros")
+  def min(col: Column): Column = min(col, lit(1.0))
 
   @Since("2.3.0")
-  def max(col: Column): Column = getSingleMetric(col, "max")
+  def normL1(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "normL1")
+  }
 
   @Since("2.3.0")
-  def min(col: Column): Column = getSingleMetric(col, "min")
+  def normL1(col: Column): Column = normL1(col, lit(1.0))
 
   @Since("2.3.0")
-  def normL1(col: Column): Column = getSingleMetric(col, "normL1")
+  def normL2(col: Column, weightCol: Column): Column = {
+    getSingleMetric(col, weightCol, "normL2")
+  }
 
   @Since("2.3.0")
-  def normL2(col: Column): Column = getSingleMetric(col, "normL2")
+  def normL2(col: Column): Column = normL2(col, lit(1.0))
 
-  private def getSingleMetric(col: Column, metric: String): Column = {
-    val c1 = metrics(metric).summary(col)
+  private def getSingleMetric(col: Column, weightCol: Column, metric: String): Column = {
+    val c1 = metrics(metric).summary(col, weightCol)
     c1.getField(metric).as(s"$metric($col)")
   }
 }
@@ -187,8 +232,7 @@ private[ml] object SummaryBuilderImpl extends Logging {
     StructType(fields)
   }
 
-  private val arrayDType = ArrayType(DoubleType, containsNull = false)
-  private val arrayLType = ArrayType(LongType, containsNull = false)
+  private val vectorUDT = new VectorUDT
 
   /**
    * All the metrics that can be currently computed by Spark for vectors.
@@ -197,14 +241,14 @@ private[ml] object SummaryBuilderImpl extends Logging {
    * metrics that need to de computed internally to get the final result.
    */
   private val allMetrics: Seq[(String, Metric, DataType, Seq[ComputeMetric])] = Seq(
-    ("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)),
-    ("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
+    ("mean", Mean, vectorUDT, Seq(ComputeMean, ComputeWeightSum)),
+    ("variance", Variance, vectorUDT, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
     ("count", Count, LongType, Seq()),
-    ("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)),
-    ("max", Max, arrayDType, Seq(ComputeMax, ComputeNNZ)),
-    ("min", Min, arrayDType, Seq(ComputeMin, ComputeNNZ)),
-    ("normL2", NormL2, arrayDType, Seq(ComputeM2)),
-    ("normL1", NormL1, arrayDType, Seq(ComputeL1))
+    ("numNonZeros", NumNonZeros, vectorUDT, Seq(ComputeNNZ)),
+    ("max", Max, vectorUDT, Seq(ComputeMax, ComputeNNZ)),
+    ("min", Min, vectorUDT, Seq(ComputeMin, ComputeNNZ)),
+    ("normL2", NormL2, vectorUDT, Seq(ComputeM2)),
+    ("normL1", NormL1, vectorUDT, Seq(ComputeL1))
   )
 
   /**
@@ -527,27 +571,28 @@ private[ml] object SummaryBuilderImpl extends Logging {
       weightExpr: Expression,
       mutableAggBufferOffset: Int,
       inputAggBufferOffset: Int)
-    extends TypedImperativeAggregate[SummarizerBuffer] {
+    extends TypedImperativeAggregate[SummarizerBuffer] with ImplicitCastInputTypes {
 
-    override def eval(state: SummarizerBuffer): InternalRow = {
+    override def eval(state: SummarizerBuffer): Any = {
       val metrics = requestedMetrics.map {
-        case Mean => UnsafeArrayData.fromPrimitiveArray(state.mean.toArray)
-        case Variance => UnsafeArrayData.fromPrimitiveArray(state.variance.toArray)
+        case Mean => vectorUDT.serialize(state.mean)
+        case Variance => vectorUDT.serialize(state.variance)
         case Count => state.count
-        case NumNonZeros => UnsafeArrayData.fromPrimitiveArray(
-          state.numNonzeros.toArray.map(_.toLong))
-        case Max => UnsafeArrayData.fromPrimitiveArray(state.max.toArray)
-        case Min => UnsafeArrayData.fromPrimitiveArray(state.min.toArray)
-        case NormL2 => UnsafeArrayData.fromPrimitiveArray(state.normL2.toArray)
-        case NormL1 => UnsafeArrayData.fromPrimitiveArray(state.normL1.toArray)
+        case NumNonZeros => vectorUDT.serialize(state.numNonzeros)
+        case Max => vectorUDT.serialize(state.max)
+        case Min => vectorUDT.serialize(state.min)
+        case NormL2 => vectorUDT.serialize(state.normL2)
+        case NormL1 => vectorUDT.serialize(state.normL1)
       }
       InternalRow.apply(metrics: _*)
     }
 
+    override def inputTypes: Seq[DataType] = vectorUDT :: DoubleType :: Nil
+
     override def children: Seq[Expression] = featuresExpr :: weightExpr :: Nil
 
     override def update(state: SummarizerBuffer, row: InternalRow): SummarizerBuffer = {
-      val features = udt.deserialize(featuresExpr.eval(row))
+      val features = vectorUDT.deserialize(featuresExpr.eval(row))
       val weight = weightExpr.eval(row).asInstanceOf[Double]
       state.add(features, weight)
       state
@@ -591,7 +636,4 @@ private[ml] object SummaryBuilderImpl extends Logging {
     override def prettyName: String = "aggregate_metrics"
 
   }
-
-  private[this] val udt = new VectorUDT
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/d3ae3e1e/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java
new file mode 100644
index 0000000..38ab39a
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/stat/JavaSummarizerSuite.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.stat;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+
+import org.apache.spark.SharedSparkSession;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.Dataset;
+import static org.apache.spark.sql.functions.col;
+import org.apache.spark.ml.feature.LabeledPoint;
+import org.apache.spark.ml.linalg.Vector;
+import org.apache.spark.ml.linalg.Vectors;
+
+public class JavaSummarizerSuite extends SharedSparkSession {
+
+  private transient Dataset<Row> dataset;
+
+  @Override
+  public void setUp() throws IOException {
+    super.setUp();
+    List<LabeledPoint> points = new ArrayList<LabeledPoint>();
+    points.add(new LabeledPoint(0.0, Vectors.dense(1.0, 2.0)));
+    points.add(new LabeledPoint(0.0, Vectors.dense(3.0, 4.0)));
+
+    dataset = spark.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
+  }
+
+  @Test
+  public void testSummarizer() {
+    dataset.select(col("features"));
+    Row result = dataset
+      .select(Summarizer.metrics("mean", "max", "count").summary(col("features")))
+      .first().getStruct(0);
+    Vector meanVec = result.getAs("mean");
+    Vector maxVec = result.getAs("max");
+    long count = result.getAs("count");
+
+    assertEquals(2L, count);
+    assertArrayEquals(new double[]{2.0, 3.0}, meanVec.toArray(), 0.0);
+    assertArrayEquals(new double[]{3.0, 4.0}, maxVec.toArray(), 0.0);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d3ae3e1e/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
index 1ea851e..5e4f402 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
@@ -17,16 +17,13 @@
 
 package org.apache.spark.ml.stat
 
-import org.scalatest.exceptions.TestFailedException
-
 import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, Statistics}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.Row
 
 class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -35,237 +32,262 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext
{
   import SummaryBuilderImpl._
 
   private case class ExpectedMetrics(
-      mean: Seq[Double],
-      variance: Seq[Double],
+      mean: Vector,
+      variance: Vector,
       count: Long,
-      numNonZeros: Seq[Long],
-      max: Seq[Double],
-      min: Seq[Double],
-      normL2: Seq[Double],
-      normL1: Seq[Double])
+      numNonZeros: Vector,
+      max: Vector,
+      min: Vector,
+      normL2: Vector,
+      normL1: Vector)
 
   /**
-   * The input is expected to be either a sparse vector, a dense vector or an array of doubles
-   * (which will be converted to a dense vector)
-   * The expected is the list of all the known metrics.
+   * The input is expected to be either a sparse vector, a dense vector.
    *
-   * The tests take an list of input vectors and a list of all the summary values that
-   * are expected for this input. They currently test against some fixed subset of the
-   * metrics, but should be made fuzzy in the future.
+   * The tests take an list of input vectors, and compare results with
+   * `mllib.stat.MultivariateOnlineSummarizer`. They currently test against some fixed subset
+   * of the metrics, but should be made fuzzy in the future.
    */
-  private def testExample(name: String, input: Seq[Any], exp: ExpectedMetrics): Unit = {
+  private def testExample(name: String, inputVec: Seq[(Vector, Double)],
+      exp: ExpectedMetrics, expWithoutWeight: ExpectedMetrics): Unit = {
 
-    def inputVec: Seq[Vector] = input.map {
-      case x: Array[Double @unchecked] => Vectors.dense(x)
-      case x: Seq[Double @unchecked] => Vectors.dense(x.toArray)
-      case x: Vector => x
-      case x => throw new Exception(x.toString)
+    val summarizer = {
+      val _summarizer = new MultivariateOnlineSummarizer
+      inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1), v._2))
+      _summarizer
     }
 
-    val summarizer = {
+    val summarizerWithoutWeight = {
       val _summarizer = new MultivariateOnlineSummarizer
-      inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v)))
+      inputVec.foreach(v => _summarizer.add(OldVectors.fromML(v._1)))
       _summarizer
     }
 
     // Because the Spark context is reset between tests, we cannot hold a reference onto
it.
     def wrappedInit() = {
-      val df = inputVec.map(Tuple1.apply).toDF("features")
-      val col = df.col("features")
-      (df, col)
+      val df = inputVec.toDF("features", "weight")
+      val featuresCol = df.col("features")
+      val weightCol = df.col("weight")
+      (df, featuresCol, weightCol)
     }
 
     registerTest(s"$name - mean only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("mean").summary(c), mean(c)), Seq(Row(exp.mean), summarizer.mean))
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("mean").summary(c, w), mean(c, w)).first(),
+        Row(Row(summarizer.mean), exp.mean))
     }
 
-    registerTest(s"$name - mean only (direct)") {
-      val (df, c) = wrappedInit()
-      compare(df.select(mean(c)), Seq(exp.mean))
+    registerTest(s"$name - mean only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("mean").summary(c), mean(c)).first(),
+        Row(Row(summarizerWithoutWeight.mean), expWithoutWeight.mean))
     }
 
     registerTest(s"$name - variance only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("variance").summary(c), variance(c)),
-        Seq(Row(exp.variance), summarizer.variance))
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("variance").summary(c, w), variance(c, w)).first(),
+        Row(Row(summarizer.variance), exp.variance))
     }
 
-    registerTest(s"$name - variance only (direct)") {
-      val (df, c) = wrappedInit()
-      compare(df.select(variance(c)), Seq(summarizer.variance))
+    registerTest(s"$name - variance only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("variance").summary(c), variance(c)).first(),
+        Row(Row(summarizerWithoutWeight.variance), expWithoutWeight.variance))
     }
 
     registerTest(s"$name - count only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("count").summary(c), count(c)),
-        Seq(Row(exp.count), exp.count))
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("count").summary(c, w), count(c, w)).first(),
+        Row(Row(summarizer.count), exp.count))
     }
 
-    registerTest(s"$name - count only (direct)") {
-      val (df, c) = wrappedInit()
-      compare(df.select(count(c)),
-        Seq(exp.count))
+    registerTest(s"$name - count only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("count").summary(c), count(c)).first(),
+        Row(Row(summarizerWithoutWeight.count), expWithoutWeight.count))
     }
 
     registerTest(s"$name - numNonZeros only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)),
-        Seq(Row(exp.numNonZeros), exp.numNonZeros))
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("numNonZeros").summary(c, w), numNonZeros(c, w)).first(),
+        Row(Row(summarizer.numNonzeros), exp.numNonZeros))
     }
 
-    registerTest(s"$name - numNonZeros only (direct)") {
-      val (df, c) = wrappedInit()
-      compare(df.select(numNonZeros(c)),
-        Seq(exp.numNonZeros))
+    registerTest(s"$name - numNonZeros only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("numNonZeros").summary(c), numNonZeros(c)).first(),
+        Row(Row(summarizerWithoutWeight.numNonzeros), expWithoutWeight.numNonZeros))
     }
 
     registerTest(s"$name - min only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("min").summary(c), min(c)),
-        Seq(Row(exp.min), exp.min))
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("min").summary(c, w), min(c, w)).first(),
+        Row(Row(summarizer.min), exp.min))
+    }
+
+    registerTest(s"$name - min only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("min").summary(c), min(c)).first(),
+        Row(Row(summarizerWithoutWeight.min), expWithoutWeight.min))
     }
 
     registerTest(s"$name - max only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("max").summary(c), max(c)),
-        Seq(Row(exp.max), exp.max))
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("max").summary(c, w), max(c, w)).first(),
+        Row(Row(summarizer.max), exp.max))
     }
 
-    registerTest(s"$name - normL1 only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("normL1").summary(c), normL1(c)),
-        Seq(Row(exp.normL1), exp.normL1))
+    registerTest(s"$name - max only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("max").summary(c), max(c)).first(),
+        Row(Row(summarizerWithoutWeight.max), expWithoutWeight.max))
     }
 
-    registerTest(s"$name - normL2 only") {
-      val (df, c) = wrappedInit()
-      compare(df.select(metrics("normL2").summary(c), normL2(c)),
-        Seq(Row(exp.normL2), exp.normL2))
+    registerTest(s"$name - normL1 only") {
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("normL1").summary(c, w), normL1(c, w)).first(),
+        Row(Row(summarizer.normL1), exp.normL1))
     }
 
-    registerTest(s"$name - all metrics at once") {
-      val (df, c) = wrappedInit()
-      compare(df.select(
-        metrics("mean", "variance", "count", "numNonZeros").summary(c),
-        mean(c), variance(c), count(c), numNonZeros(c)),
-        Seq(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros),
-          exp.mean, exp.variance, exp.count, exp.numNonZeros))
+    registerTest(s"$name - normL1 only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("normL1").summary(c), normL1(c)).first(),
+        Row(Row(summarizerWithoutWeight.normL1), expWithoutWeight.normL1))
     }
-  }
 
-  private def denseData(input: Seq[Seq[Double]]): DataFrame = {
-    input.map(_.toArray).map(Vectors.dense).map(Tuple1.apply).toDF("features")
-  }
+    registerTest(s"$name - normL2 only") {
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(metrics("normL2").summary(c, w), normL2(c, w)).first(),
+        Row(Row(summarizer.normL2), exp.normL2))
+    }
 
-  private def compare(df: DataFrame, exp: Seq[Any]): Unit = {
-    val coll = df.collect().toSeq
-    val Seq(row) = coll
-    val res = row.toSeq
-    val names = df.schema.fieldNames.zipWithIndex.map { case (n, idx) => s"$n ($idx)"
}
-    assert(res.size === exp.size, (res.size, exp.size))
-    for (((x1, x2), name) <- res.zip(exp).zip(names)) {
-      compareStructures(x1, x2, name)
+    registerTest(s"$name - normL2 only w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(metrics("normL2").summary(c), normL2(c)).first(),
+        Row(Row(summarizerWithoutWeight.normL2), expWithoutWeight.normL2))
     }
-  }
 
-  // Compares structured content.
-  private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match {
-    case (y1: Seq[Double @unchecked], v1: OldVector) =>
-      compareStructures(y1, v1.toArray.toSeq, name)
-    case (d1: Double, d2: Double) =>
-      assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name)
-    case (r1: GenericRowWithSchema, r2: Row) =>
-      assert(r1.size === r2.size, (r1, r2))
-      for (((fname, x1), x2) <- r1.schema.fieldNames.zip(r1.toSeq).zip(r2.toSeq)) {
-        compareStructures(x1, x2, s"$name.$fname")
-      }
-    case (r1: Row, r2: Row) =>
-      assert(r1.size === r2.size, (r1, r2))
-      for ((x1, x2) <- r1.toSeq.zip(r2.toSeq)) { compareStructures(x1, x2, name) }
-    case (v1: Vector, v2: Vector) =>
-      assert2(v1 ~== v2 absTol 1e-4, name)
-    case (l1: Long, l2: Long) => assert(l1 === l2)
-    case (s1: Seq[_], s2: Seq[_]) =>
-      assert(s1.size === s2.size, s"$name ${(s1, s2)}")
-      for (((x1, idx), x2) <- s1.zipWithIndex.zip(s2)) {
-        compareStructures(x1, x2, s"$name.$idx")
-      }
-    case (arr1: Array[_], arr2: Array[_]) =>
-      assert(arr1.toSeq === arr2.toSeq)
-    case _ => throw new Exception(s"$name: ${x1.getClass} ${x2.getClass} $x1 $x2")
-  }
+    registerTest(s"$name - multiple metrics at once") {
+      val (df, c, w) = wrappedInit()
+      compareRow(df.select(
+        metrics("mean", "variance", "count", "numNonZeros").summary(c, w)).first(),
+        Row(Row(exp.mean, exp.variance, exp.count, exp.numNonZeros))
+      )
+    }
 
-  private def assert2(x: => Boolean, hint: String): Unit = {
-    try {
-      assert(x, hint)
-    } catch {
-      case tfe: TestFailedException =>
-        throw new TestFailedException(Some(s"Failure with hint $hint"), Some(tfe), 1)
+    registerTest(s"$name - multiple metrics at once w/o weight") {
+      val (df, c, _) = wrappedInit()
+      compareRow(df.select(
+        metrics("mean", "variance", "count", "numNonZeros").summary(c)).first(),
+        Row(Row(expWithoutWeight.mean, expWithoutWeight.variance,
+          expWithoutWeight.count, expWithoutWeight.numNonZeros))
+      )
     }
   }
 
-  test("debugging test") {
-    val df = denseData(Nil)
-    val c = df.col("features")
-    val c1 = metrics("mean").summary(c)
-    val res = df.select(c1)
-    intercept[SparkException] {
-      compare(res, Seq.empty)
+  private def compareRow(r1: Row, r2: Row): Unit = {
+    assert(r1.size === r2.size, (r1, r2))
+    r1.toSeq.zip(r2.toSeq).foreach {
+      case (v1: Vector, v2: Vector) =>
+        assert(v1 ~== v2 absTol 1e-4)
+      case (v1: Vector, v2: OldVector) =>
+        assert(v1 ~== v2.asML absTol 1e-4)
+      case (l1: Long, l2: Long) =>
+        assert(l1 === l2)
+      case (r1: Row, r2: Row) =>
+        compareRow(r1, r2)
+      case (x1: Any, x2: Any) =>
+        throw new Exception(s"type mismatch: ${x1.getClass} ${x2.getClass} $x1 $x2")
     }
   }
 
-  test("basic error handling") {
-    val df = denseData(Nil)
+  test("no element") {
+    val df = Seq[Tuple1[Vector]]().toDF("features")
     val c = df.col("features")
-    val res = df.select(metrics("mean").summary(c), mean(c))
     intercept[SparkException] {
-      compare(res, Seq.empty)
+      df.select(metrics("mean").summary(c), mean(c)).first()
     }
+    compareRow(df.select(metrics("count").summary(c), count(c)).first(),
+      Row(Row(0L), 0L))
   }
 
-  test("no element, working metrics") {
-    val df = denseData(Nil)
-    val c = df.col("features")
-    val res = df.select(metrics("count").summary(c), count(c))
-    compare(res, Seq(Row(0L), 0L))
-  }
+  val singleElem = Vectors.dense(0.0, 1.0, 2.0)
+  testExample("single element", Seq((singleElem, 2.0)),
+    ExpectedMetrics(
+      mean = singleElem,
+      variance = Vectors.dense(0.0, 0.0, 0.0),
+      count = 1L,
+      numNonZeros = Vectors.dense(0.0, 1.0, 1.0),
+      max = singleElem,
+      min = singleElem,
+      normL1 = Vectors.dense(0.0, 2.0, 4.0),
+      normL2 = Vectors.dense(0.0, 1.414213, 2.828427)
+    ),
+    ExpectedMetrics(
+      mean = singleElem,
+      variance = Vectors.dense(0.0, 0.0, 0.0),
+      count = 1L,
+      numNonZeros = Vectors.dense(0.0, 1.0, 1.0),
+      max = singleElem,
+      min = singleElem,
+      normL1 = singleElem,
+      normL2 = singleElem
+    )
+  )
+
+  testExample("multiple elements (dense)",
+    Seq(
+      (Vectors.dense(-1.0, 0.0, 6.0), 0.5),
+      (Vectors.dense(3.0, -3.0, 0.0), 2.8),
+      (Vectors.dense(1.0, -3.0, 0.0), 0.0)
+    ),
+    ExpectedMetrics(
+      mean = Vectors.dense(2.393939, -2.545454, 0.909090),
+      variance = Vectors.dense(8.0, 4.5, 18.0),
+      count = 2L,
+      numNonZeros = Vectors.dense(2.0, 1.0, 1.0),
+      max = Vectors.dense(3.0, 0.0, 6.0),
+      min = Vectors.dense(-1.0, -3.0, 0.0),
+      normL1 = Vectors.dense(8.9, 8.4, 3.0),
+      normL2 = Vectors.dense(5.069516, 5.019960, 4.242640)
+    ),
+    ExpectedMetrics(
+      mean = Vectors.dense(1.0, -2.0, 2.0),
+      variance = Vectors.dense(4.0, 3.0, 12.0),
+      count = 3L,
+      numNonZeros = Vectors.dense(3.0, 2.0, 1.0),
+      max = Vectors.dense(3.0, 0.0, 6.0),
+      min = Vectors.dense(-1.0, -3.0, 0.0),
+      normL1 = Vectors.dense(5.0, 6.0, 6.0),
+      normL2 = Vectors.dense(3.316624, 4.242640, 6.0)
+    )
+  )
 
-  val singleElem = Seq(0.0, 1.0, 2.0)
-  testExample("single element", Seq(singleElem), ExpectedMetrics(
-    mean = singleElem,
-    variance = Seq(0.0, 0.0, 0.0),
-    count = 1,
-    numNonZeros = Seq(0, 1, 1),
-    max = singleElem,
-    min = singleElem,
-    normL1 = singleElem,
-    normL2 = singleElem
-  ))
-
-  testExample("two elements", Seq(Seq(0.0, 1.0, 2.0), Seq(0.0, -1.0, -2.0)), ExpectedMetrics(
-    mean = Seq(0.0, 0.0, 0.0),
-    // TODO: I have a doubt about these values, they are not normalized.
-    variance = Seq(0.0, 2.0, 8.0),
-    count = 2,
-    numNonZeros = Seq(0, 2, 2),
-    max = Seq(0.0, 1.0, 2.0),
-    min = Seq(0.0, -1.0, -2.0),
-    normL1 = Seq(0.0, 2.0, 4.0),
-    normL2 = Seq(0.0, math.sqrt(2.0), math.sqrt(2.0) * 2.0)
-  ))
-
-  testExample("dense vector input",
-    Seq(Seq(-1.0, 0.0, 6.0), Seq(3.0, -3.0, 0.0)),
+  testExample("multiple elements (sparse)",
+    Seq(
+      (Vectors.dense(-1.0, 0.0, 6.0).toSparse, 0.5),
+      (Vectors.dense(3.0, -3.0, 0.0).toSparse, 2.8),
+      (Vectors.dense(1.0, -3.0, 0.0).toSparse, 0.0)
+    ),
+    ExpectedMetrics(
+      mean = Vectors.dense(2.393939, -2.545454, 0.909090),
+      variance = Vectors.dense(8.0, 4.5, 18.0),
+      count = 2L,
+      numNonZeros = Vectors.dense(2.0, 1.0, 1.0),
+      max = Vectors.dense(3.0, 0.0, 6.0),
+      min = Vectors.dense(-1.0, -3.0, 0.0),
+      normL1 = Vectors.dense(8.9, 8.4, 3.0),
+      normL2 = Vectors.dense(5.069516, 5.019960, 4.242640)
+    ),
     ExpectedMetrics(
-      mean = Seq(1.0, -1.5, 3.0),
-      variance = Seq(8.0, 4.5, 18.0),
-      count = 2,
-      numNonZeros = Seq(2, 1, 1),
-      max = Seq(3.0, 0.0, 6.0),
-      min = Seq(-1.0, -3, 0.0),
-      normL1 = Seq(4.0, 3.0, 6.0),
-      normL2 = Seq(math.sqrt(10), 3, 6.0)
+      mean = Vectors.dense(1.0, -2.0, 2.0),
+      variance = Vectors.dense(4.0, 3.0, 12.0),
+      count = 3L,
+      numNonZeros = Vectors.dense(3.0, 2.0, 1.0),
+      max = Vectors.dense(3.0, 0.0, 6.0),
+      min = Vectors.dense(-1.0, -3.0, 0.0),
+      normL1 = Vectors.dense(5.0, 6.0, 6.0),
+      normL2 = Vectors.dense(3.316624, 4.242640, 6.0)
     )
   )
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message