spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
Date Tue, 30 Jun 2015 21:08:21 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.3 0ce83db11 -> d720426e1


[SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k

I'm sorry that I made https://github.com/apache/spark/pull/6949 closed by mistake.
I pushed codes again.

And, I added a test code.

>
There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()`
It should have been `U.numCols() = k = svd.U.numCols()`

>
```
self = U * sigma * V.transpose
(m x n) = (m x n) * (k x k) * (k x n) //ASIS
-->
(m x n) = (m x k) * (k x k) * (k x n) //TOBE
```

Author: lee19 <lee19@live.co.kr>

Closes #6953 from lee19/MLlibBugfix and squashes the following commits:

c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden.
4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error.
c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib]
8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols
= k

(cherry picked from commit e72526227fdcf93b7a33375ef954746ac08753f5)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d720426e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d720426e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d720426e

Branch: refs/heads/branch-1.3
Commit: d720426e1f1961bd9e095cd3c5723a969f07a2dc
Parents: 0ce83db
Author: lee19 <lee19@live.co.kr>
Authored: Tue Jun 30 14:08:00 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Jun 30 14:08:18 2015 -0700

----------------------------------------------------------------------
 .../mllib/linalg/distributed/IndexedRowMatrix.scala      |  2 +-
 .../mllib/linalg/distributed/IndexedRowMatrixSuite.scala | 11 +++++++++++
 2 files changed, 12 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d720426e/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
index 3be530f..1c33b43 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala
@@ -146,7 +146,7 @@ class IndexedRowMatrix(
       val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
         IndexedRow(i, v)
       }
-      new IndexedRowMatrix(indexedRows, nRows, nCols)
+      new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
     } else {
       null
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/d720426e/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
index 4a7b99a..0ecb7a2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala
@@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext
{
     assert(closeToZero(U * brzDiag(s) * V.t - localA))
   }
 
+  test("validate matrix sizes of svd") {
+    val k = 2
+    val A = new IndexedRowMatrix(indexedRows)
+    val svd = A.computeSVD(k, computeU = true)
+    assert(svd.U.numRows() === m)
+    assert(svd.U.numCols() === k)
+    assert(svd.s.size === k)
+    assert(svd.V.numRows === n)
+    assert(svd.V.numCols === k)
+  }
+
   test("validate k in svd") {
     val A = new IndexedRowMatrix(indexedRows)
     intercept[IllegalArgumentException] {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message