flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From trohrm...@apache.org
Subject [1/3] flink git commit: [ml] Adds convenience functions for Breeze matrix/vector conversion
Date Wed, 01 Apr 2015 09:21:16 GMT
Repository: flink
Updated Branches:
  refs/heads/master c63580244 -> d2e2d79fc


[ml] Adds convenience functions for Breeze matrix/vector conversion

[ml] Adds breeze to flink-dist LICENSE file

[ml] Optimizes sanity checks in vector/matrix accessors

[ml] Fixes scala check style error with missing whitespaces before and after +

[ml] Fixes DenseMatrixTest


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/5ddb2dd9
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/5ddb2dd9
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/5ddb2dd9

Branch: refs/heads/master
Commit: 5ddb2dd9634ab0908c99a08a1d0e10e761444120
Parents: 9219af7
Author: Till Rohrmann <trohrmann@apache.org>
Authored: Thu Mar 26 17:44:17 2015 +0100
Committer: Till Rohrmann <trohrmann@apache.org>
Committed: Wed Apr 1 10:56:47 2015 +0200

----------------------------------------------------------------------
 flink-dist/src/main/flink-bin/LICENSE           |  1 +
 .../scala/org/apache/flink/ml/math/Breeze.scala | 92 ++++++++++++++++++++
 .../org/apache/flink/ml/math/DenseMatrix.scala  |  4 +-
 .../org/apache/flink/ml/math/DenseVector.scala  |  6 +-
 .../org/apache/flink/ml/math/SparseMatrix.scala | 10 +--
 .../org/apache/flink/ml/math/SparseVector.scala |  6 +-
 .../org/apache/flink/ml/math/package.scala      | 82 ++++++++++++++---
 .../regression/MultipleLinearRegression.scala   | 12 +--
 .../apache/flink/ml/math/BreezeMathTest.scala   | 69 +++++++++++++++
 .../apache/flink/ml/math/DenseVectorTest.scala  |  2 +-
 .../apache/flink/ml/math/SparseMatrixTest.scala | 42 +++++++--
 .../apache/flink/ml/math/SparseVectorTest.scala | 18 +++-
 12 files changed, 308 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-dist/src/main/flink-bin/LICENSE
