spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-11029] [ML] Add computeCost to KMeansModel in spark.ml
Date Sat, 17 Oct 2015 17:04:27 GMT
Repository: spark
Updated Branches:
  refs/heads/master 8ac71d62d -> e1e77b22b


[SPARK-11029] [ML] Add computeCost to KMeansModel in spark.ml

jira: https://issues.apache.org/jira/browse/SPARK-11029

We should add a method analogous to spark.mllib.clustering.KMeansModel.computeCost to spark.ml.clustering.KMeansModel.
This will be a temp fix until we have proper evaluators defined for clustering.

Author: Yuhao Yang <hhbyyh@gmail.com>
Author: yuhaoyang <yuhao@zhanglipings-iMac.local>

Closes #9073 from hhbyyh/computeCost.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e1e77b22
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e1e77b22
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e1e77b22

Branch: refs/heads/master
Commit: e1e77b22b3b577909a12c3aa898eb53be02267fd
Parents: 8ac71d6
Author: Yuhao Yang <hhbyyh@gmail.com>
Authored: Sat Oct 17 10:04:19 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Sat Oct 17 10:04:19 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/clustering/KMeans.scala   | 12 ++++++++++++
 .../org/apache/spark/ml/clustering/KMeansSuite.scala    |  1 +
 2 files changed, 13 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e1e77b22/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index f40ab71..509be63 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -117,6 +117,18 @@ class KMeansModel private[ml] (
 
   @Since("1.5.0")
   def clusterCenters: Array[Vector] = parentModel.clusterCenters
+
+  /**
+   * Return the K-means cost (sum of squared distances of points to their nearest center)
for this
+   * model on the given data.
+   */
+  // TODO: Replace the temp fix when we have proper evaluators defined for clustering.
+  @Since("1.6.0")
+  def computeCost(dataset: DataFrame): Double = {
+    SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
+    val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point
}
+    parentModel.computeCost(data)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/e1e77b22/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 688b0e3..c05f905 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -104,5 +104,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
     val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
     assert(clusters.size === k)
     assert(clusters === Set(0, 1, 2, 3, 4))
+    assert(model.computeCost(dataset) < 0.1)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message