spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-11852][ML] StandardScaler minor refactor
Date Fri, 20 Nov 2015 17:55:55 GMT
Repository: spark
Updated Branches:
  refs/heads/master a66142dec -> 9ace2e5c8


[SPARK-11852][ML] StandardScaler minor refactor

```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9839 from yanboliang/standardScaler-refactor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9ace2e5c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9ace2e5c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9ace2e5c

Branch: refs/heads/master
Commit: 9ace2e5c8d7fbd360a93bc5fc4eace64a697b44f
Parents: a66142d
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Fri Nov 20 09:55:53 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Fri Nov 20 09:55:53 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/StandardScaler.scala       | 60 +++++++++-----------
 .../spark/ml/feature/StandardScalerSuite.scala  | 11 ++--
 2 files changed, 32 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9ace2e5c/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 6d54521..d76a9c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType}
 private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol
{
 
   /**
-   * Centers the data with mean before scaling.
+   * Whether to center the data with mean before scaling.
    * It will build a dense output, so this does not work on sparse input
    * and will raise an exception.
    * Default: false
    * @group param
    */
-  val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
+  val withMean: BooleanParam = new BooleanParam(this, "withMean",
+    "Whether to center data with mean")
+
+  /** @group getParam */
+  def getWithMean: Boolean = $(withMean)
 
   /**
-   * Scales the data to unit standard deviation.
+   * Whether to scale the data to unit standard deviation.
    * Default: true
    * @group param
    */
-  val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
+  val withStd: BooleanParam = new BooleanParam(this, "withStd",
+    "Whether to scale the data to unit standard deviation")
+
+  /** @group getParam */
+  def getWithStd: Boolean = $(withStd)
+
+  setDefault(withMean -> false, withStd -> true)
 }
 
 /**
@@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
 
   def this() = this(Identifiable.randomUID("stdScal"))
 
-  setDefault(withMean -> false, withStd -> true)
-
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 
@@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
     val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
     val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
     val scalerModel = scaler.fit(input)
-    copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
+    copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this))
   }
 
   override def transformSchema(schema: StructType): StructType = {
@@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler]
{
 /**
  * :: Experimental ::
  * Model fitted by [[StandardScaler]].
+ *
+ * @param std Standard deviation of the StandardScalerModel
+ * @param mean Mean of the StandardScalerModel
  */
 @Experimental
 class StandardScalerModel private[ml] (
     override val uid: String,
-    scaler: feature.StandardScalerModel)
+    val std: Vector,
+    val mean: Vector)
   extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {
 
   import StandardScalerModel._
 
-  /** Standard deviation of the StandardScalerModel */
-  val std: Vector = scaler.std
-
-  /** Mean of the StandardScalerModel */
-  val mean: Vector = scaler.mean
-
-  /** Whether to scale to unit standard deviation. */
-  @Since("1.6.0")
-  def getWithStd: Boolean = scaler.withStd
-
-  /** Whether to center data with mean. */
-  @Since("1.6.0")
-  def getWithMean: Boolean = scaler.withMean
-
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 
@@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (
 
   override def transform(dataset: DataFrame): DataFrame = {
     transformSchema(dataset.schema, logging = true)
+    val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
     val scale = udf { scaler.transform _ }
     dataset.withColumn($(outputCol), scale(col($(inputCol))))
   }
@@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
   }
 
   override def copy(extra: ParamMap): StandardScalerModel = {
-    val copied = new StandardScalerModel(uid, scaler)
+    val copied = new StandardScalerModel(uid, std, mean)
     copyValues(copied, extra).setParent(parent)
   }
 
@@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
   private[StandardScalerModel]
   class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {
 
-    private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
+    private case class Data(std: Vector, mean: Vector)
 
     override protected def saveImpl(path: String): Unit = {
       DefaultParamsWriter.saveMetadata(instance, path, sc)
-      val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
+      val data = Data(instance.std, instance.mean)
       val dataPath = new Path(path, "data").toString
       sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
@@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
     override def load(path: String): StandardScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
-      val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
-        sqlContext.read.parquet(dataPath)
-          .select("std", "mean", "withStd", "withMean")
-          .head()
-      // This is very likely to change in the future because withStd and withMean should
be params.
-      val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
-      val model = new StandardScalerModel(metadata.uid, oldModel)
+      val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
+        .select("std", "mean")
+        .head()
+      val model = new StandardScalerModel(metadata.uid, std, mean)
       DefaultParamsReader.getAndSetParams(model, metadata)
       model
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/9ace2e5c/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 49a4b2e..1eae125 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
 
   test("params") {
     ParamsSuite.checkParams(new StandardScaler)
-    val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
-    ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
+    ParamsSuite.checkParams(new StandardScalerModel("empty",
+      Vectors.dense(1.0), Vectors.dense(2.0)))
   }
 
   test("Standardization with default parameter") {
@@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
   }
 
   test("StandardScalerModel read/write") {
-    val oldModel = new feature.StandardScalerModel(
-      Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
-    val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
+    val instance = new StandardScalerModel("myStandardScalerModel",
+      Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0))
     val newInstance = testDefaultReadWrite(instance)
     assert(newInstance.std === instance.std)
     assert(newInstance.mean === instance.mean)
-    assert(newInstance.getWithStd === instance.getWithStd)
-    assert(newInstance.getWithMean === instance.getWithMean)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message