spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6791][ML] Add read/write for CrossValidator and Evaluators
Date Mon, 23 Nov 2015 05:49:00 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 a36d9bc75 -> 835b5488f


[SPARK-6791][ML] Add read/write for CrossValidator and Evaluators

I believe this works for general estimators within CrossValidator, including compound estimators.
 (See the complex unit test.)

Added read/write for all 3 Evaluators as well.

CC: mengxr yanboliang

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9848 from jkbradley/cv-io.

(cherry picked from commit a6fda0bfc16a13b28b1cecc96f1ff91363089144)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/835b5488
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/835b5488
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/835b5488

Branch: refs/heads/branch-1.6
Commit: 835b5488ff644e2f51442943adffd3cd682703ac
Parents: a36d9bc
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Sun Nov 22 21:48:48 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Sun Nov 22 21:48:57 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Pipeline.scala    |  38 +--
 .../BinaryClassificationEvaluator.scala         |  11 +-
 .../MulticlassClassificationEvaluator.scala     |  12 +-
 .../ml/evaluation/RegressionEvaluator.scala     |  11 +-
 .../apache/spark/ml/recommendation/ALS.scala    |  14 +-
 .../apache/spark/ml/tuning/CrossValidator.scala | 229 ++++++++++++++++++-
 .../org/apache/spark/ml/util/ReadWrite.scala    |  48 ++--
 .../org/apache/spark/ml/PipelineSuite.scala     |   4 +-
 .../BinaryClassificationEvaluatorSuite.scala    |  13 +-
 ...MulticlassClassificationEvaluatorSuite.scala |  13 +-
 .../evaluation/RegressionEvaluatorSuite.scala   |  12 +-
 .../spark/ml/tuning/CrossValidatorSuite.scala   | 202 +++++++++++++++-
 12 files changed, 522 insertions(+), 85 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 6f15b37..4b2b3f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter
 import org.apache.spark.ml.util._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -232,20 +231,9 @@ object Pipeline extends MLReadable[Pipeline] {
         stages: Array[PipelineStage],
         sc: SparkContext,
         path: String): Unit = {
-      // Copied and edited from DefaultParamsWriter.saveMetadata
-      // TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
-      val uid = instance.uid
-      val cls = instance.getClass.getName
       val stageUids = stages.map(_.uid)
       val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
-      val metadata = ("class" -> cls) ~
-        ("timestamp" -> System.currentTimeMillis()) ~
-        ("sparkVersion" -> sc.version) ~
-        ("uid" -> uid) ~
-        ("paramMap" -> jsonParams)
-      val metadataPath = new Path(path, "metadata").toString
-      val metadataJson = compact(render(metadata))
-      sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
+      DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))
 
       // Save stages
       val stagesDir = new Path(path, "stages").toString
@@ -266,30 +254,10 @@ object Pipeline extends MLReadable[Pipeline] {
 
       implicit val format = DefaultFormats
       val stagesDir = new Path(path, "stages").toString
-      val stageUids: Array[String] = metadata.params match {
-        case JObject(pairs) =>
-          if (pairs.length != 1) {
-            // Should not happen unless file is corrupted or we have a bug.
-            throw new RuntimeException(
-              s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
-          }
-          pairs.head match {
-            case ("stageUids", jsonValue) =>
-              jsonValue.extract[Seq[String]].toArray
-            case (paramName, jsonValue) =>
-              // Should not happen unless file is corrupted or we have a bug.
-              throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName"
+
-                s" in metadata: ${metadata.metadataStr}")
-          }
-        case _ =>
-          throw new IllegalArgumentException(
-            s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
-      }
+      val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
       val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx)
=>
         val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
-        val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
-        val cls = Utils.classForName(stageMetadata.className)
-        cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
+        DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
       }
       (metadata.uid, stages)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 1fe3aba..bfb7096 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable,
SchemaUtils}
 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql.{DataFrame, Row}
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType
 @Since("1.2.0")
 @Experimental
 class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
