spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector
Date Fri, 07 Aug 2015 21:51:08 GMT
Repository: spark
Updated Branches:
  refs/heads/master 881548ab2 -> e2fbbe731


[SPARK-8481] [MLLIB] GaussianMixtureModel predict accepting single vector

Resubmit of [https://github.com/apache/spark/pull/6906] for adding single-vec predict to GMMs

CC: dkobylarz  mengxr

To be merged with master and branch-1.5
Primary author: dkobylarz

Author: Dariusz Kobylarz <darek.kobylarz@gmail.com>

Closes #8039 from jkbradley/gmm-predict-vec and squashes the following commits:

bfbedc4 [Dariusz Kobylarz] [SPARK-8481] [MLlib] GaussianMixtureModel predict accepting single
vector


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e2fbbe73
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e2fbbe73
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e2fbbe73

Branch: refs/heads/master
Commit: e2fbbe73111d4624390f596a19a1799c86a05f6c
Parents: 881548a
Author: Dariusz Kobylarz <darek.kobylarz@gmail.com>
Authored: Fri Aug 7 14:51:03 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Fri Aug 7 14:51:03 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/clustering/GaussianMixtureModel.scala  | 13 +++++++++++++
 .../spark/mllib/clustering/GaussianMixtureSuite.scala  | 10 ++++++++++
 2 files changed, 23 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e2fbbe73/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
index cb807c8..76aeebd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala
@@ -66,6 +66,12 @@ class GaussianMixtureModel(
     responsibilityMatrix.map(r => r.indexOf(r.max))
   }
 
+  /** Maps given point to its cluster index. */
+  def predict(point: Vector): Int = {
+    val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+    r.indexOf(r.max)
+  }
+
   /** Java-friendly version of [[predict()]] */
   def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
     predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
@@ -84,6 +90,13 @@ class GaussianMixtureModel(
   }
 
   /**
+   * Given the input vector, return the membership values to all mixture components.
+   */
+  def predictSoft(point: Vector): Array[Double] = {
+    computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
+  }
+
+  /**
    * Compute the partial assignments for each vector
    */
   private def computeSoftAssignments(

http://git-wip-us.apache.org/repos/asf/spark/blob/e2fbbe73/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index b218d72..b636d02 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -148,6 +148,16 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
{
     }
   }
 
+  test("model prediction, parallel and local") {
+    val data = sc.parallelize(GaussianTestData.data)
+    val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
+
+    val batchPredictions = gmm.predict(data)
+    batchPredictions.zip(data).collect().foreach { case (batchPred, datum) =>
+      assert(batchPred === gmm.predict(datum))
+    }
+  }
+
   object GaussianTestData {
 
     val data = Array(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message