spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6309] [SQL] [MLlib] Implement MatrixUDT
Date Fri, 20 Mar 2015 21:13:26 GMT
Repository: spark
Updated Branches:
  refs/heads/master 49a01c7ea -> 11e025956


[SPARK-6309] [SQL] [MLlib] Implement MatrixUDT

Utilities to serialize and deserialize Matrices in MLlib

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #5048 from MechCoder/spark-6309 and squashes the following commits:

05dc6f2 [MechCoder] Hashcode and organize imports
16d5d47 [MechCoder] Test some more
6e67020 [MechCoder] TST: Test using Array conversion instead of equals
7fa7a2c [MechCoder] [SPARK-6309] [SQL] [MLlib] Implement MatrixUDT


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/11e02595
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/11e02595
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/11e02595

Branch: refs/heads/master
Commit: 11e025956be3818c00effef0d650734f8feeb436
Parents: 49a01c7
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Authored: Fri Mar 20 17:13:18 2015 -0400
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Fri Mar 20 17:13:18 2015 -0400

----------------------------------------------------------------------
 .../apache/spark/mllib/linalg/Matrices.scala    | 90 ++++++++++++++++++++
 .../spark/mllib/linalg/MatricesSuite.scala      | 13 +++
 2 files changed, 103 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/11e02595/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index fdd8848..849f442 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet
=> MHash
 
 import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
 
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
 /**
  * Trait for a local matrix.
  */
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
 sealed trait Matrix extends Serializable {
 
   /** Number of rows. */
@@ -102,6 +108,88 @@ sealed trait Matrix extends Serializable {
   private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
 }
 
+@DeveloperApi
+private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
+
+  override def sqlType: StructType = {
+    // type: 0 = sparse, 1 = dense
+    // the dense matrix is built by numRows, numCols, values and isTransposed, all of which
are
+    // set as not nullable, except values since in the future, support for binary matrices
might
+    // be added for which values are not needed.
+    // the sparse matrix needs colPtrs and rowIndices, which are set as
+    // null, while building the dense matrix.
+    StructType(Seq(
+      StructField("type", ByteType, nullable = false),
+      StructField("numRows", IntegerType, nullable = false),
+      StructField("numCols", IntegerType, nullable = false),
+      StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
+      StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable =
true),
+      StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
+      StructField("isTransposed", BooleanType, nullable = false)
+      ))
+  }
+
+  override def serialize(obj: Any): Row = {
+    val row = new GenericMutableRow(7)
+    obj match {
+      case sm: SparseMatrix =>
+        row.setByte(0, 0)
+        row.setInt(1, sm.numRows)
+        row.setInt(2, sm.numCols)
+        row.update(3, sm.colPtrs.toSeq)
+        row.update(4, sm.rowIndices.toSeq)
+        row.update(5, sm.values.toSeq)
+        row.setBoolean(6, sm.isTransposed)
+
+      case dm: DenseMatrix =>
+        row.setByte(0, 1)
+        row.setInt(1, dm.numRows)
+        row.setInt(2, dm.numCols)
+        row.setNullAt(3)
+        row.setNullAt(4)
+        row.update(5, dm.values.toSeq)
+        row.setBoolean(6, dm.isTransposed)
+    }
+    row
+  }
+
+  override def deserialize(datum: Any): Matrix = {
+    datum match {
+      // TODO: something wrong with UDT serialization, should never happen.
+      case m: Matrix => m
+      case row: Row =>
+        require(row.length == 7,
+          s"MatrixUDT.deserialize given row with length ${row.length} but requires length
== 7")
+        val tpe = row.getByte(0)
+        val numRows = row.getInt(1)
+        val numCols = row.getInt(2)
+        val values = row.getAs[Iterable[Double]](5).toArray
+        val isTransposed = row.getBoolean(6)
+        tpe match {
+          case 0 =>
+            val colPtrs = row.getAs[Iterable[Int]](3).toArray
+            val rowIndices = row.getAs[Iterable[Int]](4).toArray
+            new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
+          case 1 =>
+            new DenseMatrix(numRows, numCols, values, isTransposed)
+        }
+    }
+  }
+
+  override def userClass: Class[Matrix] = classOf[Matrix]
+
+  override def equals(o: Any): Boolean = {
+    o match {
+      case v: MatrixUDT => true
+      case _ => false
+    }
+  }
+
+  override def hashCode(): Int = 1994
+
+  private[spark] override def asNullable: MatrixUDT = this
+}
+
 /**
  * Column-major dense matrix.
  * The entry values are stored in a single array of doubles with columns listed in sequence.
@@ -119,6 +207,7 @@ sealed trait Matrix extends Serializable {
  * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix
in
  *                     row major.
  */
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
 class DenseMatrix(
     val numRows: Int,
     val numCols: Int,
@@ -360,6 +449,7 @@ object DenseMatrix {
  *                     Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
  *                     and `rowIndices` behave as colIndices, and `values` are stored in
row major.
  */
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
 class SparseMatrix(
     val numRows: Int,
     val numCols: Int,

http://git-wip-us.apache.org/repos/asf/spark/blob/11e02595/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index c098b54..96f677d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -424,4 +424,17 @@ class MatricesSuite extends FunSuite {
     assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
     assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
   }
+
+  test("MatrixUDT") {
+    val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
+    val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
+    val dm3 = new DenseMatrix(0, 0, Array())
+    val sm1 = dm1.toSparse
+    val sm2 = dm2.toSparse
+    val sm3 = dm3.toSparse
+    val mUDT = new MatrixUDT()
+    Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach {
+        mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message