-  extends Evaluator with HasRawPredictionCol with HasLabelCol {
+  extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable
{
 
   @Since("1.2.0")
   def this() = this(Identifiable.randomUID("binEval"))
@@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0")
override va
   @Since("1.4.1")
   override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
 }
+
+@Since("1.6.0")
+object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator]
{
+
+  @Since("1.6.0")
+  override def load(path: String): BinaryClassificationEvaluator = super.load(path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index df5f04c..c44db0e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
 import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils,
Identifiable}
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.sql.{Row, DataFrame}
 import org.apache.spark.sql.types.DoubleType
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType
 @Since("1.5.0")
 @Experimental
 class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid:
String)
-  extends Evaluator with HasPredictionCol with HasLabelCol {
+  extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("mcEval"))
@@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0")
overrid
   @Since("1.5.0")
   override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
 }
+
+@Since("1.6.0")
+object MulticlassClassificationEvaluator
+  extends DefaultParamsReadable[MulticlassClassificationEvaluator] {
+
+  @Since("1.6.0")
+  override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index ba012f4..daaa174 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable,
SchemaUtils}
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
 @Since("1.4.0")
 @Experimental
 final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
-  extends Evaluator with HasPredictionCol with HasLabelCol {
+  extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
 
   @Since("1.4.0")
   def this() = this(Identifiable.randomUID("regEval"))
@@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override
val ui
   @Since("1.5.0")
   override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
 }
