spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-22883][ML][TEST] Streaming tests for spark.ml.feature, from A to H
Date Fri, 02 Mar 2018 06:27:37 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.3 2aa66eb38 -> 56cfbd932


[SPARK-22883][ML][TEST] Streaming tests for spark.ml.feature, from A to H

## What changes were proposed in this pull request?

Adds structured streaming tests using testTransformer for these suites:
* BinarizerSuite
* BucketedRandomProjectionLSHSuite
* BucketizerSuite
* ChiSqSelectorSuite
* CountVectorizerSuite
* DCTSuite.scala
* ElementwiseProductSuite
* FeatureHasherSuite
* HashingTFSuite

## How was this patch tested?

It tests itself because it is a bunch of tests!

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #20111 from jkbradley/SPARK-22883-streaming-featureAM.

(cherry picked from commit 119f6a0e4729aa952e811d2047790a32ee90bf69)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56cfbd93
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56cfbd93
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56cfbd93

Branch: refs/heads/branch-2.3
Commit: 56cfbd932d3d038ce21cfa4939dfd9563c719003
Parents: 2aa66eb
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Thu Mar 1 21:04:01 2018 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Mar 1 22:27:30 2018 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/BinarizerSuite.scala       |  8 ++--
 .../BucketedRandomProjectionLSHSuite.scala      | 26 ++++++++---
 .../spark/ml/feature/BucketizerSuite.scala      | 11 +++--
 .../spark/ml/feature/ChiSqSelectorSuite.scala   | 36 ++++++++--------
 .../spark/ml/feature/CountVectorizerSuite.scala | 23 +++++-----
 .../org/apache/spark/ml/feature/DCTSuite.scala  | 14 +++---
 .../ml/feature/ElementwiseProductSuite.scala    | 30 ++++++++++---
 .../spark/ml/feature/FeatureHasherSuite.scala   | 45 +++++++++-----------
 .../spark/ml/feature/HashingTFSuite.scala       | 34 +++++++++------
 9 files changed, 126 insertions(+), 101 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 4455d35..05d4a6e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -17,14 +17,12 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.sql.{DataFrame, Row}
 
-class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
+class BinarizerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -47,7 +45,7 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with
Defau
       .setInputCol("feature")
       .setOutputCol("binarized_feature")
 
-    binarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach
{
+    testTransformer[(Double, Double)](dataFrame, binarizer, "binarized_feature", "expected")
{
       case Row(x: Double, y: Double) =>
         assert(x === y, "The feature value is not correct after binarization.")
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index 7175c72..ed9a39d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -20,16 +20,15 @@ package org.apache.spark.ml.feature
 import breeze.numerics.{cos, sin}
 import breeze.numerics.constants.Pi
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.{Dataset, Row}
 
-class BucketedRandomProjectionLSHSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest {
+
+  import testImplicits._
 
   @transient var dataset: Dataset[_] = _
 
@@ -98,6 +97,21 @@ class BucketedRandomProjectionLSHSuite
     MLTestingUtils.checkCopyAndUids(brp, brpModel)
   }
 
+  test("BucketedRandomProjectionLSH: streaming transform") {
+    val brp = new BucketedRandomProjectionLSH()
+      .setNumHashTables(2)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(1.0)
+      .setSeed(12345)
+    val brpModel = brp.fit(dataset)
+
+    testTransformer[Tuple1[Vector]](dataset.toDF(), brpModel, "values") {
+      case Row(values: Seq[_]) =>
+        assert(values.length === brp.getNumHashTables)
+    }
+  }
+
   test("BucketedRandomProjectionLSH: test of LSH property") {
     // Project from 2 dimensional Euclidean Space to 1 dimensions
     val brp = new BucketedRandomProjectionLSH()

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 41cf72f..9ea15e1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -23,14 +23,13 @@ import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.Pipeline
 import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
-class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
+class BucketizerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -50,7 +49,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with
Defa
       .setOutputCol("result")
       .setSplits(splits)
 
-    bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+    testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
       case Row(x: Double, y: Double) =>
         assert(x === y,
           s"The feature value is not correct after bucketing.  Expected $y but found $x")
@@ -84,7 +83,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with
Defa
       .setOutputCol("result")
       .setSplits(splits)
 
-    bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+    testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
       case Row(x: Double, y: Double) =>
         assert(x === y,
           s"The feature value is not correct after bucketing.  Expected $y but found $x")
@@ -103,7 +102,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
       .setSplits(splits)
 
     bucketizer.setHandleInvalid("keep")
-    bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
+    testTransformer[(Double, Double)](dataFrame, bucketizer, "result", "expected") {
       case Row(x: Double, y: Double) =>
         assert(x === y,
           s"The feature value is not correct after bucketing.  Expected $y but found $x")

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index c83909c..c843df9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -17,16 +17,15 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Dataset, Row}
 
-class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
-  with DefaultReadWriteTest {
+class ChiSqSelectorSuite extends MLTest with DefaultReadWriteTest {
+
+  import testImplicits._
 
   @transient var dataset: Dataset[_] = _
 
@@ -119,32 +118,32 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
   test("Test Chi-Square selector: numTopFeatures") {
     val selector = new ChiSqSelector()
       .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1)
-    val model = ChiSqSelectorSuite.testSelector(selector, dataset)
+    val model = testSelector(selector, dataset)
     MLTestingUtils.checkCopyAndUids(selector, model)
   }
 
   test("Test Chi-Square selector: percentile") {
     val selector = new ChiSqSelector()
       .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.17)
-    ChiSqSelectorSuite.testSelector(selector, dataset)
+    testSelector(selector, dataset)
   }
 
   test("Test Chi-Square selector: fpr") {
     val selector = new ChiSqSelector()
       .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.02)
-    ChiSqSelectorSuite.testSelector(selector, dataset)
+    testSelector(selector, dataset)
   }
 
   test("Test Chi-Square selector: fdr") {
     val selector = new ChiSqSelector()
       .setOutputCol("filtered").setSelectorType("fdr").setFdr(0.12)
-    ChiSqSelectorSuite.testSelector(selector, dataset)
+    testSelector(selector, dataset)
   }
 
   test("Test Chi-Square selector: fwe") {
     val selector = new ChiSqSelector()
       .setOutputCol("filtered").setSelectorType("fwe").setFwe(0.12)
-    ChiSqSelectorSuite.testSelector(selector, dataset)
+    testSelector(selector, dataset)
   }
 
   test("read/write") {
@@ -163,18 +162,19 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
         assert(expected.selectedFeatures === actual.selectedFeatures)
       }
   }
-}
 
