spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-8265] [MLLIB] [PYSPARK] Add LinearDataGenerator to pyspark.mllib.utils
Date Tue, 23 Jun 2015 19:43:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master 2b1111dd0 -> f2022fa0d


[SPARK-8265] [MLLIB] [PYSPARK] Add LinearDataGenerator to pyspark.mllib.utils

It is useful to generate linear data for easy testing of linear models and in general. Scala
already has it. This is just a wrapper around the Scala code.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #6715 from MechCoder/generate_linear_input and squashes the following commits:

6182884 [MechCoder] Minor changes
8bda047 [MechCoder] Minor style fixes
0f1053c [MechCoder] [SPARK-8265] Add LinearDataGenerator to pyspark.mllib.utils


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f2022fa0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f2022fa0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f2022fa0

Branch: refs/heads/master
Commit: f2022fa0d375c804eca7803e172543b23ecbb9b7
Parents: 2b1111d
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Authored: Tue Jun 23 12:43:32 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Jun 23 12:43:32 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 32 +++++++++++++++++-
 python/pyspark/mllib/tests.py                   | 22 ++++++++++--
 python/pyspark/mllib/util.py                    | 35 ++++++++++++++++++++
 3 files changed, 86 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f2022fa0/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f9a271f..c4bea7c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -51,6 +51,7 @@ import org.apache.spark.mllib.tree.loss.Losses
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel}
 import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
 import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.util.LinearDataGenerator
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.storage.StorageLevel
@@ -972,7 +973,7 @@ private[python] class PythonMLLibAPI extends Serializable {
   def estimateKernelDensity(
       sample: JavaRDD[Double],
       bandwidth: Double, points: java.util.ArrayList[Double]): Array[Double] = {
-    return new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
+    new KernelDensity().setSample(sample).setBandwidth(bandwidth).estimate(
       points.asScala.toArray)
   }
 
@@ -991,6 +992,35 @@ private[python] class PythonMLLibAPI extends Serializable {
       List[AnyRef](model.clusterCenters, Vectors.dense(model.clusterWeights)).asJava
   }
 
+  /**
+   * Wrapper around the generateLinearInput method of LinearDataGenerator.
+   */
+  def generateLinearInputWrapper(
+      intercept: Double,
+      weights: JList[Double],
+      xMean: JList[Double],
+      xVariance: JList[Double],
+      nPoints: Int,
+      seed: Int,
+      eps: Double): Array[LabeledPoint] = {
+    LinearDataGenerator.generateLinearInput(
+      intercept, weights.asScala.toArray, xMean.asScala.toArray,
+      xVariance.asScala.toArray, nPoints, seed, eps).toArray
+  }
+
+  /**
+   * Wrapper around the generateLinearRDD method of LinearDataGenerator.
+   */
+  def generateLinearRDDWrapper(
+      sc: JavaSparkContext,
+      nexamples: Int,
+      nfeatures: Int,
+      eps: Double,
+      nparts: Int,
+      intercept: Double): JavaRDD[LabeledPoint] = {
+    LinearDataGenerator.generateLinearRDD(
+      sc, nexamples, nfeatures, eps, nparts, intercept)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f2022fa0/python/pyspark/mllib/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index c8d61b9..509faa1 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -49,8 +49,8 @@ from pyspark.mllib.random import RandomRDDs
 from pyspark.mllib.stat import Statistics
 from pyspark.mllib.feature import Word2Vec
 from pyspark.mllib.feature import IDF
-from pyspark.mllib.feature import StandardScaler
-from pyspark.mllib.feature import ElementwiseProduct
+from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
+from pyspark.mllib.util import LinearDataGenerator
 from pyspark.serializers import PickleSerializer
 from pyspark.streaming import StreamingContext
 from pyspark.sql import SQLContext
@@ -1019,6 +1019,24 @@ class StreamingKMeansTest(MLLibStreamingTestCase):
         self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]])
 
 
+class LinearDataGeneratorTests(MLlibTestCase):
+    def test_dim(self):
+        linear_data = LinearDataGenerator.generateLinearInput(
+            intercept=0.0, weights=[0.0, 0.0, 0.0],
+            xMean=[0.0, 0.0, 0.0], xVariance=[0.33, 0.33, 0.33],
+            nPoints=4, seed=0, eps=0.1)
+        self.assertEqual(len(linear_data), 4)
+        for point in linear_data:
+            self.assertEqual(len(point.features), 3)
+
+        linear_data = LinearDataGenerator.generateLinearRDD(
+            sc=sc, nexamples=6, nfeatures=2, eps=0.1,
+            nParts=2, intercept=0.0).collect()
+        self.assertEqual(len(linear_data), 6)
+        for point in linear_data:
+            self.assertEqual(len(point.features), 2)
+
+
 if __name__ == "__main__":
     if not _have_scipy:
         print("NOTE: Skipping SciPy tests as it does not seem to be installed")

http://git-wip-us.apache.org/repos/asf/spark/blob/f2022fa0/python/pyspark/mllib/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 16a90db..3482383 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -257,6 +257,41 @@ class JavaLoader(Loader):
         return cls(java_model)
 
 
+class LinearDataGenerator(object):
+    """Utils for generating linear data"""
+
+    @staticmethod
+    def generateLinearInput(intercept, weights, xMean, xVariance,
+                            nPoints, seed, eps):
+        """
+        :param: intercept bias factor, the term c in X'w + c
+        :param: weights   feature vector, the term w in X'w + c
+        :param: xMean     Point around which the data X is centered.
+        :param: xVariance Variance of the given data
+        :param: nPoints   Number of points to be generated
+        :param: seed      Random Seed
+        :param: eps       Used to scale the noise. If eps is set high,
+                          the amount of gaussian noise added is more.
+        Returns a list of LabeledPoints of length nPoints
+        """
+        weights = [float(weight) for weight in weights]
+        xMean = [float(mean) for mean in xMean]
+        xVariance = [float(var) for var in xVariance]
+        return list(callMLlibFunc(
+            "generateLinearInputWrapper", float(intercept), weights, xMean,
+            xVariance, int(nPoints), int(seed), float(eps)))
+
+    @staticmethod
+    def generateLinearRDD(sc, nexamples, nfeatures, eps,
+                          nParts=2, intercept=0.0):
+        """
+        Generate a RDD of LabeledPoints.
+        """
+        return callMLlibFunc(
+            "generateLinearRDDWrapper", sc, int(nexamples), int(nfeatures),
+            float(eps), int(nParts), float(intercept))
+
+
 def _test():
     import doctest
     from pyspark.context import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message