spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: SPARK-4156 [MLLIB] EM algorithm for GMMs
Date Mon, 29 Dec 2014 23:29:21 GMT
Repository: spark
Updated Branches:
  refs/heads/master 9bc0df680 -> 6cf6fdf3f


SPARK-4156 [MLLIB] EM algorithm for GMMs

Implementation of Expectation-Maximization for Gaussian Mixture Models.

This is my maiden contribution to Apache Spark, so I apologize now if I have done anything
incorrectly; having said that, this work is my own, and I offer it to the project under the
project's open source license.

Author: Travis Galoppo <tjg2107@columbia.edu>
Author: Travis Galoppo <travis@localhost.localdomain>
Author: tgaloppo <tjg2107@columbia.edu>
Author: FlytxtRnD <meethu.mathew@flytxt.com>

Closes #3022 from tgaloppo/master and squashes the following commits:

aaa8f25 [Travis Galoppo] MLUtils: changed privacy of EPSILON from [util] to [mllib]
709e4bf [Travis Galoppo] fixed usage line to include optional maxIterations parameter
acf1fba [Travis Galoppo] Fixed parameter comment in GaussianMixtureModel Made maximum iterations
an optional parameter to DenseGmmEM
9b2fc2a [Travis Galoppo] Style improvements Changed ExpectationSum to a private class
b97fe00 [Travis Galoppo] Minor fixes and tweaks.
1de73f3 [Travis Galoppo] Removed redundant array from array creation
578c2d1 [Travis Galoppo] Removed unused import
227ad66 [Travis Galoppo] Moved prediction methods into model class.
308c8ad [Travis Galoppo] Numerous changes to improve code
cff73e0 [Travis Galoppo] Replaced accumulators with RDD.aggregate
20ebca1 [Travis Galoppo] Removed unusued code
42b2142 [Travis Galoppo] Added functionality to allow setting of GMM starting point. Added
two cluster test to testing suite.
8b633f3 [Travis Galoppo] Style issue
9be2534 [Travis Galoppo] Style issue
d695034 [Travis Galoppo] Fixed style issues
c3b8ce0 [Travis Galoppo] Merge branch 'master' of https://github.com/tgaloppo/spark   Adds
predict() method
2df336b [Travis Galoppo] Fixed style issue
b99ecc4 [tgaloppo] Merge pull request #1 from FlytxtRnD/predictBranch
f407b4c [FlytxtRnD] Added predict() to return the cluster labels and membership values
97044cf [Travis Galoppo] Fixed style issues
dc9c742 [Travis Galoppo] Moved MultivariateGaussian utility class
e7d413b [Travis Galoppo] Moved multivariate Gaussian utility class to mllib/stat/impl Improved
comments
9770261 [Travis Galoppo] Corrected a variety of style and naming issues.
8aaa17d [Travis Galoppo] Added additional train() method to companion object for cluster count
and tolerance parameters.
676e523 [Travis Galoppo] Fixed to no longer ignore delta value provided on command line
e6ea805 [Travis Galoppo] Merged with master branch; update test suite with latest context
changes. Improved cluster initialization strategy.
86fb382 [Travis Galoppo] Merge remote-tracking branch 'upstream/master'
719d8cc [Travis Galoppo] Added scala test suite with basic test
c1a8e16 [Travis Galoppo] Made GaussianMixtureModel class serializable Modified sum function
for better performance
5c96c57 [Travis Galoppo] Merge remote-tracking branch 'upstream/master'
c15405c [Travis Galoppo] SPARK-4156


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6cf6fdf3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6cf6fdf3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6cf6fdf3

Branch: refs/heads/master
Commit: 6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464
Parents: 9bc0df6
Author: Travis Galoppo <tjg2107@columbia.edu>
Authored: Mon Dec 29 15:29:15 2014 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Dec 29 15:29:15 2014 -0800

