spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6787][ML] add read/write to estimators under ml.feature (1)
Date Wed, 18 Nov 2015 23:48:01 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 d9945bc46 -> dc1e23744


[SPARK-6787][ML] add read/write to estimators under ml.feature (1)

Add read/write support to the following estimators under spark.ml:

* CountVectorizer
* IDF
* MinMaxScaler
* StandardScaler (a little awkward because we store some params in spark.mllib model)
* StringIndexer

Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable`
and `DefaultParamsWritable` to save some boilerplate code, though we still need to override
`load` for Java compatibility.

jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #9798 from mengxr/SPARK-6787.

(cherry picked from commit 7e987de1770f4ab3d54bc05db8de0a1ef035941d)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dc1e2374
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dc1e2374
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dc1e2374

Branch: refs/heads/branch-1.6
Commit: dc1e23744b7fc1b8ee5fac07cf56d5760d66503e
Parents: d9945bc
Author: Xiangrui Meng <meng@databricks.com>
Authored: Wed Nov 18 15:47:49 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Nov 18 15:47:57 2015 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/CountVectorizer.scala      | 72 ++++++++++++++++--
 .../scala/org/apache/spark/ml/feature/IDF.scala | 71 +++++++++++++++++-
 .../apache/spark/ml/feature/MinMaxScaler.scala  | 72 ++++++++++++++++--
 .../spark/ml/feature/StandardScaler.scala       | 78 +++++++++++++++++++-
 .../apache/spark/ml/feature/StringIndexer.scala | 70 ++++++++++++++++--
 .../spark/ml/feature/CountVectorizerSuite.scala | 24 +++++-
 .../org/apache/spark/ml/feature/IDFSuite.scala  | 19 ++++-
 .../spark/ml/feature/MinMaxScalerSuite.scala    | 25 ++++++-
 .../spark/ml/feature/StandardScalerSuite.scala  | 64 +++++++++++-----
 .../spark/ml/feature/StringIndexerSuite.scala   | 19 ++++-
 10 files changed, 467 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 49028e4..5ff9bfb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -16,17 +16,19 @@
  */
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.DataFrame
 import org.apache.spark.util.collection.OpenHashMap
 
 /**
@@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol
wit
  */
 @Experimental
 class CountVectorizer(override val uid: String)
-  extends Estimator[CountVectorizerModel] with CountVectorizerParams {
+  extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable {
 
   def this() = this(Identifiable.randomUID("cntVec"))
 
@@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String)
   }
 
   override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object CountVectorizer extends Readable[CountVectorizer] {
+
+  @Since("1.6.0")
+  override def read: Reader[CountVectorizer] = new DefaultParamsReader
+
+  @Since("1.6.0")
+  override def load(path: String): CountVectorizer = super.load(path)
 }
 
 /**
@@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String)
  */
 @Experimental
 class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
-  extends Model[CountVectorizerModel] with CountVectorizerParams {
+  extends Model[CountVectorizerModel] with CountVectorizerParams with Writable {
+
+  import CountVectorizerModel._
 
   def this(vocabulary: Array[String]) = {
     this(Identifiable.randomUID("cntVecModel"), vocabulary)
@@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary:
Array[Strin
     val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
     copyValues(copied, extra)
   }
+
+  @Since("1.6.0")
+  override def write: Writer = new CountVectorizerModelWriter(this)
+}
+
+@Since("1.6.0")
+object CountVectorizerModel extends Readable[CountVectorizerModel] {
+
+  private[CountVectorizerModel]
+  class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer {
+
+    private case class Data(vocabulary: Seq[String])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.vocabulary)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class CountVectorizerModelReader extends Reader[CountVectorizerModel] {
+
+    private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
+
+    override def load(path: String): CountVectorizerModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("vocabulary")
+        .head()
+      val vocabulary = data.getAs[Seq[String]](0).toArray
+      val model = new CountVectorizerModel(metadata.uid, vocabulary)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): CountVectorizerModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 4c36df7..53ad34e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml._
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql._
@@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
  * Compute the Inverse Document Frequency (IDF) given a collection of documents.
  */
 @Experimental
-final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
+final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable
{
 
   def this() = this(Identifiable.randomUID("idf"))
 
@@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with
IDFBa
   }
 
   override def copy(extra: ParamMap): IDF = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object IDF extends Readable[IDF] {
+
+  @Since("1.6.0")
+  override def read: Reader[IDF] = new DefaultParamsReader
+
+  @Since("1.6.0")
+  override def load(path: String): IDF = super.load(path)
 }
 
 /**
@@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with
IDFBa
 class IDFModel private[ml] (
     override val uid: String,
     idfModel: feature.IDFModel)
-  extends Model[IDFModel] with IDFBase {
+  extends Model[IDFModel] with IDFBase with Writable {
+
+  import IDFModel._
 
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
@@ -117,4 +134,50 @@ class IDFModel private[ml] (
     val copied = new IDFModel(uid, idfModel)
     copyValues(copied, extra).setParent(parent)
   }
+
+  /** Returns the IDF vector. */
+  @Since("1.6.0")
+  def idf: Vector = idfModel.idf
+
+  @Since("1.6.0")
+  override def write: Writer = new IDFModelWriter(this)
+}
+
+@Since("1.6.0")
+object IDFModel extends Readable[IDFModel] {
+
+  private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer {
+
+    private case class Data(idf: Vector)
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.idf)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class IDFModelReader extends Reader[IDFModel] {
+
+    private val className = "org.apache.spark.ml.feature.IDFModel"
+
+    override def load(path: String): IDFModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("idf")
+        .head()
+      val idf = data.getAs[Vector](0)
+      val model = new IDFModel(metadata.uid, new feature.IDFModel(idf))
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: Reader[IDFModel] = new IDFModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): IDFModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 1b494ec..24d964f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
-import org.apache.spark.ml.util.Identifiable
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
 import org.apache.spark.mllib.stat.Statistics
 import org.apache.spark.sql._
@@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol
with H
  */
 @Experimental
 class MinMaxScaler(override val uid: String)
-  extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
+  extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable {
 
   def this() = this(Identifiable.randomUID("minMaxScal"))
 
@@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String)
   }
 
   override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object MinMaxScaler extends Readable[MinMaxScaler] {
+
+  @Since("1.6.0")
+  override def read: Reader[MinMaxScaler] = new DefaultParamsReader
+
+  @Since("1.6.0")
+  override def load(path: String): MinMaxScaler = super.load(path)
 }
 
 /**
@@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] (
     override val uid: String,
     val originalMin: Vector,
     val originalMax: Vector)
-  extends Model[MinMaxScalerModel] with MinMaxScalerParams {
+  extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable {
+
+  import MinMaxScalerModel._
 
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
@@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] (
     val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: Writer = new MinMaxScalerModelWriter(this)
+}
+
+@Since("1.6.0")
+object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
+
+  private[MinMaxScalerModel]
+  class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer {
+
+    private case class Data(originalMin: Vector, originalMax: Vector)
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = new Data(instance.originalMin, instance.originalMax)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] {
+
+    private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
+
+    override def load(path: String): MinMaxScalerModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath)
+        .select("originalMin", "originalMax")
+        .head()
+      val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): MinMaxScalerModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index f6d0b0c..ab04e54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml._
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql._
@@ -57,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol
with
  */
 @Experimental
 class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel]
-  with StandardScalerParams {
+  with StandardScalerParams with Writable {
 
   def this() = this(Identifiable.randomUID("stdScal"))
 
@@ -94,6 +96,19 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
   }
 
   override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object StandardScaler extends Readable[StandardScaler] {
+
+  @Since("1.6.0")
+  override def read: Reader[StandardScaler] = new DefaultParamsReader
+
+  @Since("1.6.0")
+  override def load(path: String): StandardScaler = super.load(path)
 }
 
 /**
@@ -104,7 +119,9 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
 class StandardScalerModel private[ml] (
     override val uid: String,
     scaler: feature.StandardScalerModel)
-  extends Model[StandardScalerModel] with StandardScalerParams {
+  extends Model[StandardScalerModel] with StandardScalerParams with Writable {
+
+  import StandardScalerModel._
 
   /** Standard deviation of the StandardScalerModel */
   val std: Vector = scaler.std
@@ -112,6 +129,14 @@ class StandardScalerModel private[ml] (
   /** Mean of the StandardScalerModel */
   val mean: Vector = scaler.mean
 
+  /** Whether to scale to unit standard deviation. */
+  @Since("1.6.0")
+  def getWithStd: Boolean = scaler.withStd
+
+  /** Whether to center data with mean. */
+  @Since("1.6.0")
+  def getWithMean: Boolean = scaler.withMean
+
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 
@@ -138,4 +163,49 @@ class StandardScalerModel private[ml] (
     val copied = new StandardScalerModel(uid, scaler)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: Writer = new StandardScalerModelWriter(this)
+}
+
+@Since("1.6.0")
+object StandardScalerModel extends Readable[StandardScalerModel] {
+
+  private[StandardScalerModel]
+  class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer {
+
+    private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class StandardScalerModelReader extends Reader[StandardScalerModel] {
+
+    private val className = "org.apache.spark.ml.feature.StandardScalerModel"
+
+    override def load(path: String): StandardScalerModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
+        sqlContext.read.parquet(dataPath)
+          .select("std", "mean", "withStd", "withMean")
+          .head()
+      // This is very likely to change in the future because withStd and withMean should
be params.
+      val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
+      val model = new StandardScalerModel(metadata.uid, oldModel)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: Reader[StandardScalerModel] = new StandardScalerModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): StandardScalerModel = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index f782a27..f16f6af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -17,13 +17,14 @@
 
 package org.apache.spark.ml.feature
 
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.SparkException
-import org.apache.spark.annotation.{Since, Experimental}
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model, Transformer}
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
@@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol
with Ha
  */
 @Experimental
 class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
-  with StringIndexerBase {
+  with StringIndexerBase with Writable {
 
   def this() = this(Identifiable.randomUID("strIdx"))
 
@@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
   }
 
   override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
+
+  @Since("1.6.0")
+  override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object StringIndexer extends Readable[StringIndexer] {
+
+  @Since("1.6.0")
+  override def read: Reader[StringIndexer] = new DefaultParamsReader
+
+  @Since("1.6.0")
+  override def load(path: String): StringIndexer = super.load(path)
 }
 
 /**
@@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
 @Experimental
 class StringIndexerModel (
     override val uid: String,
-    val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
+    val labels: Array[String])
+  extends Model[StringIndexerModel] with StringIndexerBase with Writable {
+
+  import StringIndexerModel._
 
   def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)
 
@@ -176,6 +193,49 @@ class StringIndexerModel (
     val copied = new StringIndexerModel(uid, labels)
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: StringIndexModelWriter = new StringIndexModelWriter(this)
+}
+
+@Since("1.6.0")
+object StringIndexerModel extends Readable[StringIndexerModel] {
+
+  private[StringIndexerModel]
+  class StringIndexModelWriter(instance: StringIndexerModel) extends Writer {
+
+    private case class Data(labels: Array[String])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.labels)
+      val dataPath = new Path(path, "data").toString
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class StringIndexerModelReader extends Reader[StringIndexerModel] {
+
+    private val className = "org.apache.spark.ml.feature.StringIndexerModel"
+
+    override def load(path: String): StringIndexerModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val dataPath = new Path(path, "data").toString
+      val data = sqlContext.read.parquet(dataPath)
+        .select("labels")
+        .head()
+      val labels = data.getAs[Seq[String]](0).toArray
+      val model = new StringIndexerModel(metadata.uid, labels)
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+
+  @Since("1.6.0")
+  override def read: Reader[StringIndexerModel] = new StringIndexerModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): StringIndexerModel = super.load(path)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index e192fa4..9c99990 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.Row
 
-class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest {
 
   test("params") {
+    ParamsSuite.checkParams(new CountVectorizer)
     ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
   }
 
@@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
{
         assert(features ~== expected absTol 1e-14)
     }
   }
+
+  test("CountVectorizer read/write") {
+    val t = new CountVectorizer()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setMinDF(0.5)
+      .setMinTF(3.0)
+      .setVocabSize(10)
+    testDefaultReadWrite(t)
+  }
+
+  test("CountVectorizerModel read/write") {
+    val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c"))
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setMinTF(3.0)
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.vocabulary === instance.vocabulary)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index 08f80af..bc958c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.Row
 
-class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
 
   def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
     dataSet.map {
@@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
         assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
     }
   }
+
+  test("IDF read/write") {
+    val t = new IDF()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setMinDocFreq(5)
+    testDefaultReadWrite(t)
+  }
+
+  test("IDFModel read/write") {
+    val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0)))
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.idf === instance.idf)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index c04dda4..09183fe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -18,12 +18,12 @@
 package org.apache.spark.ml.feature
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.{Row, SQLContext}
 
-class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
 
   test("MinMaxScaler fit basic case") {
     val sqlContext = new SQLContext(sc)
@@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext
{
       }
     }
   }
+
+  test("MinMaxScaler read/write") {
+    val t = new MinMaxScaler()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setMax(1.0)
+      .setMin(-1.0)
+    testDefaultReadWrite(t)
+  }
+
+  test("MinMaxScalerModel read/write") {
+    val instance = new MinMaxScalerModel(
+        "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0))
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setMin(-1.0)
+      .setMax(1.0)
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.originalMin === instance.originalMin)
+    assert(newInstance.originalMax === instance.originalMax)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 879a3ae..49a4b2e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -19,12 +19,16 @@ package org.apache.spark.ml.feature
 
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
 
-class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
+  with DefaultReadWriteTest {
 
   @transient var data: Array[Vector] = _
   @transient var resWithStd: Array[Vector] = _
@@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
     )
   }
 
-  def assertResult(dataframe: DataFrame): Unit = {
-    dataframe.select("standarded_features", "expected").collect().foreach {
+  def assertResult(df: DataFrame): Unit = {
+    df.select("standardized_features", "expected").collect().foreach {
       case Row(vector1: Vector, vector2: Vector) =>
         assert(vector1 ~== vector2 absTol 1E-5,
           "The vector value is not correct after standardization.")
     }
   }
 
+  test("params") {
+    ParamsSuite.checkParams(new StandardScaler)
+    val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
+    ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
+  }
+
   test("Standardization with default parameter") {
     val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
 
-    val standardscaler0 = new StandardScaler()
+    val standardScaler0 = new StandardScaler()
       .setInputCol("features")
-      .setOutputCol("standarded_features")
+      .setOutputCol("standardized_features")
       .fit(df0)
 
-    assertResult(standardscaler0.transform(df0))
+    assertResult(standardScaler0.transform(df0))
   }
 
   test("Standardization with setter") {
@@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
     val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
     val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
 
-    val standardscaler1 = new StandardScaler()
+    val standardScaler1 = new StandardScaler()
       .setInputCol("features")
-      .setOutputCol("standarded_features")
+      .setOutputCol("standardized_features")
       .setWithMean(true)
       .setWithStd(true)
       .fit(df1)
 
-    val standardscaler2 = new StandardScaler()
+    val standardScaler2 = new StandardScaler()
       .setInputCol("features")
-      .setOutputCol("standarded_features")
+      .setOutputCol("standardized_features")
       .setWithMean(true)
       .setWithStd(false)
       .fit(df2)
 
-    val standardscaler3 = new StandardScaler()
+    val standardScaler3 = new StandardScaler()
       .setInputCol("features")
-      .setOutputCol("standarded_features")
+      .setOutputCol("standardized_features")
       .setWithMean(false)
       .setWithStd(false)
       .fit(df3)
 
-    assertResult(standardscaler1.transform(df1))
-    assertResult(standardscaler2.transform(df2))
-    assertResult(standardscaler3.transform(df3))
+    assertResult(standardScaler1.transform(df1))
+    assertResult(standardScaler2.transform(df2))
+    assertResult(standardScaler3.transform(df3))
+  }
+
+  test("StandardScaler read/write") {
+    val t = new StandardScaler()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setWithStd(false)
+      .setWithMean(true)
+    testDefaultReadWrite(t)
+  }
+
+  test("StandardScalerModel read/write") {
+    val oldModel = new feature.StandardScalerModel(
+      Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
+    val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.std === instance.std)
+    assert(newInstance.mean === instance.mean)
+    assert(newInstance.getWithStd === instance.getWithStd)
+    assert(newInstance.getWithMean === instance.getWithMean)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/dc1e2374/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index be37bfb..749bfac 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -118,6 +118,23 @@ class StringIndexerSuite
     assert(indexerModel.transform(df).eq(df))
   }
 
+  test("StringIndexer read/write") {
+    val t = new StringIndexer()
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setHandleInvalid("skip")
+    testDefaultReadWrite(t)
+  }
+
+  test("StringIndexerModel read/write") {
+    val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c"))
+      .setInputCol("myInputCol")
+      .setOutputCol("myOutputCol")
+      .setHandleInvalid("skip")
+    val newInstance = testDefaultReadWrite(instance)
+    assert(newInstance.labels === instance.labels)
+  }
+
   test("IndexToString params") {
     val idxToStr = new IndexToString()
     ParamsSuite.checkParams(idxToStr)
@@ -175,7 +192,7 @@ class StringIndexerSuite
     assert(outSchema("output").dataType === StringType)
   }
 
-  test("read/write") {
+  test("IndexToString read/write") {
     val t = new IndexToString()
       .setInputCol("myInputCol")
       .setOutputCol("myOutputCol")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message