spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pwend...@apache.org
Subject [2/9] git commit: Added unit tests for bulk prediction in MatrixFactorizationModel
Date Wed, 08 Jan 2014 00:57:20 GMT
Added unit tests for bulk prediction in MatrixFactorizationModel


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/2c1cba85
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/2c1cba85
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/2c1cba85

Branch: refs/heads/master
Commit: 2c1cba851c2954bacf10006c0d5dad67aba77ab5
Parents: 67f937e
Author: Hossein Falaki <falaki@gmail.com>
Authored: Fri Jan 3 15:35:20 2014 -0800
Committer: Hossein Falaki <falaki@gmail.com>
Committed: Fri Jan 3 15:35:20 2014 -0800

----------------------------------------------------------------------
 .../spark/mllib/recommendation/ALSSuite.scala   | 33 ++++++++++++++++++--
 1 file changed, 31 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/2c1cba85/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index fafc5ec..e683a90 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -90,18 +90,34 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
     testALS(50, 100, 1, 15, 0.7, 0.3)
   }
 
+  test("rank-1 matrices bulk") {
+    testALS(50, 100, 1, 15, 0.7, 0.3, false, true)
+  }
+
   test("rank-2 matrices") {
     testALS(100, 200, 2, 15, 0.7, 0.3)
   }
 
+  test("rank-2 matrices bulk") {
+    testALS(100, 200, 2, 15, 0.7, 0.3, false, true)
+  }
+
   test("rank-1 matrices implicit") {
     testALS(80, 160, 1, 15, 0.7, 0.4, true)
   }
 
+  test("rank-1 matrices implicit bulk") {
+    testALS(80, 160, 1, 15, 0.7, 0.4, true, true)
+  }
+
   test("rank-2 matrices implicit") {
     testALS(100, 200, 2, 15, 0.7, 0.4, true)
   }
 
+  test("rank-2 matrices implicit bulk") {
+    testALS(100, 200, 2, 15, 0.7, 0.4, true, true)
+  }
+
   /**
    * Test if we can correctly factorize R = U * P where U and P are of known rank.
    *
@@ -111,9 +127,12 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
    * @param iterations     number of iterations to run
    * @param samplingRate   what fraction of the user-product pairs are known
    * @param matchThreshold max difference allowed to consider a predicted rating correct
+   * @param implicitPrefs  flag to test implicit feedback
+   * @param bulkPredict    flag to test bulk prediciton
    */
   def testALS(users: Int, products: Int, features: Int, iterations: Int,
-    samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false)
+    samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
+    bulkPredict: Boolean = false)
   {
     val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
       features, samplingRate, implicitPrefs)
@@ -130,7 +149,17 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll {
     for ((p, vec) <- model.productFeatures.collect(); i <- 0 until features) {
       predictedP.put(p, i, vec(i))
     }
-    val predictedRatings = predictedU.mmul(predictedP.transpose)
+    val predictedRatings = bulkPredict match {
+      case false => predictedU.mmul(predictedP.transpose)
+      case true =>
+        val allRatings = new DoubleMatrix(users, products)
+        val usersProducts = for (u <- 0 until users; p <- 0 until products) yield (u,
p)
+        val userProductsRDD = sc.parallelize(usersProducts)
+        model.predict(userProductsRDD).collect().foreach { elem =>
+          allRatings.put(elem.user, elem.product, elem.rating)
+        }
+        allRatings
+    }
 
     if (!implicitPrefs) {
       for (u <- 0 until users; p <- 0 until products) {


Mime
View raw message