spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-23568][ML] Use metadata numAttributes if available in Silhouette
Date Wed, 21 Mar 2018 15:19:05 GMT
Repository: spark
Updated Branches:
  refs/heads/master 983e8d9d6 -> 500b21c3d


[SPARK-23568][ML] Use metadata numAttributes if available in Silhouette

## What changes were proposed in this pull request?

Silhouette need to know the number of features. This was taken using `first` and checking
the size of the vector. Despite this works fine, if the number of attributes is present in
metadata, we can avoid to trigger a job for this and use the metadata value. This can help
improving performances of course.

## How was this patch tested?

existing UTs + added UT

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #20719 from mgaido91/SPARK-23568.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/500b21c3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/500b21c3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/500b21c3

Branch: refs/heads/master
Commit: 500b21c3d6247015e550be7e144e9b4b26fe28be
Parents: 983e8d9
Author: Marco Gaido <marcogaido91@gmail.com>
Authored: Wed Mar 21 10:19:02 2018 -0500
Committer: Sean Owen <srowen@gmail.com>
Committed: Wed Mar 21 10:19:02 2018 -0500

----------------------------------------------------------------------
 .../ml/evaluation/ClusteringEvaluator.scala     | 22 +++++++++++++++----
 .../evaluation/ClusteringEvaluatorSuite.scala   | 23 +++++++++++++++++++-
 2 files changed, 40 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/500b21c3/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
index 8d4ae56..4353c46 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.attribute.AttributeGroup
 import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT}
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
@@ -170,6 +171,15 @@ private[evaluation] abstract class Silhouette {
   def overallScore(df: DataFrame, scoreColumn: Column): Double = {
     df.select(avg(scoreColumn)).collect()(0).getDouble(0)
   }
+
+  protected def getNumberOfFeatures(dataFrame: DataFrame, columnName: String): Int = {
+    val group = AttributeGroup.fromStructField(dataFrame.schema(columnName))
+    if (group.size < 0) {
+      dataFrame.select(col(columnName)).first().getAs[Vector](0).size
+    } else {
+      group.size
+    }
+  }
 }
 
 /**
@@ -360,7 +370,7 @@ private[evaluation] object SquaredEuclideanSilhouette extends Silhouette
{
     df: DataFrame,
     predictionCol: String,
     featuresCol: String): Map[Double, ClusterStats] = {
-    val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size
+    val numFeatures = getNumberOfFeatures(df, featuresCol)
     val clustersStatsRDD = df.select(
         col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"))
       .rdd
@@ -552,8 +562,11 @@ private[evaluation] object CosineSilhouette extends Silhouette {
    * @return A [[scala.collection.immutable.Map]] which associates each cluster id to a
    *         its statistics (ie. the precomputed values `N` and `$\Omega_{\Gamma}$`).
    */
-  def computeClusterStats(df: DataFrame, predictionCol: String): Map[Double, (Vector, Long)]
= {
-    val numFeatures = df.select(col(normalizedFeaturesColName)).first().getAs[Vector](0).size
+  def computeClusterStats(
+      df: DataFrame,
+      featuresCol: String,
+      predictionCol: String): Map[Double, (Vector, Long)] = {
+    val numFeatures = getNumberOfFeatures(df, featuresCol)
     val clustersStatsRDD = df.select(
       col(predictionCol).cast(DoubleType), col(normalizedFeaturesColName))
       .rdd
@@ -626,7 +639,8 @@ private[evaluation] object CosineSilhouette extends Silhouette {
       normalizeFeatureUDF(col(featuresCol)))
 
     // compute aggregate values for clusters needed by the algorithm
-    val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, predictionCol)
+    val clustersStatsMap = computeClusterStats(dfWithNormalizedFeatures, featuresCol,
+      predictionCol)
 
     // Silhouette is reasonable only when the number of clusters is greater then 1
     assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")

http://git-wip-us.apache.org/repos/asf/spark/blob/500b21c3/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
index 3bf3477..2c175ff 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.ml.evaluation
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.ml.util.TestingUtils._
@@ -100,4 +102,23 @@ class ClusteringEvaluatorSuite
     }
   }
 
+  test("SPARK-23568: we should use metadata to determine features number") {
+    val attributesNum = irisDataset.select("features").rdd.first().getAs[Vector](0).size
+    val attrGroup = new AttributeGroup("features", attributesNum)
+    val df = irisDataset.select($"features".as("features", attrGroup.toMetadata()), $"label")
+    require(AttributeGroup.fromStructField(df.schema("features"))
+      .numAttributes.isDefined, "numAttributes metadata should be defined")
+    val evaluator = new ClusteringEvaluator()
+      .setFeaturesCol("features")
+      .setPredictionCol("label")
+
+    // with the proper metadata we compute correctly the result
+    assert(evaluator.evaluate(df) ~== 0.6564679231 relTol 1e-5)
+
+    val wrongAttrGroup = new AttributeGroup("features", attributesNum + 1)
+    val dfWrong = irisDataset.select($"features".as("features", wrongAttrGroup.toMetadata()),
+      $"label")
+    // with wrong metadata the evaluator throws an Exception
+    intercept[SparkException](evaluator.evaluate(dfWrong))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message