----------------------------------------------------------------------
 .../spark/examples/mllib/DenseGmmEM.scala       |  67 ++++++
 .../mllib/clustering/GaussianMixtureEM.scala    | 241 +++++++++++++++++++
 .../mllib/clustering/GaussianMixtureModel.scala |  91 +++++++
 .../mllib/stat/impl/MultivariateGaussian.scala  |  39 +++
 .../org/apache/spark/mllib/util/MLUtils.scala   |   2 +-
 .../GMMExpectationMaximizationSuite.scala       |  78 ++++++
 6 files changed, 517 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6cf6fdf3/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
new file mode 100644
index 0000000..948c350
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.clustering.GaussianMixtureEM
+import org.apache.spark.mllib.linalg.Vectors
+
+/**
+ * An example Gaussian Mixture Model EM app. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM <input> <k> <covergenceTol>
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit
your app.
+ */
+object DenseGmmEM {
+  def main(args: Array[String]): Unit = {
+    if (args.length < 3) {
+      println("usage: DenseGmmEM <input file> <k> <convergenceTol> [maxIterations]")
+    } else {
+      val maxIterations = if (args.length > 3) args(3).toInt else 100
+      run(args(0), args(1).toInt, args(2).toDouble, maxIterations)
+    }
+  }
+
+  private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int)
{
+    val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
+    val ctx  = new SparkContext(conf)
+    
+    val data = ctx.textFile(inputFile).map { line =>
+      Vectors.dense(line.trim.split(' ').map(_.toDouble))
+    }.cache()
+      
+    val clusters = new GaussianMixtureEM()
+      .setK(k)
+      .setConvergenceTol(convergenceTol)
+      .setMaxIterations(maxIterations)
+      .run(data)
+    
+    for (i <- 0 until clusters.k) {
+      println("weight=%f\nmu=%s\nsigma=\n%s\n" format 
+        (clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
+    }
+    
+    println("Cluster labels (first <= 100):")
+    val clusterLabels = clusters.predict(data)
+    clusterLabels.take(100).foreach { x =>
+      print(" " + x)
+    }
+    println()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6cf6fdf3/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
new file mode 100644
index 0000000..bdf984a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.collection.mutable.IndexedSeq
+
+import breeze.linalg.{DenseVector => BreezeVector, DenseMatrix => BreezeMatrix, diag,
Transpose}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
+import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.util.MLUtils
+
+/**
+ * This class performs expectation maximization for multivariate Gaussian
+ * Mixture Models (GMMs).  A GMM represents a composite distribution of
+ * independent Gaussian distributions with associated "mixing" weights
+ * specifying each's contribution to the composite.
+ *
+ * Given a set of sample points, this class will maximize the log-likelihood 
+ * for a mixture of k Gaussians, iterating until the log-likelihood changes by 
+ * less than convergenceTol, or until it has reached the max number of iterations.
+ * While this process is generally guaranteed to converge, it is not guaranteed
+ * to find a global optimum.  
+ * 
+ * @param k The number of independent Gaussians in the mixture model
+ * @param convergenceTol The maximum change in log-likelihood at which convergence
+ * is considered to have occurred.
+ * @param maxIterations The maximum number of iterations to perform
+ */
+class GaussianMixtureEM private (
+    private var k: Int, 
+    private var convergenceTol: Double, 
+    private var maxIterations: Int) extends Serializable {
+  
+  /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
+  def this() = this(2, 0.01, 100)
+  
+  // number of samples per cluster to use when initializing Gaussians
+  private val nSamples = 5
+  
+  // an initializing GMM can be provided rather than using the 
+  // default random starting point
+  private var initialModel: Option[GaussianMixtureModel] = None
+  
+  /** Set the initial GMM starting point, bypassing the random initialization.
+   *  You must call setK() prior to calling this method, and the condition
+   *  (model.k == this.k) must be met; failure will result in an IllegalArgumentException
+   */
+  def setInitialModel(model: GaussianMixtureModel): this.type = {
+    if (model.k == k) {
+      initialModel = Some(model)
+    } else {
+      throw new IllegalArgumentException("mismatched cluster count (model.k != k)")
+    }
+    this
+  }
+  
+  /** Return the user supplied initial GMM, if supplied */
+  def getInitialModel: Option[GaussianMixtureModel] = initialModel
+  
+  /** Set the number of Gaussians in the mixture model.  Default: 2 */
+  def setK(k: Int): this.type = {
+    this.k = k
+    this
+  }
+  
+  /** Return the number of Gaussians in the mixture model */
+  def getK: Int = k
+  
+  /** Set the maximum number of iterations to run. Default: 100 */
+  def setMaxIterations(maxIterations: Int): this.type = {
+    this.maxIterations = maxIterations
+    this
+  }
+  
+  /** Return the maximum number of iterations to run */
+  def getMaxIterations: Int = maxIterations
+  
+  /**
+   * Set the largest change in log-likelihood at which convergence is 
+   * considered to have occurred.
+   */
+  def setConvergenceTol(convergenceTol: Double): this.type = {
+    this.convergenceTol = convergenceTol
+    this
+  }
+  
+  /** Return the largest change in log-likelihood at which convergence is
+   *  considered to have occurred.
+   */
+  def getConvergenceTol: Double = convergenceTol
+  
+  /** Perform expectation maximization */
+  def run(data: RDD[Vector]): GaussianMixtureModel = {
+    val sc = data.sparkContext
+    
+    // we will operate on the data as breeze data
+    val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
+    
+    // Get length of the input vectors
+    val d = breezeData.first.length 
+    
+    // Determine initial weights and corresponding Gaussians.
+    // If the user supplied an initial GMM, we use those values, otherwise
+    // we start with uniform weights, a random mean from the data, and
+    // diagonal covariance matrices using component variances
+    // derived from the samples    
+    val (weights, gaussians) = initialModel match {
+      case Some(gmm) => (gmm.weight, gmm.mu.zip(gmm.sigma).map { case(mu, sigma) =>

+        new MultivariateGaussian(mu.toBreeze.toDenseVector, sigma.toBreeze.toDenseMatrix)

+      })
+      
+      case None => {
+        val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
+        (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => 
+          val slice = samples.view(i * nSamples, (i + 1) * nSamples)
+          new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) 
+        })  
+      }
+    }
+    
+    var llh = Double.MinValue // current log-likelihood 
+    var llhp = 0.0            // previous log-likelihood
+    
+    var iter = 0
+    while(iter < maxIterations && Math.abs(llh-llhp) > convergenceTol) {
+      // create and broadcast curried cluster contribution function
+      val compute = sc.broadcast(ExpectationSum.add(weights, gaussians)_)
+      
+      // aggregate the cluster contribution for all sample points
+      val sums = breezeData.aggregate(ExpectationSum.zero(k, d))(compute.value, _ += _)
+      
+      // Create new distributions based on the partial assignments
+      // (often referred to as the "M" step in literature)
+      val sumWeights = sums.weights.sum
+      var i = 0
+      while (i < k) {
+        val mu = sums.means(i) / sums.weights(i)
+        val sigma = sums.sigmas(i) / sums.weights(i) - mu * new Transpose(mu) // TODO: Use
BLAS.dsyr
+        weights(i) = sums.weights(i) / sumWeights
+        gaussians(i) = new MultivariateGaussian(mu, sigma)
+        i = i + 1
+      }
+   
+      llhp = llh // current becomes previous
+      llh = sums.logLikelihood // this is the freshly computed log-likelihood
+      iter += 1
+    } 
+    
+    // Need to convert the breeze matrices to MLlib matrices
+    val means = Array.tabulate(k) { i => Vectors.fromBreeze(gaussians(i).mu) }
+    val sigmas = Array.tabulate(k) { i => Matrices.fromBreeze(gaussians(i).sigma) }
+    new GaussianMixtureModel(weights, means, sigmas)
+  }
+    
+  /** Average of dense breeze vectors */
+  private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
+    val v = BreezeVector.zeros[Double](x(0).length)
+    x.foreach(xi => v += xi)
+    v / x.length.toDouble 
+  }
+  
+  /**
+   * Construct matrix where diagonal entries are element-wise
+   * variance of input vectors (computes biased variance)
+   */
+  private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] =
{
+    val mu = vectorMean(x)
+    val ss = BreezeVector.zeros[Double](x(0).length)
+    x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
+    diag(ss / x.length.toDouble)
+  }
+}
+
+// companion class to provide zero constructor for ExpectationSum
+private object ExpectationSum {
+  def zero(k: Int, d: Int): ExpectationSum = {
+    new ExpectationSum(0.0, Array.fill(k)(0.0), 
+      Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
+  }
+  
+  // compute cluster contributions for each input point
+  // (U, T) => U for aggregation
+  def add(
+      weights: Array[Double], 
+      dists: Array[MultivariateGaussian])
+      (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
+    val p = weights.zip(dists).map {
+      case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
+    }
+    val pSum = p.sum
+    sums.logLikelihood += math.log(pSum)
+    val xxt = x * new Transpose(x)
+    var i = 0
+    while (i < sums.k) {
+      p(i) /= pSum
+      sums.weights(i) += p(i)
+      sums.means(i) += x * p(i)
+      sums.sigmas(i) += xxt * p(i) // TODO: use BLAS.dsyr
+      i = i + 1
+    }
+    sums
+  }  
+}
+
+// Aggregation class for partial expectation results
+private class ExpectationSum(
+    var logLikelihood: Double,
+    val weights: Array[Double],
+    val means: Array[BreezeVector[Double]],
+    val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
+  
+  val k = weights.length
+  
+  def +=(x: ExpectationSum): ExpectationSum = {
+    var i = 0
+    while (i < k) {
+      weights(i) += x.weights(i)
+      means(i) += x.means(i)
+      sigmas(i) += x.sigmas(i)
+      i = i + 1
+    }
+    logLikelihood += x.logLikelihood
+    this
+  }  
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6cf6fdf3/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
new file mode 100644
index 0000000..11a110d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import breeze.linalg.{DenseVector => BreezeVector}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.linalg.{Matrix, Vector}
+import org.apache.spark.mllib.stat.impl.MultivariateGaussian
+import org.apache.spark.mllib.util.MLUtils
+
+/**
+ * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points 
+ * are drawn from each Gaussian i=1..k with probability w(i); mu(i) and sigma(i) are 
+ * the respective mean and covariance for each Gaussian distribution i=1..k. 
+ * 
+ * @param weight Weights for each Gaussian distribution in the mixture, where weight(i) is
+ *               the weight for Gaussian i, and weight.sum == 1
+ * @param mu Means for each Gaussian in the mixture, where mu(i) is the mean for Gaussian
i
+ * @param sigma Covariance maxtrix for each Gaussian in the mixture, where sigma(i) is the
+ *              covariance matrix for Gaussian i
+ */
+class GaussianMixtureModel(
+  val weight: Array[Double], 
+  val mu: Array[Vector], 
+  val sigma: Array[Matrix]) extends Serializable {
+  
+  /** Number of gaussians in mixture */
+  def k: Int = weight.length
+
+  /** Maps given points to their cluster indices. */
+  def predict(points: RDD[Vector]): RDD[Int] = {
+    val responsibilityMatrix = predictMembership(points, mu, sigma, weight, k)
+    responsibilityMatrix.map(r => r.indexOf(r.max))
+  }
+  
+  /**
+   * Given the input vectors, return the membership value of each vector
+   * to all mixture components. 
+   */
+  def predictMembership(
+      points: RDD[Vector], 
+      mu: Array[Vector], 
+      sigma: Array[Matrix],
+      weight: Array[Double], 
+      k: Int): RDD[Array[Double]] = {
+    val sc = points.sparkContext
+    val dists = sc.broadcast {
+      (0 until k).map { i => 
+        new MultivariateGaussian(mu(i).toBreeze.toDenseVector, sigma(i).toBreeze.toDenseMatrix)
+      }.toArray
+    }
+    val weights = sc.broadcast(weight)
+    points.map { x => 
+      computeSoftAssignments(x.toBreeze.toDenseVector, dists.value, weights.value, k)
+    }
+  }
+  
+  /**
+   * Compute the partial assignments for each vector
+   */
+  private def computeSoftAssignments(
+      pt: BreezeVector[Double],
+      dists: Array[MultivariateGaussian],
+      weights: Array[Double],
+      k: Int): Array[Double] = {
+    val p = weights.zip(dists).map {
+      case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(pt)
+    }
+    val pSum = p.sum 
+    for (i <- 0 until k) {
+      p(i) /= pSum
+    }
+    p
+  }  
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6cf6fdf3/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
new file mode 100644
index 0000000..2eab5d2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/impl/MultivariateGaussian.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.stat.impl
+
+import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, Transpose, det, pinv}
+
+/** 
+   * Utility class to implement the density function for multivariate Gaussian distribution.
+   * Breeze provides this functionality, but it requires the Apache Commons Math library,
+   * so this class is here so-as to not introduce a new dependency in Spark.
+   */
+private[mllib] class MultivariateGaussian(
+    val mu: DBV[Double], 
+    val sigma: DBM[Double]) extends Serializable {
+  private val sigmaInv2 = pinv(sigma) * -0.5
+  private val U = math.pow(2.0 * math.Pi, -mu.length / 2.0) * math.pow(det(sigma), -0.5)
+    
+  /** Returns density of this multivariate Gaussian at given point, x */
+  def pdf(x: DBV[Double]): Double = {
+    val delta = x - mu
+    val deltaTranspose = new Transpose(delta)
+    U * math.exp(deltaTranspose * sigmaInv2 * delta)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6cf6fdf3/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index b0d05ae..1d07b5d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -39,7 +39,7 @@ import org.apache.spark.streaming.dstream.DStream
  */
 object MLUtils {
 
-  private[util] lazy val EPSILON = {
+  private[mllib] lazy val EPSILON = {
     var eps = 1.0
     while ((1.0 + (eps / 2.0)) != 1.0) {
       eps /= 2.0

http://git-wip-us.apache.org/repos/asf/spark/blob/6cf6fdf3/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
new file mode 100644
index 0000000..23feb82
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{Vectors, Matrices}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext {
+  test("single cluster") {
+    val data = sc.parallelize(Array(
+      Vectors.dense(6.0, 9.0),
+      Vectors.dense(5.0, 10.0),
+      Vectors.dense(4.0, 11.0)
+    ))
+    
+    // expectations
+    val Ew = 1.0
+    val Emu = Vectors.dense(5.0, 10.0)
+    val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
+    
+    val gmm = new GaussianMixtureEM().setK(1).run(data)
+                
+    assert(gmm.weight(0) ~== Ew absTol 1E-5)
+    assert(gmm.mu(0) ~== Emu absTol 1E-5)
+    assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
+  }
+  
+  test("two clusters") {
+    val data = sc.parallelize(Array(
+      Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+      Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+      Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+      Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+      Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+    ))
+  
+    // we set an initial gaussian to induce expected results
+    val initialGmm = new GaussianMixtureModel(
+      Array(0.5, 0.5),
+      Array(Vectors.dense(-1.0), Vectors.dense(1.0)),
+      Array(Matrices.dense(1, 1, Array(1.0)), Matrices.dense(1, 1, Array(1.0)))
+    )
+    
+    val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
+    val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
+    val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
+    
+    val gmm = new GaussianMixtureEM()
+      .setK(2)
+      .setInitialModel(initialGmm)
+      .run(data)
+      
+    assert(gmm.weight(0) ~== Ew(0) absTol 1E-3)
+    assert(gmm.weight(1) ~== Ew(1) absTol 1E-3)
+    assert(gmm.mu(0) ~== Emu(0) absTol 1E-3)
+    assert(gmm.mu(1) ~== Emu(1) absTol 1E-3)
+    assert(gmm.sigma(0) ~== Esigma(0) absTol 1E-3)
+    assert(gmm.sigma(1) ~== Esigma(1) absTol 1E-3)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message