----------------------------------------------------------------------
diff --git a/flink-dist/src/main/flink-bin/LICENSE b/flink-dist/src/main/flink-bin/LICENSE
index d0b7fb4..8c733e4 100644
--- a/flink-dist/src/main/flink-bin/LICENSE
+++ b/flink-dist/src/main/flink-bin/LICENSE
@@ -250,6 +250,7 @@ under the Apache License (v 2.0):
  - Twitter Hosebird Client (hbc) (https://github.com/twitter/hbc)
  - Jettison (http://jettison.codehaus.org)
  - Akka (http://akka.io)
+ - Breeze (https://github.com/scalanlp/breeze)
  
 
 -----------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
new file mode 100644
index 0000000..dffb984
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import breeze.linalg.{ Matrix => BreezeMatrix, DenseMatrix => BreezeDenseMatrix,
+CSCMatrix => BreezeCSCMatrix, DenseVector => BreezeDenseVector, SparseVector =>
BreezeSparseVector,
+Vector => BreezeVector}
+
+/** This class contains convenience function to wrap a matrix/vector into a breeze matrix/vector
+  * and to unwrap it again.
+  *
+  */
+object Breeze {
+
+  implicit class Matrix2BreezeConverter(matrix: Matrix) {
+    def asBreeze: BreezeMatrix[Double] = {
+      matrix match {
+        case dense: DenseMatrix =>
+          new BreezeDenseMatrix[Double](
+            dense.numRows,
+            dense.numCols,
+            dense.data)
+
+        case sparse: SparseMatrix =>
+          new BreezeCSCMatrix[Double](
+            sparse.data,
+            sparse.numRows,
+            sparse.numCols,
+            sparse.colPtrs,
+            sparse.rowIndices
+          )
+      }
+    }
+  }
+
+  implicit class Breeze2MatrixConverter(matrix: BreezeMatrix[Double]) {
+    def fromBreeze: Matrix = {
+      matrix match {
+        case dense: BreezeDenseMatrix[Double] =>
+          new DenseMatrix(dense.rows, dense.cols, dense.data)
+
+        case sparse: BreezeCSCMatrix[Double] =>
+          new SparseMatrix(sparse.rows, sparse.cols, sparse.rowIndices, sparse.colPtrs, sparse.data)
+      }
+    }
+  }
+
+  implicit class BreezeArrayConverter[T](array: Array[T]) {
+    def asBreeze: BreezeDenseVector[T] = {
+      new BreezeDenseVector[T](array)
+    }
+  }
+
+  implicit class Breeze2VectorConverter(vector: BreezeVector[Double]) {
+    def fromBreeze: Vector = {
+      vector match {
+        case dense: BreezeDenseVector[Double] => new DenseVector(dense.data)
+
+        case sparse: BreezeSparseVector[Double] =>
+          new SparseVector(sparse.length, sparse.index, sparse.data)
+      }
+    }
+  }
+
+  implicit class Vector2BreezeConverter(vector: Vector) {
+    def asBreeze: BreezeVector[Double] = {
+      vector match {
+        case dense: DenseVector =>
+          new BreezeDenseVector[Double](dense.data)
+
+        case sparse: SparseVector =>
+          new BreezeSparseVector[Double](sparse.indices, sparse.data, sparse.size)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
index 72eae05..16291b8 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
@@ -134,8 +134,8 @@ case class DenseMatrix(val numRows: Int,
     * @return
     */
   private def locate(row: Int, col: Int): Int = {
-    require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
-    require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+    require(0 <= row && row < numRows && 0 <= col && col
< numCols,
+      (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
 
     row + col * numRows
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
index 6d41d47..50992a9 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
@@ -41,8 +41,7 @@ case class DenseVector(val data: Array[Double]) extends Vector {
    * @return element at the given index
    */
   override def apply(index: Int): Double = {
-    require(0 <= index && index < data.length, s"Index $index is out of bounds
" +
-      s"[0, ${data.length})")
+    require(0 <= index && index < data.length, index + " not in [0, " + data.length
+ ")")
     data(index)
   }
 
@@ -72,8 +71,7 @@ case class DenseVector(val data: Array[Double]) extends Vector {
     * @param value
     */
   override def update(index: Int, value: Double): Unit = {
-    require(0 <= index && index < data.length, s"Index $index is out of bounds
" +
-      s"[0, ${data.length})")
+    require(0 <= index && index < data.length, index + " not in [0, " + data.length
+ ")")
 
     data(index) = value
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
index a46202c..b065630 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
@@ -112,11 +112,11 @@ class SparseMatrix(
   }
 
   private def locate(row: Int, col: Int): Int = {
-    require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
-    require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+    require(0 <= row && row < numRows && 0 <= col && col
< numCols,
+      (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
 
     val startIndex = colPtrs(col)
-    val endIndex = colPtrs(col+1)
+    val endIndex = colPtrs(col + 1)
 
     java.util.Arrays.binarySearch(rowIndices, startIndex, endIndex, row)
   }
@@ -155,8 +155,8 @@ object SparseMatrix{
     val entryArray = entries.toArray
 
     entryArray.foreach{ case (row, col, _) =>
-        require(0 <= row && row < numRows, s"Row $row is out of bounds [0,
$numRows).")
-        require(0 <= col && col < numCols, s"Columm $col is out of bounds [0,
$numCols).")
+      require(0 <= row && row < numRows && 0 <= col && col
<= numCols,
+        (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
     }
 
     val COOOrdering = new Ordering[(Int, Int, Double)] {

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
index 93da362..9fa69cb 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
@@ -78,7 +78,7 @@ class SparseVector(
   }
 
   private def locate(index: Int): Int = {
-    require(0 <= index && index < size, s"Index $index is out of bounds [0,
$size).")
+    require(0 <= index && index < size, index + " not in [0, " + size + ")")
 
     java.util.Arrays.binarySearch(indices, 0, indices.length, index)
   }
@@ -107,6 +107,10 @@ object SparseVector {
   def fromCOO(size: Int, entries: Iterable[(Int, Double)]): SparseVector = {
     val entryArray = entries.toArray
 
+    entryArray.foreach { case (index, _) =>
+      require(0 <= index && index < size, index + " not in [0, " + size + ")")
+    }
+
     val COOOrdering = new Ordering[(Int, Double)] {
       override def compare(x: (Int, Double), y: (Int, Double)): Int = {
         x._1 - y._1

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
index 3ab6143..4c7f254 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
@@ -23,26 +23,88 @@ package org.apache.flink.ml
  * abstraction.
  */
 package object math {
-  implicit class RichMatrix(matrix: Matrix) extends Iterable[Double] {
+  implicit class RichMatrix(matrix: Matrix) extends Iterable[(Int, Int, Double)] {
 
-    override def iterator: Iterator[Double] = {
-      matrix match {
-        case dense: DenseMatrix => dense.data.iterator
+    override def iterator: Iterator[(Int, Int, Double)] = {
+      new Iterator[(Int, Int, Double)] {
+        var index = 0
+
+        override def hasNext: Boolean = {
+          index < matrix.numRows * matrix.numCols
+        }
+
+        override def next(): (Int, Int, Double) = {
+          val row = index % matrix.numRows
+          val column = index / matrix.numRows
+
+          index += 1
+
+          (row, column, matrix(row, column))
+        }
+      }
+    }
+
+    def valueIterator: Iterator[Double] = {
+      val it = iterator
+
+      new Iterator[Double] {
+        override def hasNext: Boolean = it.hasNext
+
+        override def next(): Double = it.next._3
       }
     }
+
   }
 
-  implicit class RichVector(vector: Vector) extends Iterable[Double] {
-    override def iterator: Iterator[Double] = {
-      vector match {
-        case dense: DenseVector => dense.data.iterator
+  implicit class RichVector(vector: Vector) extends Iterable[(Int, Double)] {
+
+    override def iterator: Iterator[(Int, Double)] = {
+      new Iterator[(Int, Double)] {
+        var index = 0
+
+        override def hasNext: Boolean = {
+          index < vector.size
+        }
+
+        override def next(): (Int, Double) = {
+          val resultIndex = index
+
+          index += 1
+
+          (resultIndex, vector(resultIndex))
+        }
+      }
+    }
+
+    def valueIterator: Iterator[Double] = {
+      val it = iterator
+
+      new Iterator[Double] {
+        override def hasNext: Boolean = it.hasNext
+
+        override def next(): Double = it.next._2
       }
     }
   }
 
-  implicit def vector2Array(vector: Vector): Array[Double] = {
+  /** Stores the vector values in a dense array
+    *
+    * @param vector
+    * @return Array containing the vector values
+    */
+  def vector2Array(vector: Vector): Array[Double] = {
     vector match {
-      case dense: DenseVector => dense.data
+      case dense: DenseVector => dense.data.clone
+
+      case sparse: SparseVector =>
+        val result = new Array[Double](sparse.size)
+
+        for((index, value) <- sparse) {
+          result(index) = value
+        }
+
+        result
+
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index 8060d2b..9768cce 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -24,6 +24,8 @@ import org.apache.flink.configuration.Configuration
 import org.apache.flink.ml.math.Vector
 import org.apache.flink.ml.common._
 
+import org.apache.flink.ml.math.vector2Array
+
 import org.apache.flink.api.scala._
 
 import com.github.fommil.netlib.BLAS.{ getInstance => blas }
@@ -283,14 +285,14 @@ private class SquaredResiduals extends RichMapFunction[LabeledVector,
Double] {
   }
 
   override def map(value: LabeledVector): Double = {
-    val vector = value.vector
+    val array = vector2Array(value.vector)
     val label = value.label
 
-    val dotProduct = blas.ddot(weightVector.length, weightVector, 1, vector, 1)
+    val dotProduct = blas.ddot(weightVector.length, weightVector, 1, array, 1)
 
     val residual = dotProduct + weight0 - label
 
-    residual*residual
+    residual * residual
   }
 }
 
@@ -322,7 +324,7 @@ RichMapFunction[LabeledVector, (Array[Double], Double, Int)] {
   }
 
   override def map(value: LabeledVector): (Array[Double], Double, Int) = {
-    val x = value.vector
+    val x = vector2Array(value.vector)
     val label = value.label
 
     val dotProduct = blas.ddot(weightVector.length, weightVector, 1, x, 1)
@@ -435,7 +437,7 @@ Transformer[ Vector, LabeledVector ] {
     }
 
     override def map(value: Vector): LabeledVector = {
-      val dotProduct = blas.ddot(weights.length, weights, 1, value, 1)
+      val dotProduct = blas.ddot(weights.length, weights, 1, vector2Array(value), 1)
 
       val prediction = dotProduct + weight0
 

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
new file mode 100644
index 0000000..7084f2a
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import Breeze._
+
+import org.junit.Test
+import org.scalatest.ShouldMatchers
+
+class BreezeMathTest extends ShouldMatchers {
+
+  @Test
+  def testBreezeDenseMatrixWrapping: Unit = {
+    val numRows = 5
+    val numCols = 4
+
+    val data = Array.range(0, numRows * numCols)
+    val expectedData = Array.range(0, numRows * numCols).map(_ * 2)
+
+    val denseMatrix = DenseMatrix(numRows, numCols, data)
+    val expectedMatrix = DenseMatrix(numRows, numCols, expectedData)
+
+    val m = denseMatrix.asBreeze
+
+    val result = (m * 2.0).fromBreeze
+
+    result should equal(expectedMatrix)
+  }
+
+  @Test
+  def testBreezeSparseMatrixWrapping: Unit = {
+    val numRows = 5
+    val numCols = 4
+
+    val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols,
+      (0, 1, 1),
+      (4, 3, 13),
+      (3, 2, 45),
+      (4, 0, 12))
+
+    val expectedMatrix = SparseMatrix.fromCOO(numRows, numCols,
+      (0, 1, 2),
+      (4, 3, 26),
+      (3, 2, 90),
+      (4, 0, 24))
+
+    val sm = sparseMatrix.asBreeze
+
+    val result = (sm * 2.0).fromBreeze
+
+    result should equal(expectedMatrix)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
index 5da9fe2..66a51fe 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
@@ -32,7 +32,7 @@ class DenseVectorTest extends ShouldMatchers {
 
     assertResult(data.length)(vector.size)
 
-    data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)}
+    data.zip(vector.map(_._2)).foreach{case (expected, actual) => assertResult(expected)(actual)}
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
index a0e1d27..7fcdf54 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
@@ -25,9 +25,14 @@ class SparseMatrixTest extends ShouldMatchers {
 
   @Test
   def testSparseMatrixFromCOO: Unit = {
-    val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1,
17),
+    val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
       (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
 
+    val numRows = 5
+    val numCols = 5
+
+    val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
+
     val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3,
88),
       (4, 2, 99), (1, 4, 91))
 
@@ -43,8 +48,22 @@ class SparseMatrixTest extends ShouldMatchers {
 
     sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
 
+    val dataMap = data.
+      map{ case (row, col, value) => (row, col) -> value }.
+      groupBy{_._1}.
+      mapValues{
+      entries =>
+        entries.map(_._2).reduce(_ + _)
+    }
+
+    for(row <- 0 until numRows; col <- 0 until numCols) {
+      sparseMatrix(row, col) should be(dataMap.getOrElse((row, col), 0))
+    }
+
+    // test access to defined field even though it was set to 0
     sparseMatrix(0, 1) = 10
 
+    // test that a non-defined field is not accessible
     intercept[IllegalArgumentException]{
       sparseMatrix(1, 1) = 1
     }
@@ -52,18 +71,29 @@ class SparseMatrixTest extends ShouldMatchers {
 
   @Test
   def testInvalidIndexAccess: Unit = {
-    val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
+    val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+      (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+
+    val numRows = 5
+    val numCols = 5
+
+    val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
 
     intercept[IllegalArgumentException] {
-      sparseVector(-1)
+      sparseMatrix(-1, 4)
     }
 
     intercept[IllegalArgumentException] {
-      sparseVector(5)
+      sparseMatrix(numRows, 0)
     }
 
-    sparseVector(0) should equal(0)
-    sparseVector(3) should equal(3)
+    intercept[IllegalArgumentException] {
+      sparseMatrix(0, numCols)
+    }
+
+    intercept[IllegalArgumentException] {
+      sparseMatrix(3, -1)
+    }
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
index 5e514c6..88d4878 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
@@ -25,7 +25,10 @@ class SparseVectorTest extends ShouldMatchers{
 
   @Test
   def testDataAfterInitialization: Unit = {
-    val sparseVector = SparseVector.fromCOO(5, (0, 1), (2, 0), (4, 42), (0, 3))
+    val data = List[(Int, Double)]((0, 1), (2, 0), (4, 42), (0, 3))
+    val size = 5
+    val sparseVector = SparseVector.fromCOO(size, data)
+
     val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
     val expectedDenseVector = DenseVector.zeros(5)
 
@@ -38,11 +41,22 @@ class SparseVectorTest extends ShouldMatchers{
     val denseVector = sparseVector.toDenseVector
 
     denseVector should equal(expectedDenseVector)
+
+    val dataMap = data.
+      groupBy{_._1}.
+      mapValues{
+      entries =>
+        entries.map(_._2).reduce(_ + _)
+    }
+
+    for(index <- 0 until size) {
+      sparseVector(index) should be(dataMap.getOrElse(index, 0))
+    }
   }
 
   @Test
   def testInvalidIndexAccess: Unit = {
-    val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 10), (3, 5))
+    val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
 
     intercept[IllegalArgumentException] {
       sparseVector(-1)


Mime
View raw message