Repository: flink
Updated Branches:
refs/heads/master c63580244 -> d2e2d79fc
[ml] Adds convenience functions for Breeze matrix/vector conversion
[ml] Adds breeze to flink-dist LICENSE file
[ml] Optimizes sanity checks in vector/matrix accessors
[ml] Fixes scala check style error with missing whitespaces before and after +
[ml] Fixes DenseMatrixTest
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/5ddb2dd9
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/5ddb2dd9
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/5ddb2dd9
Branch: refs/heads/master
Commit: 5ddb2dd9634ab0908c99a08a1d0e10e761444120
Parents: 9219af7
Author: Till Rohrmann <trohrmann@apache.org>
Authored: Thu Mar 26 17:44:17 2015 +0100
Committer: Till Rohrmann <trohrmann@apache.org>
Committed: Wed Apr 1 10:56:47 2015 +0200
----------------------------------------------------------------------
flink-dist/src/main/flink-bin/LICENSE | 1 +
.../scala/org/apache/flink/ml/math/Breeze.scala | 92 ++++++++++++++++++++
.../org/apache/flink/ml/math/DenseMatrix.scala | 4 +-
.../org/apache/flink/ml/math/DenseVector.scala | 6 +-
.../org/apache/flink/ml/math/SparseMatrix.scala | 10 +--
.../org/apache/flink/ml/math/SparseVector.scala | 6 +-
.../org/apache/flink/ml/math/package.scala | 82 ++++++++++++++---
.../regression/MultipleLinearRegression.scala | 12 +--
.../apache/flink/ml/math/BreezeMathTest.scala | 69 +++++++++++++++
.../apache/flink/ml/math/DenseVectorTest.scala | 2 +-
.../apache/flink/ml/math/SparseMatrixTest.scala | 42 +++++++--
.../apache/flink/ml/math/SparseVectorTest.scala | 18 +++-
12 files changed, 308 insertions(+), 36 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-dist/src/main/flink-bin/LICENSE
----------------------------------------------------------------------
diff --git a/flink-dist/src/main/flink-bin/LICENSE b/flink-dist/src/main/flink-bin/LICENSE
index d0b7fb4..8c733e4 100644
--- a/flink-dist/src/main/flink-bin/LICENSE
+++ b/flink-dist/src/main/flink-bin/LICENSE
@@ -250,6 +250,7 @@ under the Apache License (v 2.0):
- Twitter Hosebird Client (hbc) (https://github.com/twitter/hbc)
- Jettison (http://jettison.codehaus.org)
- Akka (http://akka.io)
+ - Breeze (https://github.com/scalanlp/breeze)
-----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
new file mode 100644
index 0000000..dffb984
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import breeze.linalg.{ Matrix => BreezeMatrix, DenseMatrix => BreezeDenseMatrix,
+CSCMatrix => BreezeCSCMatrix, DenseVector => BreezeDenseVector, SparseVector =>
BreezeSparseVector,
+Vector => BreezeVector}
+
+/** This class contains convenience function to wrap a matrix/vector into a breeze matrix/vector
+ * and to unwrap it again.
+ *
+ */
+object Breeze {
+
+ implicit class Matrix2BreezeConverter(matrix: Matrix) {
+ def asBreeze: BreezeMatrix[Double] = {
+ matrix match {
+ case dense: DenseMatrix =>
+ new BreezeDenseMatrix[Double](
+ dense.numRows,
+ dense.numCols,
+ dense.data)
+
+ case sparse: SparseMatrix =>
+ new BreezeCSCMatrix[Double](
+ sparse.data,
+ sparse.numRows,
+ sparse.numCols,
+ sparse.colPtrs,
+ sparse.rowIndices
+ )
+ }
+ }
+ }
+
+ implicit class Breeze2MatrixConverter(matrix: BreezeMatrix[Double]) {
+ def fromBreeze: Matrix = {
+ matrix match {
+ case dense: BreezeDenseMatrix[Double] =>
+ new DenseMatrix(dense.rows, dense.cols, dense.data)
+
+ case sparse: BreezeCSCMatrix[Double] =>
+ new SparseMatrix(sparse.rows, sparse.cols, sparse.rowIndices, sparse.colPtrs, sparse.data)
+ }
+ }
+ }
+
+ implicit class BreezeArrayConverter[T](array: Array[T]) {
+ def asBreeze: BreezeDenseVector[T] = {
+ new BreezeDenseVector[T](array)
+ }
+ }
+
+ implicit class Breeze2VectorConverter(vector: BreezeVector[Double]) {
+ def fromBreeze: Vector = {
+ vector match {
+ case dense: BreezeDenseVector[Double] => new DenseVector(dense.data)
+
+ case sparse: BreezeSparseVector[Double] =>
+ new SparseVector(sparse.length, sparse.index, sparse.data)
+ }
+ }
+ }
+
+ implicit class Vector2BreezeConverter(vector: Vector) {
+ def asBreeze: BreezeVector[Double] = {
+ vector match {
+ case dense: DenseVector =>
+ new BreezeDenseVector[Double](dense.data)
+
+ case sparse: SparseVector =>
+ new BreezeSparseVector[Double](sparse.indices, sparse.data, sparse.size)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
index 72eae05..16291b8 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
@@ -134,8 +134,8 @@ case class DenseMatrix(val numRows: Int,
* @return
*/
private def locate(row: Int, col: Int): Int = {
- require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
- require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+ require(0 <= row && row < numRows && 0 <= col && col
< numCols,
+ (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
row + col * numRows
}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
index 6d41d47..50992a9 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
@@ -41,8 +41,7 @@ case class DenseVector(val data: Array[Double]) extends Vector {
* @return element at the given index
*/
override def apply(index: Int): Double = {
- require(0 <= index && index < data.length, s"Index $index is out of bounds
" +
- s"[0, ${data.length})")
+ require(0 <= index && index < data.length, index + " not in [0, " + data.length
+ ")")
data(index)
}
@@ -72,8 +71,7 @@ case class DenseVector(val data: Array[Double]) extends Vector {
* @param value
*/
override def update(index: Int, value: Double): Unit = {
- require(0 <= index && index < data.length, s"Index $index is out of bounds
" +
- s"[0, ${data.length})")
+ require(0 <= index && index < data.length, index + " not in [0, " + data.length
+ ")")
data(index) = value
}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
index a46202c..b065630 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
@@ -112,11 +112,11 @@ class SparseMatrix(
}
private def locate(row: Int, col: Int): Int = {
- require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
- require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+ require(0 <= row && row < numRows && 0 <= col && col
< numCols,
+ (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
val startIndex = colPtrs(col)
- val endIndex = colPtrs(col+1)
+ val endIndex = colPtrs(col + 1)
java.util.Arrays.binarySearch(rowIndices, startIndex, endIndex, row)
}
@@ -155,8 +155,8 @@ object SparseMatrix{
val entryArray = entries.toArray
entryArray.foreach{ case (row, col, _) =>
- require(0 <= row && row < numRows, s"Row $row is out of bounds [0,
$numRows).")
- require(0 <= col && col < numCols, s"Columm $col is out of bounds [0,
$numCols).")
+ require(0 <= row && row < numRows && 0 <= col && col
<= numCols,
+ (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
}
val COOOrdering = new Ordering[(Int, Int, Double)] {
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
index 93da362..9fa69cb 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
@@ -78,7 +78,7 @@ class SparseVector(
}
private def locate(index: Int): Int = {
- require(0 <= index && index < size, s"Index $index is out of bounds [0,
$size).")
+ require(0 <= index && index < size, index + " not in [0, " + size + ")")
java.util.Arrays.binarySearch(indices, 0, indices.length, index)
}
@@ -107,6 +107,10 @@ object SparseVector {
def fromCOO(size: Int, entries: Iterable[(Int, Double)]): SparseVector = {
val entryArray = entries.toArray
+ entryArray.foreach { case (index, _) =>
+ require(0 <= index && index < size, index + " not in [0, " + size + ")")
+ }
+
val COOOrdering = new Ordering[(Int, Double)] {
override def compare(x: (Int, Double), y: (Int, Double)): Int = {
x._1 - y._1
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
index 3ab6143..4c7f254 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
@@ -23,26 +23,88 @@ package org.apache.flink.ml
* abstraction.
*/
package object math {
- implicit class RichMatrix(matrix: Matrix) extends Iterable[Double] {
+ implicit class RichMatrix(matrix: Matrix) extends Iterable[(Int, Int, Double)] {
- override def iterator: Iterator[Double] = {
- matrix match {
- case dense: DenseMatrix => dense.data.iterator
+ override def iterator: Iterator[(Int, Int, Double)] = {
+ new Iterator[(Int, Int, Double)] {
+ var index = 0
+
+ override def hasNext: Boolean = {
+ index < matrix.numRows * matrix.numCols
+ }
+
+ override def next(): (Int, Int, Double) = {
+ val row = index % matrix.numRows
+ val column = index / matrix.numRows
+
+ index += 1
+
+ (row, column, matrix(row, column))
+ }
+ }
+ }
+
+ def valueIterator: Iterator[Double] = {
+ val it = iterator
+
+ new Iterator[Double] {
+ override def hasNext: Boolean = it.hasNext
+
+ override def next(): Double = it.next._3
}
}
+
}
- implicit class RichVector(vector: Vector) extends Iterable[Double] {
- override def iterator: Iterator[Double] = {
- vector match {
- case dense: DenseVector => dense.data.iterator
+ implicit class RichVector(vector: Vector) extends Iterable[(Int, Double)] {
+
+ override def iterator: Iterator[(Int, Double)] = {
+ new Iterator[(Int, Double)] {
+ var index = 0
+
+ override def hasNext: Boolean = {
+ index < vector.size
+ }
+
+ override def next(): (Int, Double) = {
+ val resultIndex = index
+
+ index += 1
+
+ (resultIndex, vector(resultIndex))
+ }
+ }
+ }
+
+ def valueIterator: Iterator[Double] = {
+ val it = iterator
+
+ new Iterator[Double] {
+ override def hasNext: Boolean = it.hasNext
+
+ override def next(): Double = it.next._2
}
}
}
- implicit def vector2Array(vector: Vector): Array[Double] = {
+ /** Stores the vector values in a dense array
+ *
+ * @param vector
+ * @return Array containing the vector values
+ */
+ def vector2Array(vector: Vector): Array[Double] = {
vector match {
- case dense: DenseVector => dense.data
+ case dense: DenseVector => dense.data.clone
+
+ case sparse: SparseVector =>
+ val result = new Array[Double](sparse.size)
+
+ for((index, value) <- sparse) {
+ result(index) = value
+ }
+
+ result
+
}
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index 8060d2b..9768cce 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -24,6 +24,8 @@ import org.apache.flink.configuration.Configuration
import org.apache.flink.ml.math.Vector
import org.apache.flink.ml.common._
+import org.apache.flink.ml.math.vector2Array
+
import org.apache.flink.api.scala._
import com.github.fommil.netlib.BLAS.{ getInstance => blas }
@@ -283,14 +285,14 @@ private class SquaredResiduals extends RichMapFunction[LabeledVector,
Double] {
}
override def map(value: LabeledVector): Double = {
- val vector = value.vector
+ val array = vector2Array(value.vector)
val label = value.label
- val dotProduct = blas.ddot(weightVector.length, weightVector, 1, vector, 1)
+ val dotProduct = blas.ddot(weightVector.length, weightVector, 1, array, 1)
val residual = dotProduct + weight0 - label
- residual*residual
+ residual * residual
}
}
@@ -322,7 +324,7 @@ RichMapFunction[LabeledVector, (Array[Double], Double, Int)] {
}
override def map(value: LabeledVector): (Array[Double], Double, Int) = {
- val x = value.vector
+ val x = vector2Array(value.vector)
val label = value.label
val dotProduct = blas.ddot(weightVector.length, weightVector, 1, x, 1)
@@ -435,7 +437,7 @@ Transformer[ Vector, LabeledVector ] {
}
override def map(value: Vector): LabeledVector = {
- val dotProduct = blas.ddot(weights.length, weights, 1, value, 1)
+ val dotProduct = blas.ddot(weights.length, weights, 1, vector2Array(value), 1)
val prediction = dotProduct + weight0
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
new file mode 100644
index 0000000..7084f2a
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import Breeze._
+
+import org.junit.Test
+import org.scalatest.ShouldMatchers
+
+class BreezeMathTest extends ShouldMatchers {
+
+ @Test
+ def testBreezeDenseMatrixWrapping: Unit = {
+ val numRows = 5
+ val numCols = 4
+
+ val data = Array.range(0, numRows * numCols)
+ val expectedData = Array.range(0, numRows * numCols).map(_ * 2)
+
+ val denseMatrix = DenseMatrix(numRows, numCols, data)
+ val expectedMatrix = DenseMatrix(numRows, numCols, expectedData)
+
+ val m = denseMatrix.asBreeze
+
+ val result = (m * 2.0).fromBreeze
+
+ result should equal(expectedMatrix)
+ }
+
+ @Test
+ def testBreezeSparseMatrixWrapping: Unit = {
+ val numRows = 5
+ val numCols = 4
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols,
+ (0, 1, 1),
+ (4, 3, 13),
+ (3, 2, 45),
+ (4, 0, 12))
+
+ val expectedMatrix = SparseMatrix.fromCOO(numRows, numCols,
+ (0, 1, 2),
+ (4, 3, 26),
+ (3, 2, 90),
+ (4, 0, 24))
+
+ val sm = sparseMatrix.asBreeze
+
+ val result = (sm * 2.0).fromBreeze
+
+ result should equal(expectedMatrix)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
index 5da9fe2..66a51fe 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
@@ -32,7 +32,7 @@ class DenseVectorTest extends ShouldMatchers {
assertResult(data.length)(vector.size)
- data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)}
+ data.zip(vector.map(_._2)).foreach{case (expected, actual) => assertResult(expected)(actual)}
}
@Test
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
index a0e1d27..7fcdf54 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
@@ -25,9 +25,14 @@ class SparseMatrixTest extends ShouldMatchers {
@Test
def testSparseMatrixFromCOO: Unit = {
- val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1,
17),
+ val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
(3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+ val numRows = 5
+ val numCols = 5
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
+
val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3,
88),
(4, 2, 99), (1, 4, 91))
@@ -43,8 +48,22 @@ class SparseMatrixTest extends ShouldMatchers {
sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
+ val dataMap = data.
+ map{ case (row, col, value) => (row, col) -> value }.
+ groupBy{_._1}.
+ mapValues{
+ entries =>
+ entries.map(_._2).reduce(_ + _)
+ }
+
+ for(row <- 0 until numRows; col <- 0 until numCols) {
+ sparseMatrix(row, col) should be(dataMap.getOrElse((row, col), 0))
+ }
+
+ // test access to defined field even though it was set to 0
sparseMatrix(0, 1) = 10
+ // test that a non-defined field is not accessible
intercept[IllegalArgumentException]{
sparseMatrix(1, 1) = 1
}
@@ -52,18 +71,29 @@ class SparseMatrixTest extends ShouldMatchers {
@Test
def testInvalidIndexAccess: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
+ val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+
+ val numRows = 5
+ val numCols = 5
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
intercept[IllegalArgumentException] {
- sparseVector(-1)
+ sparseMatrix(-1, 4)
}
intercept[IllegalArgumentException] {
- sparseVector(5)
+ sparseMatrix(numRows, 0)
}
- sparseVector(0) should equal(0)
- sparseVector(3) should equal(3)
+ intercept[IllegalArgumentException] {
+ sparseMatrix(0, numCols)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseMatrix(3, -1)
+ }
}
@Test
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
index 5e514c6..88d4878 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
@@ -25,7 +25,10 @@ class SparseVectorTest extends ShouldMatchers{
@Test
def testDataAfterInitialization: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (0, 1), (2, 0), (4, 42), (0, 3))
+ val data = List[(Int, Double)]((0, 1), (2, 0), (4, 42), (0, 3))
+ val size = 5
+ val sparseVector = SparseVector.fromCOO(size, data)
+
val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
val expectedDenseVector = DenseVector.zeros(5)
@@ -38,11 +41,22 @@ class SparseVectorTest extends ShouldMatchers{
val denseVector = sparseVector.toDenseVector
denseVector should equal(expectedDenseVector)
+
+ val dataMap = data.
+ groupBy{_._1}.
+ mapValues{
+ entries =>
+ entries.map(_._2).reduce(_ + _)
+ }
+
+ for(index <- 0 until size) {
+ sparseVector(index) should be(dataMap.getOrElse(index, 0))
+ }
}
@Test
def testInvalidIndexAccess: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 10), (3, 5))
+ val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
intercept[IllegalArgumentException] {
sparseVector(-1)
|