spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yli...@apache.org
Subject spark git commit: [SPARK-15819][PYSPARK][ML] Add KMeanSummary in KMeans of PySpark
Date Wed, 30 Nov 2016 04:52:44 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.1 55b1142bd -> b95aad7ca


[SPARK-15819][PYSPARK][ML] Add KMeanSummary in KMeans of PySpark

## What changes were proposed in this pull request?

Add python api for KMeansSummary
## How was this patch tested?

unit test added

Author: Jeff Zhang <zjffdu@apache.org>

Closes #13557 from zjffdu/SPARK-15819.

(cherry picked from commit 4c82ca86d979e5526a15666683eef3c79c37dc68)
Signed-off-by: Yanbo Liang <ybliang8@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b95aad7c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b95aad7c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b95aad7c

Branch: refs/heads/branch-2.1
Commit: b95aad7cad99a62851fe5e61692fda9bceb4b160
Parents: 55b1142
Author: Jeff Zhang <zjffdu@apache.org>
Authored: Tue Nov 29 20:51:27 2016 -0800
Committer: Yanbo Liang <ybliang8@gmail.com>
Committed: Tue Nov 29 20:52:21 2016 -0800

----------------------------------------------------------------------
 python/pyspark/ml/clustering.py | 41 ++++++++++++++++++++++++++++++++++++
 python/pyspark/ml/tests.py      | 15 +++++++++++++
 2 files changed, 56 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b95aad7c/python/pyspark/ml/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 7f8d845..35d0aef 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -292,6 +292,17 @@ class GaussianMixtureSummary(ClusteringSummary):
         return self._call_java("probability")
 
 
+class KMeansSummary(ClusteringSummary):
+    """
+    .. note:: Experimental
+
+    Summary of KMeans.
+
+    .. versionadded:: 2.1.0
+    """
+    pass
+
+
 class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
     """
     Model fitted by KMeans.
@@ -312,6 +323,27 @@ class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
         """
         return self._call_java("computeCost", dataset)
 
+    @property
+    @since("2.1.0")
+    def hasSummary(self):
+        """
+        Indicates whether a training summary exists for this model instance.
+        """
+        return self._call_java("hasSummary")
+
+    @property
+    @since("2.1.0")
+    def summary(self):
+        """
+        Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
+        training set. An exception is thrown if no summary exists.
+        """
+        if self.hasSummary:
+            return KMeansSummary(self._call_java("summary"))
+        else:
+            raise RuntimeError("No training summary available for this %s" %
+                               self.__class__.__name__)
+
 
 @inherit_doc
 class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
@@ -337,6 +369,13 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter,
HasTol
     True
     >>> rows[2].prediction == rows[3].prediction
     True
+    >>> model.hasSummary
+    True
+    >>> summary = model.summary
+    >>> summary.k
+    2
+    >>> summary.clusterSizes
+    [2, 2]
     >>> kmeans_path = temp_path + "/kmeans"
     >>> kmeans.save(kmeans_path)
     >>> kmeans2 = KMeans.load(kmeans_path)
@@ -345,6 +384,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter,
HasTol
     >>> model_path = temp_path + "/kmeans_model"
     >>> model.save(model_path)
     >>> model2 = KMeansModel.load(model_path)
+    >>> model2.hasSummary
+    False
     >>> model.clusterCenters()[0] == model2.clusterCenters()[0]
     array([ True,  True], dtype=bool)
     >>> model.clusterCenters()[1] == model2.clusterCenters()[1]

http://git-wip-us.apache.org/repos/asf/spark/blob/b95aad7c/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c0f0d40..a0c288a 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1129,6 +1129,21 @@ class TrainingSummaryTest(SparkSessionTestCase):
         self.assertEqual(len(s.clusterSizes), 2)
         self.assertEqual(s.k, 2)
 
+    def test_kmeans_summary(self):
+        data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
+                (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
+        df = self.spark.createDataFrame(data, ["features"])
+        kmeans = KMeans(k=2, seed=1)
+        model = kmeans.fit(df)
+        self.assertTrue(model.hasSummary)
+        s = model.summary
+        self.assertTrue(isinstance(s.predictions, DataFrame))
+        self.assertEqual(s.featuresCol, "features")
+        self.assertEqual(s.predictionCol, "prediction")
+        self.assertTrue(isinstance(s.cluster, DataFrame))
+        self.assertEqual(len(s.clusterSizes), 2)
+        self.assertEqual(s.k, 2)
+
 
 class OneVsRestTests(SparkSessionTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message