+
+@Since("1.6.0")
+object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {
+
+  @Since("1.6.0")
+  override def load(path: String): RegressionEvaluator = super.load(path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 4d35177..b798aa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64
 
 import com.github.fommil.netlib.BLAS.{getInstance => blas}
 import org.apache.hadoop.fs.{FileSystem, Path}
-import org.json4s.{DefaultFormats, JValue}
+import org.json4s.DefaultFormats
 import org.json4s.JsonDSL._
-import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.{Logging, Partitioner}
 import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
@@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] {
   private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
 
     override protected def saveImpl(path: String): Unit = {
-      val extraMetadata = render("rank" -> instance.rank)
+      val extraMetadata = "rank" -> instance.rank
       DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
       val userPath = new Path(path, "userFactors").toString
       instance.userFactors.write.format("parquet").save(userPath)
@@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] {
     override def load(path: String): ALSModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       implicit val format = DefaultFormats
-      val rank: Int = metadata.extraMetadata match {
-        case Some(m: JValue) =>
-          (m \ "rank").extract[Int]
-        case None =>
-          throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:"
+
-            s" ${metadata.metadataStr}")
-      }
-
+      val rank = (metadata.metadata \ "rank").extract[Int]
       val userPath = new Path(path, "userFactors").toString
       val userFactors = sqlContext.read.format("parquet").load(userPath)
       val itemPath = new Path(path, "itemFactors").toString

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 77d9948..83a9048 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -18,17 +18,24 @@
 package org.apache.spark.ml.tuning
 
 import com.github.fommil.netlib.F2jBLAS
+import org.apache.hadoop.fs.Path
+import org.json4s.{JObject, DefaultFormats}
+import org.json4s.jackson.JsonMethods._
 
-import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.feature.RFormulaModel
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml._
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
+
 /**
  * Params for [[CrossValidator]] and [[CrossValidatorModel]].
  */
@@ -53,7 +60,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
  */
 @Experimental
 class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
-  with CrossValidatorParams with Logging {
+  with CrossValidatorParams with MLWritable with Logging {
 
   def this() = this(Identifiable.randomUID("cv"))
 
@@ -131,6 +138,166 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
     }
     copied
   }
+
+  // Currently, this only works if all [[Param]]s in [[estimatorParamMaps]] are simple types.
+  // E.g., this may fail if a [[Param]] is an instance of an [[Estimator]].
+  // However, this case should be unusual.
+  @Since("1.6.0")
+  override def write: MLWriter = new CrossValidator.CrossValidatorWriter(this)
+}
+
+@Since("1.6.0")
+object CrossValidator extends MLReadable[CrossValidator] {
+
+  @Since("1.6.0")
+  override def read: MLReader[CrossValidator] = new CrossValidatorReader
+
+  @Since("1.6.0")
+  override def load(path: String): CrossValidator = super.load(path)
+
+  private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter
{
+
+    SharedReadWrite.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit =
+      SharedReadWrite.saveImpl(path, instance, sc)
+  }
+
+  private class CrossValidatorReader extends MLReader[CrossValidator] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[CrossValidator].getName
+
+    override def load(path: String): CrossValidator = {
+      val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
+        SharedReadWrite.load(path, sc, className)
+      new CrossValidator(metadata.uid)
+        .setEstimator(estimator)
+        .setEvaluator(evaluator)
+        .setEstimatorParamMaps(estimatorParamMaps)
+        .setNumFolds(numFolds)
+    }
+  }
+
+  private object CrossValidatorReader {
+    /**
+     * Examine the given estimator (which may be a compound estimator) and extract a mapping
+     * from UIDs to corresponding [[Params]] instances.
+     */
+    def getUidMap(instance: Params): Map[String, Params] = {
+      val uidList = getUidMapImpl(instance)
+      val uidMap = uidList.toMap
+      if (uidList.size != uidMap.size) {
+        throw new RuntimeException("CrossValidator.load found a compound estimator with stages"
+
+          s" with duplicate UIDs.  List of UIDs: ${uidList.map(_._1).mkString(", ")}")
+      }
+      uidMap
+    }
+
+    def getUidMapImpl(instance: Params): List[(String, Params)] = {
+      val subStages: Array[Params] = instance match {
+        case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+        case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+        case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+        case ovr: OneVsRestParams =>
+          // TODO: SPARK-11892: This case may require special handling.
+          throw new UnsupportedOperationException("CrossValidator write will fail because
it" +
+            " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
+        case rform: RFormulaModel =>
+          // TODO: SPARK-11891: This case may require special handling.
+          throw new UnsupportedOperationException("CrossValidator write will fail because
it" +
+            " cannot yet handle an estimator containing an RFormulaModel")
+        case _: Params => Array()
+      }
+      val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_
++ _)
+      List((instance.uid, instance)) ++ subStageMaps
+    }
+  }
+
+  private[tuning] object SharedReadWrite {
+
+    /**
+     * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
+     * This does not check [[CrossValidator.estimatorParamMaps]].
+     */
+    def validateParams(instance: ValidatorParams): Unit = {
+      def checkElement(elem: Params, name: String): Unit = elem match {
+        case stage: MLWritable => // good
+        case other =>
+          throw new UnsupportedOperationException("CrossValidator write will fail " +
+            s" because it contains $name which does not implement Writable." +
+            s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+      }
+      checkElement(instance.getEvaluator, "evaluator")
+      checkElement(instance.getEstimator, "estimator")
+      // Check to make sure all Params apply to this estimator.  Throw an error if any do
not.
+      // Extraneous Params would cause problems when loading the estimatorParamMaps.
+      val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
+      instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+        pMap.toSeq.foreach { case ParamPair(p, v) =>
+          require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params
in" +
+            s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its"
+
+            s" Evaluator.  An extraneous Param was found: $p")
+        }
+      }
+    }
+
+    private[tuning] def saveImpl(
+        path: String,
+        instance: CrossValidatorParams,
+        sc: SparkContext,
+        extraMetadata: Option[JObject] = None): Unit = {
+      import org.json4s.JsonDSL._
+
+      val estimatorParamMapsJson = compact(render(
+        instance.getEstimatorParamMaps.map { case paramMap =>
+          paramMap.toSeq.map { case ParamPair(p, v) =>
+            Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+          }
+        }.toSeq
+      ))
+      val jsonParams = List(
+        "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
+        "estimatorParamMaps" -> parse(estimatorParamMapsJson)
+      )
+      DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+      val evaluatorPath = new Path(path, "evaluator").toString
+      instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+      val estimatorPath = new Path(path, "estimator").toString
+      instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+    }
+
+    private[tuning] def load[M <: Model[M]](
+        path: String,
+        sc: SparkContext,
+        expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap],
Int) = {
+
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+      implicit val format = DefaultFormats
+      val evaluatorPath = new Path(path, "evaluator").toString
+      val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+      val estimatorPath = new Path(path, "estimator").toString
+      val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath,
sc)
+
+      val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
+
+      val numFolds = (metadata.params \ "numFolds").extract[Int]
+      val estimatorParamMaps: Array[ParamMap] =
+        (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map
{
+          pMap =>
+            val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+              val est = uidToParams(pInfo("parent"))
+              val param = est.getParam(pInfo("name"))
+              val value = param.jsonDecode(pInfo("value"))
+              param -> value
+            }
+            ParamMap(paramPairs: _*)
+        }.toArray
+      (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
+    }
+  }
 }
 
 /**
@@ -139,14 +306,14 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
  *
  * @param bestModel The best model selected from k-fold cross validation.
  * @param avgMetrics Average cross-validation metrics for each paramMap in
- *                   [[estimatorParamMaps]], in the corresponding order.
+ *                   [[CrossValidator.estimatorParamMaps]], in the corresponding order.
  */
 @Experimental
 class CrossValidatorModel private[ml] (
     override val uid: String,
     val bestModel: Model[_],
     val avgMetrics: Array[Double])
