spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-14375][ML] Unit test for spark.ml KMeansSummary
Date Wed, 13 Apr 2016 20:23:14 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0d17593b3 -> a91aaf5a8


[SPARK-14375][ML] Unit test for spark.ml KMeansSummary

## What changes were proposed in this pull request?
* Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters.
* Add unit test for spark.ml ```KMeansSummary```.
* Add Since tag.

## How was this patch tested?
unit tests.

cc jkbradley

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #12254 from yanboliang/spark-14375.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a91aaf5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a91aaf5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a91aaf5a

Branch: refs/heads/master
Commit: a91aaf5a8cca18811c0cccc20f4e77f36231b344
Parents: 0d17593
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Wed Apr 13 13:23:10 2016 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Wed Apr 13 13:23:10 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/KMeans.scala | 35 ++++++++++++++++----
 .../org/apache/spark/ml/r/KMeansWrapper.scala   |  2 +-
 .../spark/ml/clustering/KMeansSuite.scala       | 18 +++++++++-
 3 files changed, 47 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a91aaf5a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index d716bc6..b324196 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -144,6 +144,12 @@ class KMeansModel private[ml] (
   }
 
   /**
+   * Return true if there exists summary of model.
+   */
+  @Since("2.0.0")
+  def hasSummary: Boolean = trainingSummary.nonEmpty
+
+  /**
    * Gets summary of model on training set. An exception is
    * thrown if `trainingSummary == None`.
    */
@@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") (
       .setEpsilon($(tol))
     val parentModel = algo.run(rdd)
     val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
-    val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
+    val summary = new KMeansSummary(
+      model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
     model.setSummary(summary)
   }
 
@@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
   override def load(path: String): KMeans = super.load(path)
 }
 
+/**
+ * :: Experimental ::
+ * Summary of KMeans.
+ *
+ * @param predictions  [[DataFrame]] produced by [[KMeansModel.transform()]]
+ * @param predictionCol  Name for column of predicted clusters in `predictions`
+ * @param featuresCol  Name for column of features in `predictions`
+ * @param k  Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
 class KMeansSummary private[clustering] (
     @Since("2.0.0") @transient val predictions: DataFrame,
     @Since("2.0.0") val predictionCol: String,
-    @Since("2.0.0") val featuresCol: String) extends Serializable {
+    @Since("2.0.0") val featuresCol: String,
+    @Since("2.0.0") val k: Int) extends Serializable {
 
   /**
    * Cluster centers of the transformed data.
@@ -296,11 +315,15 @@ class KMeansSummary private[clustering] (
   @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
 
   /**
-   * Size of each cluster.
+   * Size of (number of data points in) each cluster.
    */
   @Since("2.0.0")
-  lazy val clusterSizes: Array[Int] = cluster.rdd.map {
-    case Row(clusterIdx: Int) => (clusterIdx, 1)
-  }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+  lazy val clusterSizes: Array[Long] = {
+    val sizes = Array.fill[Long](k)(0)
+    cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach
{
+      case Row(cluster: Int, count: Long) => sizes(cluster) = count
+    }
+    sizes
+  }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a91aaf5a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index ee51357..9e2b81e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -37,7 +37,7 @@ private[r] class KMeansWrapper private (
 
   lazy val k: Int = kMeansModel.getK
 
-  lazy val size: Array[Int] = kMeansModel.summary.clusterSizes
+  lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
 
   lazy val cluster: DataFrame = kMeansModel.summary.cluster
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a91aaf5a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2076c74..2ca386e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with
DefaultR
     }
   }
 
-  test("fit & transform") {
+  test("fit, transform, and summary") {
     val predictionColName = "kmeans_prediction"
     val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
     val model = kmeans.fit(dataset)
@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with
DefaultR
     assert(clusters === Set(0, 1, 2, 3, 4))
     assert(model.computeCost(dataset) < 0.1)
     assert(model.hasParent)
+
+    // Check validity of model summary
+    val numRows = dataset.count()
+    assert(model.hasSummary)
+    val summary: KMeansSummary = model.summary
+    assert(summary.predictionCol === predictionColName)
+    assert(summary.featuresCol === "features")
+    assert(summary.predictions.count() === numRows)
+    for (c <- Array(predictionColName, "features")) {
+      assert(summary.predictions.columns.contains(c))
+    }
+    assert(summary.cluster.columns === Array(predictionColName))
+    val clusterSizes = summary.clusterSizes
+    assert(clusterSizes.length === k)
+    assert(clusterSizes.sum === numRows)
+    assert(clusterSizes.forall(_ >= 0))
   }
 
   test("read/write") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message