spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-20615][ML][TEST] SparseVector.argmax throws IndexOutOfBoundsException
Date Tue, 09 May 2017 08:47:54 GMT
Repository: spark
Updated Branches:
  refs/heads/master 10b00abad -> be53a7835


[SPARK-20615][ML][TEST] SparseVector.argmax throws IndexOutOfBoundsException

## What changes were proposed in this pull request?

Added a check for for the number of defined values.  Previously the argmax function assumed
that at least one value was defined if the vector size was greater than zero.

## How was this patch tested?

Tests were added to the existing VectorsSuite to cover this case.

Author: Jon McLean <jon.mclean@atsid.com>

Closes #17877 from jonmclean/vectorArgmaxIndexBug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/be53a783
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/be53a783
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/be53a783

Branch: refs/heads/master
Commit: be53a78352ae7c70d8a07d0df24574b3e3129b4a
Parents: 10b00ab
Author: Jon McLean <jon.mclean@atsid.com>
Authored: Tue May 9 09:47:50 2017 +0100
Committer: Sean Owen <sowen@cloudera.com>
Committed: Tue May 9 09:47:50 2017 +0100

----------------------------------------------------------------------
 .../src/main/scala/org/apache/spark/ml/linalg/Vectors.scala   | 2 ++
 .../test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala  | 7 +++++++
 .../main/scala/org/apache/spark/mllib/linalg/Vectors.scala    | 2 ++
 .../scala/org/apache/spark/mllib/linalg/VectorsSuite.scala    | 7 +++++++
 4 files changed, 18 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/be53a783/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index 8e166ba..3fbc095 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -657,6 +657,8 @@ class SparseVector @Since("2.0.0") (
   override def argmax: Int = {
     if (size == 0) {
       -1
+    } else if (numActives == 0) {
+      0
     } else {
       // Find the max active entry.
       var maxIdx = indices(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/be53a783/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
index dfbdaf1..4cd91af 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala
@@ -125,6 +125,13 @@ class VectorsSuite extends SparkMLFunSuite {
 
     val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
     assert(vec8.argmax === 0)
+
+    // Check for case when sparse vector is non-empty but the values are empty
+    val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec9.argmax === 0)
+
+    val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec10.argmax === 0)
   }
 
   test("vector equals") {

http://git-wip-us.apache.org/repos/asf/spark/blob/be53a783/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 723addc..f063420 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -846,6 +846,8 @@ class SparseVector @Since("1.0.0") (
   override def argmax: Int = {
     if (size == 0) {
       -1
+    } else if (numActives == 0) {
+      0
     } else {
       // Find the max active entry.
       var maxIdx = indices(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/be53a783/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 71a3cea..6172cff 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -122,6 +122,13 @@ class VectorsSuite extends SparkFunSuite with Logging {
 
     val vec8 = Vectors.sparse(5, Array(1, 2), Array(0.0, -1.0))
     assert(vec8.argmax === 0)
+
+    // Check for case when sparse vector is non-empty but the values are empty
+    val vec9 = Vectors.sparse(100, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec9.argmax === 0)
+
+    val vec10 = Vectors.sparse(1, Array.empty[Int], Array.empty[Double]).asInstanceOf[SparseVector]
+    assert(vec10.argmax === 0)
   }
 
   test("vector equals") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message