-  extends Model[CrossValidatorModel] with CrossValidatorParams {
+  extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {
 
   override def validateParams(): Unit = {
     bestModel.validateParams()
@@ -168,4 +335,54 @@ class CrossValidatorModel private[ml] (
       avgMetrics.clone())
     copyValues(copied, extra).setParent(parent)
   }
+
+  @Since("1.6.0")
+  override def write: MLWriter = new CrossValidatorModel.CrossValidatorModelWriter(this)
+}
+
+@Since("1.6.0")
+object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
+
+  import CrossValidator.SharedReadWrite
+
+  @Since("1.6.0")
+  override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
+
+  @Since("1.6.0")
+  override def load(path: String): CrossValidatorModel = super.load(path)
+
+  private[CrossValidatorModel]
+  class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
+
+    SharedReadWrite.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit = {
+      import org.json4s.JsonDSL._
+      val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
+      SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+      val bestModelPath = new Path(path, "bestModel").toString
+      instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+    }
+  }
+
+  private class CrossValidatorModelReader extends MLReader[CrossValidatorModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[CrossValidatorModel].getName
+
+    override def load(path: String): CrossValidatorModel = {
+      implicit val format = DefaultFormats
+
+      val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
+        SharedReadWrite.load(path, sc, className)
+      val bestModelPath = new Path(path, "bestModel").toString
+      val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+      val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
+      val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
+      cv.set(cv.estimator, estimator)
+        .set(cv.evaluator, evaluator)
+        .set(cv.estimatorParamMaps, estimatorParamMaps)
+        .set(cv.numFolds, numFolds)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index ff9322d..8484b1f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -202,25 +202,36 @@ private[ml] object DefaultParamsWriter {
    *  - timestamp
    *  - sparkVersion
    *  - uid
-   *  - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
+   *  - paramMap
+   *  - (optionally, extra metadata)
+   * @param extraMetadata  Extra metadata to be saved at same level as uid, paramMap, etc.
+   * @param paramMap  If given, this is saved in the "paramMap" field.
+   *                  Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using
+   *                  [[org.apache.spark.ml.param.Param.jsonEncode()]].
    */
   def saveMetadata(
       instance: Params,
       path: String,
       sc: SparkContext,
-      extraMetadata: Option[JValue] = None): Unit = {
+      extraMetadata: Option[JObject] = None,
+      paramMap: Option[JValue] = None): Unit = {
     val uid = instance.uid
     val cls = instance.getClass.getName
     val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
-    val jsonParams = params.map { case ParamPair(p, v) =>
+    val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
       p.name -> parse(p.jsonEncode(v))
-    }.toList
-    val metadata = ("class" -> cls) ~
+    }.toList))
+    val basicMetadata = ("class" -> cls) ~
       ("timestamp" -> System.currentTimeMillis()) ~
       ("sparkVersion" -> sc.version) ~
       ("uid" -> uid) ~
-      ("paramMap" -> jsonParams) ~
-      ("extraMetadata" -> extraMetadata)
+      ("paramMap" -> jsonParams)
+    val metadata = extraMetadata match {
+      case Some(jObject) =>
+        basicMetadata ~ jObject
+      case None =>
+        basicMetadata
+    }
     val metadataPath = new Path(path, "metadata").toString
     val metadataJson = compact(render(metadata))
     sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
@@ -251,8 +262,8 @@ private[ml] object DefaultParamsReader {
   /**
    * All info from metadata file.
    * @param params  paramMap, as a [[JValue]]
-   * @param extraMetadata  Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
-   * @param metadataStr  Full metadata file String (for debugging)
+   * @param metadata  All metadata, including the other fields
+   * @param metadataJson  Full metadata file String (for debugging)
    */
   case class Metadata(
       className: String,
@@ -260,8 +271,8 @@ private[ml] object DefaultParamsReader {
       timestamp: Long,
       sparkVersion: String,
       params: JValue,
-      extraMetadata: Option[JValue],
-      metadataStr: String)
+      metadata: JValue,
+      metadataJson: String)
 
   /**
    * Load metadata from file.
@@ -279,13 +290,12 @@ private[ml] object DefaultParamsReader {
     val timestamp = (metadata \ "timestamp").extract[Long]
     val sparkVersion = (metadata \ "sparkVersion").extract[String]
     val params = metadata \ "paramMap"
-    val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
     if (expectedClassName.nonEmpty) {
       require(className == expectedClassName, s"Error loading metadata: Expected class name"
+
         s" $expectedClassName but found class name $className")
     }
 
-    Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
+    Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
   }
 
   /**
@@ -303,7 +313,17 @@ private[ml] object DefaultParamsReader {
         }
       case _ =>
         throw new IllegalArgumentException(
-          s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
+          s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
     }
   }
+
+  /**
+   * Load a [[Params]] instance from the given path, and return it.
+   * This assumes the instance implements [[MLReadable]].
+   */
+  def loadParamsInstance[T](path: String, sc: SparkContext): T = {
+    val metadata = DefaultParamsReader.loadMetadata(path, sc)
+    val cls = Utils.classForName(metadata.className)
+    cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 12aba6b..8c86767 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -17,11 +17,9 @@
 
 package org.apache.spark.ml
 
-import java.io.File
-
 import scala.collection.JavaConverters._
 
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
 import org.mockito.Matchers.{any, eq => meq}
 import org.mockito.Mockito.when
 import org.scalatest.mock.MockitoSugar.mock

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
index def869f..a535c12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class BinaryClassificationEvaluatorSuite extends SparkFunSuite {
+class BinaryClassificationEvaluatorSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new BinaryClassificationEvaluator)
   }
+
+  test("read/write") {
+    val evaluator = new BinaryClassificationEvaluator()
+      .setRawPredictionCol("myRawPrediction")
+      .setLabelCol("myLabel")
+      .setMetricName("areaUnderPR")
+    testDefaultReadWrite(evaluator)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
index 6d8412b..7ee6597 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -19,10 +19,21 @@ package org.apache.spark.ml.evaluation
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
 
-class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
+class MulticlassClassificationEvaluatorSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
   }
+
+  test("read/write") {
+    val evaluator = new MulticlassClassificationEvaluator()
+      .setPredictionCol("myPrediction")
+      .setLabelCol("myLabel")
+      .setMetricName("recall")
+    testDefaultReadWrite(evaluator)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index aa722da..60886bf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -20,10 +20,12 @@ package org.apache.spark.ml.evaluation
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 
-class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RegressionEvaluatorSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new RegressionEvaluator)
@@ -73,4 +75,12 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
     evaluator.setMetricName("mae")
     assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
   }
+
+  test("read/write") {
+    val evaluator = new RegressionEvaluator()
+      .setPredictionCol("myPrediction")
+      .setLabelCol("myLabel")
+      .setMetricName("r2")
+    testDefaultReadWrite(evaluator)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/835b5488/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index cbe0929..dd63660 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -18,19 +18,22 @@
 package org.apache.spark.ml.tuning
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.{Pipeline, Estimator, Model}
+import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression}
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{ParamPair, ParamMap}
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.types.StructType
 
-class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CrossValidatorSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   @transient var dataset: DataFrame = _
 
@@ -95,7 +98,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext
{
   }
 
   test("validateParams should check estimatorParamMaps") {
-    import CrossValidatorSuite._
+    import CrossValidatorSuite.{MyEstimator, MyEvaluator}
 
     val est = new MyEstimator("est")
     val eval = new MyEvaluator
@@ -116,9 +119,194 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext
{
       cv.validateParams()
     }
   }
+
+  test("read/write: CrossValidator with simple estimator") {
+    val lr = new LogisticRegression().setMaxIter(3)
+    val evaluator = new BinaryClassificationEvaluator()
+      .setMetricName("areaUnderPR")  // not default metric
+    val paramMaps = new ParamGridBuilder()
+        .addGrid(lr.regParam, Array(0.1, 0.2))
+        .build()
+    val cv = new CrossValidator()
+      .setEstimator(lr)
+      .setEvaluator(evaluator)
+      .setNumFolds(20)
+      .setEstimatorParamMaps(paramMaps)
+
+    val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+    assert(cv.uid === cv2.uid)
+    assert(cv.getNumFolds === cv2.getNumFolds)
+
+    assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+    val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+    assert(evaluator.uid === evaluator2.uid)
+    assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+    cv2.getEstimator match {
+      case lr2: LogisticRegression =>
+        assert(lr.uid === lr2.uid)
+        assert(lr.getMaxIter === lr2.getMaxIter)
+      case other =>
+        throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+          s" LogisticRegression but found ${other.getClass.getName}")
+    }
+
+    CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+  }
+
+  test("read/write: CrossValidator with complex estimator") {
+    // workflow: CrossValidator[Pipeline[HashingTF, CrossValidator[LogisticRegression]]]
+    val lrEvaluator = new BinaryClassificationEvaluator()
+      .setMetricName("areaUnderPR")  // not default metric
+
+    val lr = new LogisticRegression().setMaxIter(3)
+    val lrParamMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.1, 0.2))
+      .build()
+    val lrcv = new CrossValidator()
+      .setEstimator(lr)
+      .setEvaluator(lrEvaluator)
+      .setEstimatorParamMaps(lrParamMaps)
+
+    val hashingTF = new HashingTF()
+    val pipeline = new Pipeline().setStages(Array(hashingTF, lrcv))
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(hashingTF.numFeatures, Array(10, 20))
+      .addGrid(lr.elasticNetParam, Array(0.0, 1.0))
+      .build()
+    val evaluator = new BinaryClassificationEvaluator()
+
+    val cv = new CrossValidator()
+      .setEstimator(pipeline)
+      .setEvaluator(evaluator)
+      .setNumFolds(20)
+      .setEstimatorParamMaps(paramMaps)
+
+    val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+    assert(cv.uid === cv2.uid)
+    assert(cv.getNumFolds === cv2.getNumFolds)
+
+    assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+    assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
+
+    CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+    cv2.getEstimator match {
+      case pipeline2: Pipeline =>
+        assert(pipeline.uid === pipeline2.uid)
+        pipeline2.getStages match {
+          case Array(hashingTF2: HashingTF, lrcv2: CrossValidator) =>
+            assert(hashingTF.uid === hashingTF2.uid)
+            lrcv2.getEstimator match {
+              case lr2: LogisticRegression =>
+                assert(lr.uid === lr2.uid)
+                assert(lr.getMaxIter === lr2.getMaxIter)
+              case other =>
+                throw new AssertionError(s"Loaded internal CrossValidator expected to be"
+
+                  s" LogisticRegression but found type ${other.getClass.getName}")
+            }
+            assert(lrcv.uid === lrcv2.uid)
+            assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+            assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
+            CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
+          case other =>
+            throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)"
+
+              " but found: " + other.map(_.getClass.getName).mkString(", "))
+        }
+      case other =>
+        throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+          s" CrossValidator but found ${other.getClass.getName}")
+    }
+  }
+
+  test("read/write: CrossValidator fails for extraneous Param") {
+    val lr = new LogisticRegression()
+    val lr2 = new LogisticRegression()
+    val evaluator = new BinaryClassificationEvaluator()
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(lr.regParam, Array(0.1, 0.2))
+      .addGrid(lr2.regParam, Array(0.1, 0.2))
+      .build()
+    val cv = new CrossValidator()
+      .setEstimator(lr)
+      .setEvaluator(evaluator)
+      .setEstimatorParamMaps(paramMaps)
+    withClue("CrossValidator.write failed to catch extraneous Param error") {
+      intercept[IllegalArgumentException] {
+        cv.write
+      }
+    }
+  }
+
+  test("read/write: CrossValidatorModel") {
+    val lr = new LogisticRegression()
+      .setThreshold(0.6)
+    val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
+      .setThreshold(0.6)
+    val evaluator = new BinaryClassificationEvaluator()
+      .setMetricName("areaUnderPR")  // not default metric
+    val paramMaps = new ParamGridBuilder()
+        .addGrid(lr.regParam, Array(0.1, 0.2))
+        .build()
+    val cv = new CrossValidatorModel("cvUid", lrModel, Array(0.3, 0.6))
+    cv.set(cv.estimator, lr)
+      .set(cv.evaluator, evaluator)
+      .set(cv.numFolds, 20)
+      .set(cv.estimatorParamMaps, paramMaps)
+
+    val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+    assert(cv.uid === cv2.uid)
+    assert(cv.getNumFolds === cv2.getNumFolds)
+
+    assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
+    val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
+    assert(evaluator.uid === evaluator2.uid)
+    assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+    cv2.getEstimator match {
+      case lr2: LogisticRegression =>
+        assert(lr.uid === lr2.uid)
+        assert(lr.getThreshold === lr2.getThreshold)
+      case other =>
+        throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+          s" LogisticRegression but found ${other.getClass.getName}")
+    }
+
+    CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+
+    cv2.bestModel match {
+      case lrModel2: LogisticRegressionModel =>
+        assert(lrModel.uid === lrModel2.uid)
+        assert(lrModel.getThreshold === lrModel2.getThreshold)
+        assert(lrModel.coefficients === lrModel2.coefficients)
+        assert(lrModel.intercept === lrModel2.intercept)
+      case other =>
+        throw new AssertionError(s"Loaded CrossValidator expected bestModel of type" +
+          s" LogisticRegressionModel but found ${other.getClass.getName}")
+    }
+    assert(cv.avgMetrics === cv2.avgMetrics)
+  }
 }
 
-object CrossValidatorSuite {
+object CrossValidatorSuite extends SparkFunSuite {
+
+  /**
+   * Assert sequences of estimatorParamMaps are identical.
+   * Params must be simple types comparable with `===`.
+   */
+  def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
+    assert(pMaps.length === pMaps2.length)
+    pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
+      assert(pMap.size === pMap2.size)
+      pMap.toSeq.foreach { case ParamPair(p, v) =>
+        assert(pMap2.contains(p))
+        assert(pMap2(p) === v)
+      }
+    }
+  }
 
   abstract class MyModel extends Model[MyModel]
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message