spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ma...@apache.org
Subject [07/28] git commit: Scala classification and clustering stubs; matrix serialization/deserialization.
Date Thu, 26 Dec 2013 06:31:24 GMT
Scala classification and clustering stubs; matrix serialization/deserialization.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/f99970e8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/f99970e8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/f99970e8

Branch: refs/heads/master
Commit: f99970e8cdc85eae33999b57a4c5c1893fe3727a
Parents: 2328bdd
Author: Tor Myklebust <tmyklebu@gmail.com>
Authored: Fri Dec 20 00:12:22 2013 -0500
Committer: Tor Myklebust <tmyklebu@gmail.com>
Committed: Fri Dec 20 00:12:22 2013 -0500

----------------------------------------------------------------------
 .../apache/spark/mllib/api/PythonMLLibAPI.scala | 82 +++++++++++++++++++-
 1 file changed, 79 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/f99970e8/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
index c9bd7c6..bcf2f07 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
@@ -1,5 +1,7 @@
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.regression._
+import org.apache.spark.mllib.classification._
+import org.apache.spark.mllib.clustering._
 import org.apache.spark.rdd.RDD
 import java.nio.ByteBuffer
 import java.nio.ByteOrder
@@ -39,6 +41,52 @@ class PythonMLLibAPI extends Serializable {
     return bytes
   }
 
+  def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
+    val packetLength = bytes.length
+    if (packetLength < 24) {
+      throw new IllegalArgumentException("Byte array too short.")
+    }
+    val bb = ByteBuffer.wrap(bytes)
+    bb.order(ByteOrder.nativeOrder())
+    val magic = bb.getLong()
+    if (magic != 2) {
+      throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+    }
+    val rows = bb.getLong()
+    val cols = bb.getLong()
+    if (packetLength != 24 + 8 * rows * cols) {
+      throw new IllegalArgumentException("Size " + rows + "x" + cols + "is wrong.")
+    }
+    val db = bb.asDoubleBuffer()
+    val ans = new Array[Array[Double]](rows.toInt)
+    var i = 0
+    for (i <- 0 until rows.toInt) {
+      ans(i) = new Array[Double](cols.toInt)
+      db.get(ans(i))
+    }
+    return ans
+  }
+
+  def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
+    val rows = doubles.length
+    var cols = 0
+    if (rows > 0) {
+      cols = doubles(0).length
+    }
+    val bytes = new Array[Byte](24 + 8 * rows * cols)
+    val bb = ByteBuffer.wrap(bytes)
+    bb.order(ByteOrder.nativeOrder())
+    bb.putLong(2)
+    bb.putLong(rows)
+    bb.putLong(cols)
+    val db = bb.asDoubleBuffer()
+    var i = 0
+    for (i <- 0 until rows) {
+      db.put(doubles(i))
+    }
+    return bytes
+  }
+
   def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
       dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
       java.util.LinkedList[java.lang.Object] = {
@@ -60,7 +108,7 @@ class PythonMLLibAPI extends Serializable {
     return trainRegressionModel((data, initialWeights) =>
         LinearRegressionWithSGD.train(data, numIterations, stepSize,
                                       miniBatchFraction, initialWeights),
-        dataBytesJRDD, initialWeightsBA);
+        dataBytesJRDD, initialWeightsBA)
   }
 
   def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
@@ -69,7 +117,7 @@ class PythonMLLibAPI extends Serializable {
     return trainRegressionModel((data, initialWeights) =>
         LassoWithSGD.train(data, numIterations, stepSize, regParam,
                            miniBatchFraction, initialWeights),
-        dataBytesJRDD, initialWeightsBA);
+        dataBytesJRDD, initialWeightsBA)
   }
 
   def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
@@ -78,6 +126,34 @@ class PythonMLLibAPI extends Serializable {
     return trainRegressionModel((data, initialWeights) =>
         RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
                                      miniBatchFraction, initialWeights),
-        dataBytesJRDD, initialWeightsBA);
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  def trainSVMModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+      stepSize: Double, regParam: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        SVMWithSGD.train(data, numIterations, stepSize, regParam,
+                                     miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  def trainLogisticRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]],
+      numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        LogisticRegressionWithSGD.train(data, numIterations, stepSize,
+                                     miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA)
+  }
+
+  def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int,
+      maxIterations: Int, runs: Int, initializationMode: String):
+      java.util.List[java.lang.Object] = {
+    val data = dataBytesJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
+    val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
+    val ret = new java.util.LinkedList[java.lang.Object]()
+    ret.add(serializeDoubleMatrix(model.clusterCenters))
+    return ret
   }
 }


Mime
View raw message