spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None
Date Thu, 07 Jan 2016 18:33:13 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 47a58c799 -> 69a885a71


[SPARK-12006][ML][PYTHON] Fix GMM failure if initialModel is not None

If initial model passed to GMM is not empty it causes net.razorvine.pickle.PickleException.
It can be fixed by converting initialModel.weights to list.

Author: zero323 <matthew.szymkiewicz@gmail.com>

Closes #10644 from zero323/SPARK-12006.

(cherry picked from commit 592f64985d0d58b4f6a0366bf975e04ca496bdbe)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/69a885a7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/69a885a7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/69a885a7

Branch: refs/heads/branch-1.6
Commit: 69a885a71cfe7c62179e784e7d9eee023d3bb6eb
Parents: 47a58c7
Author: zero323 <matthew.szymkiewicz@gmail.com>
Authored: Thu Jan 7 10:32:56 2016 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Jan 7 10:33:09 2016 -0800

----------------------------------------------------------------------
 python/pyspark/mllib/clustering.py |  2 +-
 python/pyspark/mllib/tests.py      | 12 ++++++++++++
 2 files changed, 13 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/69a885a7/python/pyspark/mllib/clustering.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index c9e6f1d..48daa87 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -346,7 +346,7 @@ class GaussianMixture(object):
             if initialModel.k != k:
                 raise Exception("Mismatched cluster count, initialModel.k = %s, however k
= %s"
                                 % (initialModel.k, k))
-            initialModelWeights = initialModel.weights
+            initialModelWeights = list(initialModel.weights)
             initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)]
             initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)]
         java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector),

http://git-wip-us.apache.org/repos/asf/spark/blob/69a885a7/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index f8e8e0e..26db87e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -474,6 +474,18 @@ class ListTests(MLlibTestCase):
         for c1, c2 in zip(clusters1.weights, clusters2.weights):
             self.assertEqual(round(c1, 7), round(c2, 7))
 
+    def test_gmm_with_initial_model(self):
+        from pyspark.mllib.clustering import GaussianMixture
+        data = self.sc.parallelize([
+            (-10, -5), (-9, -4), (10, 5), (9, 4)
+        ])
+
+        gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+                                     maxIterations=10, seed=63)
+        gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001,
+                                     maxIterations=10, seed=63, initialModel=gmm1)
+        self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0)
+
     def test_classification(self):
         from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
         from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message