-object ChiSqSelectorSuite {
-
-  private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): ChiSqSelectorModel
= {
-    val selectorModel = selector.fit(dataset)
-    selectorModel.transform(dataset).select("filtered", "topFeature").collect()
-      .foreach { case Row(vec1: Vector, vec2: Vector) =>
+  private def testSelector(selector: ChiSqSelector, data: Dataset[_]): ChiSqSelectorModel
= {
+    val selectorModel = selector.fit(data)
+    testTransformer[(Double, Vector, Vector)](data.toDF(), selectorModel,
+      "filtered", "topFeature") {
+      case Row(vec1: Vector, vec2: Vector) =>
         assert(vec1 ~== vec2 absTol 1e-1)
-      }
+    }
     selectorModel
   }
+}
+
+object ChiSqSelectorSuite {
 
   /**
    * Mapping from all Params to valid settings which differ from the defaults.

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index f213145..b4cabff 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -16,16 +16,13 @@
  */
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Row
 
-class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
-  with DefaultReadWriteTest {
+class CountVectorizerSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -50,7 +47,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
     val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
       .setInputCol("words")
       .setOutputCol("features")
-    cv.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }
@@ -72,7 +69,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
     MLTestingUtils.checkCopyAndUids(cv, cvm)
     assert(cvm.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
 
-    cvm.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cvm, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }
@@ -100,7 +97,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       .fit(df)
     assert(cvModel2.vocabulary === Array("a", "b"))
 
-    cvModel2.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cvModel2, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }
@@ -113,7 +110,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       .fit(df)
     assert(cvModel3.vocabulary === Array("a", "b"))
 
-    cvModel3.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cvModel3, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }
@@ -147,7 +144,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       .setInputCol("words")
       .setOutputCol("features")
       .setMinTF(3)
-    cv.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }
@@ -166,7 +163,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       .setInputCol("words")
       .setOutputCol("features")
       .setMinTF(0.3)
-    cv.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }
@@ -186,7 +183,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       .setOutputCol("features")
       .setBinary(true)
       .fit(df)
-    cv.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cv, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }
@@ -196,7 +193,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
       .setInputCol("words")
       .setOutputCol("features")
       .setBinary(true)
