spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5015] [mllib] Random seed for GMM + make test suite deterministic
Date Fri, 09 Jan 2015 21:00:21 GMT
Repository: spark
Updated Branches:
  refs/heads/master 454fe129e -> 7e8e62aec


[SPARK-5015] [mllib] Random seed for GMM + make test suite deterministic

Issues:
* From JIRA: GaussianMixtureEM uses randomness but does not take a random seed. It should
take one as a parameter.
* This also makes the test suite flaky since initialization can fail due to stochasticity.

Fix:
* Add random seed
* Use it in test suite

CC: mengxr  tgaloppo

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #3981 from jkbradley/gmm-seed and squashes the following commits:

f0df4fd [Joseph K. Bradley] Added seed parameter to GMM.  Updated test suite to use seed to
prevent flakiness


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e8e62ae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e8e62ae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e8e62ae

Branch: refs/heads/master
Commit: 7e8e62aec11c43c983055adc475b96006412199a
Parents: 454fe12
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Fri Jan 9 13:00:15 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Fri Jan 9 13:00:15 2015 -0800

----------------------------------------------------------------------
 .../mllib/clustering/GaussianMixtureEM.scala    | 26 ++++++++++++++------
 .../GMMExpectationMaximizationSuite.scala       | 14 ++++++-----
 2 files changed, 27 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7e8e62ae/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
index 3a6c0e6..b3c5631 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix,
BLAS}
 import org.apache.spark.mllib.stat.impl.MultivariateGaussian
 import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.util.Utils
 
 /**
  * This class performs expectation maximization for multivariate Gaussian
@@ -45,10 +46,11 @@ import org.apache.spark.mllib.util.MLUtils
 class GaussianMixtureEM private (
     private var k: Int, 
     private var convergenceTol: Double, 
-    private var maxIterations: Int) extends Serializable {
+    private var maxIterations: Int,
+    private var seed: Long) extends Serializable {
   
   /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
-  def this() = this(2, 0.01, 100)
+  def this() = this(2, 0.01, 100, Utils.random.nextLong())
   
   // number of samples per cluster to use when initializing Gaussians
   private val nSamples = 5
@@ -100,11 +102,21 @@ class GaussianMixtureEM private (
     this
   }
   
-  /** Return the largest change in log-likelihood at which convergence is
-   *  considered to have occurred.
+  /**
+   * Return the largest change in log-likelihood at which convergence is
+   * considered to have occurred.
    */
   def getConvergenceTol: Double = convergenceTol
-  
+
+  /** Set the random seed */
+  def setSeed(seed: Long): this.type = {
+    this.seed = seed
+    this
+  }
+
+  /** Return the random seed */
+  def getSeed: Long = seed
+
   /** Perform expectation maximization */
   def run(data: RDD[Vector]): GaussianMixtureModel = {
     val sc = data.sparkContext
@@ -113,7 +125,7 @@ class GaussianMixtureEM private (
     val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
     
     // Get length of the input vectors
-    val d = breezeData.first.length 
+    val d = breezeData.first().length
     
     // Determine initial weights and corresponding Gaussians.
     // If the user supplied an initial GMM, we use those values, otherwise
@@ -126,7 +138,7 @@ class GaussianMixtureEM private (
       })
       
       case None => {
-        val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
+        val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
         (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => 
           val slice = samples.view(i * nSamples, (i + 1) * nSamples)
           new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) 

http://git-wip-us.apache.org/repos/asf/spark/blob/7e8e62ae/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
index 23feb82..9da5495 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
     val Ew = 1.0
     val Emu = Vectors.dense(5.0, 10.0)
     val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
-    
-    val gmm = new GaussianMixtureEM().setK(1).run(data)
-                
-    assert(gmm.weight(0) ~== Ew absTol 1E-5)
-    assert(gmm.mu(0) ~== Emu absTol 1E-5)
-    assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+
+    val seeds = Array(314589, 29032897, 50181, 494821, 4660)
+    seeds.foreach { seed =>
+      val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
+      assert(gmm.weight(0) ~== Ew absTol 1E-5)
+      assert(gmm.mu(0) ~== Emu absTol 1E-5)
+      assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+    }
   }
   
   test("two clusters") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message