-    cv2.transform(df).select("features", "expected").collect().foreach {
+    testTransformer[(Int, Seq[String], Vector)](df, cv2, "features", "expected") {
       case Row(features: Vector, expected: Vector) =>
         assert(features ~== expected absTol 1e-14)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 8dd3dd7..6734336 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -21,16 +21,14 @@ import scala.beans.BeanInfo
 
 import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.sql.Row
 
 @BeanInfo
 case class DCTTestData(vec: Vector, wantedVec: Vector)
 
-class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
+class DCTSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -72,11 +70,9 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
       .setOutputCol("resultVec")
       .setInverse(inverse)
 
-    transformer.transform(dataset)
-      .select("resultVec", "wantedVec")
-      .collect()
-      .foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
-      assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
+    testTransformer[(Vector, Vector)](dataset, transformer, "resultVec", "wantedVec") {
+      case Row(resultVec: Vector, wantedVec: Vector) =>
+        assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
index a4cca27..3a8d076 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ElementwiseProductSuite.scala
@@ -17,13 +17,31 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.sql.Row
 
-class ElementwiseProductSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class ElementwiseProductSuite extends MLTest with DefaultReadWriteTest {
+
+  import testImplicits._
+
+  test("streaming transform") {
+    val scalingVec = Vectors.dense(0.1, 10.0)
+    val data = Seq(
+      (Vectors.dense(0.1, 1.0), Vectors.dense(0.01, 10.0)),
+      (Vectors.dense(0.0, -1.1), Vectors.dense(0.0, -11.0))
+    )
+    val df = spark.createDataFrame(data).toDF("features", "expected")
+    val ep = new ElementwiseProduct()
+      .setInputCol("features")
+      .setOutputCol("actual")
+      .setScalingVec(scalingVec)
+    testTransformer[(Vector, Vector)](df, ep, "actual", "expected") {
+      case Row(actual: Vector, expected: Vector) =>
+        assert(actual ~== expected relTol 1e-14)
+    }
+  }
 
   test("read/write") {
     val ep = new ElementwiseProduct()

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala
index 7bc1825..d799ba6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/FeatureHasherSuite.scala
@@ -17,27 +17,24 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
-class FeatureHasherSuite extends SparkFunSuite
-  with MLlibTestSparkContext
-  with DefaultReadWriteTest {
+class FeatureHasherSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
   import FeatureHasherSuite.murmur3FeatureIdx
 
-  implicit private val vectorEncoder = ExpressionEncoder[Vector]()
+  implicit private val vectorEncoder: ExpressionEncoder[Vector] = ExpressionEncoder[Vector]()
 
   test("params") {
     ParamsSuite.checkParams(new FeatureHasher)
@@ -52,31 +49,31 @@ class FeatureHasherSuite extends SparkFunSuite
   }
 
   test("feature hashing") {
+    val numFeatures = 100
+    // Assume perfect hash on field names in computing expected results
+    def idx: Any => Int = murmur3FeatureIdx(numFeatures)
+
     val df = Seq(
-      (2.0, true, "1", "foo"),
-      (3.0, false, "2", "bar")
-    ).toDF("real", "bool", "stringNum", "string")
+      (2.0, true, "1", "foo",
+        Vectors.sparse(numFeatures, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0),
+          (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0)))),
+      (3.0, false, "2", "bar",
+        Vectors.sparse(numFeatures, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0),
+          (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0))))
+    ).toDF("real", "bool", "stringNum", "string", "expected")
 
-    val n = 100
     val hasher = new FeatureHasher()
       .setInputCols("real", "bool", "stringNum", "string")
       .setOutputCol("features")
-      .setNumFeatures(n)
+      .setNumFeatures(numFeatures)
     val output = hasher.transform(df)
     val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
-    assert(attrGroup.numAttributes === Some(n))
+    assert(attrGroup.numAttributes === Some(numFeatures))
 
-    val features = output.select("features").as[Vector].collect()
-    // Assume perfect hash on field names
-    def idx: Any => Int = murmur3FeatureIdx(n)
-    // check expected indices
-    val expected = Seq(
-      Vectors.sparse(n, Seq((idx("real"), 2.0), (idx("bool=true"), 1.0),
-        (idx("stringNum=1"), 1.0), (idx("string=foo"), 1.0))),
-      Vectors.sparse(n, Seq((idx("real"), 3.0), (idx("bool=false"), 1.0),
-        (idx("stringNum=2"), 1.0), (idx("string=bar"), 1.0)))
-    )
-    assert(features.zip(expected).forall { case (e, a) => e ~== a absTol 1e-14 })
+    testTransformer[(Double, Boolean, String, String, Vector)](df, hasher, "features", "expected")
{
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14 )
+    }
   }
 
   test("setting explicit numerical columns to treat as categorical") {

http://git-wip-us.apache.org/repos/asf/spark/blob/56cfbd93/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index a46272f..c5183ec 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -17,17 +17,16 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
 import org.apache.spark.util.Utils
 
-class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
+class HashingTFSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
   import HashingTFSuite.murmur3FeatureIdx
@@ -37,21 +36,28 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext
with Defau
   }
 
   test("hashingTF") {
-    val df = Seq((0, "a a b b c d".split(" ").toSeq)).toDF("id", "words")
-    val n = 100
+    val numFeatures = 100
+    // Assume perfect hash when computing expected features.
+    def idx: Any => Int = murmur3FeatureIdx(numFeatures)
+    val data = Seq(
+      ("a a b b c d".split(" ").toSeq,
+        Vectors.sparse(numFeatures,
+          Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))))
+    )
+
+    val df = data.toDF("words", "expected")
     val hashingTF = new HashingTF()
       .setInputCol("words")
       .setOutputCol("features")
-      .setNumFeatures(n)
+      .setNumFeatures(numFeatures)
     val output = hashingTF.transform(df)
     val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
-    require(attrGroup.numAttributes === Some(n))
-    val features = output.select("features").first().getAs[Vector](0)
-    // Assume perfect hash on "a", "b", "c", and "d".
-    def idx: Any => Int = murmur3FeatureIdx(n)
-    val expected = Vectors.sparse(n,
-      Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
-    assert(features ~== expected absTol 1e-14)
+    require(attrGroup.numAttributes === Some(numFeatures))
+
+    testTransformer[(Seq[String], Vector)](df, hashingTF, "features", "expected") {
+      case Row(features: Vector, expected: Vector) =>
+        assert(features ~== expected absTol 1e-14)
+    }
   }
 
   test("applying binary term freqs") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message