spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-21221][ML] CrossValidator and TrainValidationSplit Persist Nested Estimators such as OneVsRest
Date Mon, 17 Jul 2017 17:07:36 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4ce735eed -> 7047f49f4


[SPARK-21221][ML] CrossValidator and TrainValidationSplit Persist Nested Estimators such as
OneVsRest

## What changes were proposed in this pull request?
Added functionality for CrossValidator and TrainValidationSplit to persist nested estimators
such as OneVsRest. Also added CrossValidator and TrainValidation split persistence to pyspark.

## How was this patch tested?
Performed both cross validation and train validation split with a one vs. rest estimator and
tested read/write functionality of the estimator parameter maps required by these meta-algorithms.

Author: Ajay Saini <ajays725@gmail.com>

Closes #18428 from ajaysaini725/MetaAlgorithmPersistNestedEstimators.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7047f49f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7047f49f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7047f49f

Branch: refs/heads/master
Commit: 7047f49f45406be3b4a9b0aa209b3021621392ca
Parents: 4ce735e
Author: Ajay Saini <ajays725@gmail.com>
Authored: Mon Jul 17 10:07:32 2017 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Mon Jul 17 10:07:32 2017 -0700

----------------------------------------------------------------------
 .../spark/ml/tuning/ValidatorParams.scala       |  31 ++-
 .../spark/ml/tuning/CrossValidatorSuite.scala   | 103 +++++++--
 .../ml/tuning/TrainValidationSplitSuite.scala   |  84 ++++++-
 .../ml/tuning/ValidatorParamsSuiteHelpers.scala |  86 +++++++
 .../spark/ml/util/DefaultReadWriteTest.scala    |   1 -
 python/pyspark/ml/classification.py             |  92 +++++---
 python/pyspark/ml/tests.py                      | 145 +++++++++++-
 python/pyspark/ml/tuning.py                     | 226 ++++++++++++++++++-
 python/pyspark/ml/wrapper.py                    |   2 +-
 9 files changed, 696 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index d55eb14..0ab6eed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -126,10 +126,26 @@ private[ml] object ValidatorParams {
       extraMetadata: Option[JObject] = None): Unit = {
     import org.json4s.JsonDSL._
 
+    var numParamsNotJson = 0
     val estimatorParamMapsJson = compact(render(
       instance.getEstimatorParamMaps.map { case paramMap =>
         paramMap.toSeq.map { case ParamPair(p, v) =>
-          Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+          v match {
+            case writeableObj: DefaultParamsWritable =>
+              val relativePath = "epm_" + p.name + numParamsNotJson
+              val paramPath = new Path(path, relativePath).toString
+              numParamsNotJson += 1
+              writeableObj.save(paramPath)
+              Map("parent" -> p.parent, "name" -> p.name,
+                "value" -> compact(render(JString(relativePath))),
+                "isJson" -> compact(render(JBool(false))))
+            case _: MLWritable =>
+              throw new NotImplementedError("ValidatorParams.saveImpl does not handle parameters
" +
+                "of type: MLWritable that are not DefaultParamsWritable")
+            case _ =>
+              Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v),
+                "isJson" -> compact(render(JBool(true))))
+          }
         }
       }.toSeq
     ))
@@ -183,8 +199,17 @@ private[ml] object ValidatorParams {
           val paramPairs = pMap.map { case pInfo: Map[String, String] =>
             val est = uidToParams(pInfo("parent"))
             val param = est.getParam(pInfo("name"))
-            val value = param.jsonDecode(pInfo("value"))
-            param -> value
+            // [Spark-21221] introduced the isJson field
+            if (!pInfo.contains("isJson") ||
+                (pInfo.contains("isJson") && pInfo("isJson").toBoolean.booleanValue()))
{
+              val value = param.jsonDecode(pInfo("value"))
+              param -> value
+            } else {
+              val relativePath = param.jsonDecode(pInfo("value")).toString
+              val value = DefaultParamsReader
+                .loadParamsInstance[MLWritable](new Path(path, relativePath).toString, sc)
+              param -> value
+            }
           }
           ParamMap(paramPairs: _*)
       }.toArray

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 2b4e6b5..2791ea7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -19,12 +19,12 @@ package org.apache.spark.ml.tuning
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model, Pipeline}
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest}
 import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
-import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator,
RegressionEvaluator}
 import org.apache.spark.ml.feature.HashingTF
-import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
-import org.apache.spark.ml.param.{ParamMap, ParamPair}
+import org.apache.spark.ml.linalg.Vectors
+import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -153,7 +153,76 @@ class CrossValidatorSuite
           s" LogisticRegression but found ${other.getClass.getName}")
     }
 
-    CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+    ValidatorParamsSuiteHelpers
+      .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+  }
+
+  test("read/write: CrossValidator with nested estimator") {
+    val ova = new OneVsRest().setClassifier(new LogisticRegression)
+    val evaluator = new MulticlassClassificationEvaluator()
+      .setMetricName("accuracy")
+    val classifier1 = new LogisticRegression().setRegParam(2.0)
+    val classifier2 = new LogisticRegression().setRegParam(3.0)
+    // params that are not JSON serializable must inherit from Params
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(ova.classifier, Array(classifier1, classifier2))
+      .build()
+    val cv = new CrossValidator()
+      .setEstimator(ova)
+      .setEvaluator(evaluator)
+      .setNumFolds(20)
+      .setEstimatorParamMaps(paramMaps)
+
+    val cv2 = testDefaultReadWrite(cv, testParams = false)
+
+    assert(cv.uid === cv2.uid)
+    assert(cv.getNumFolds === cv2.getNumFolds)
+    assert(cv.getSeed === cv2.getSeed)
+
+    assert(cv2.getEvaluator.isInstanceOf[MulticlassClassificationEvaluator])
+    val evaluator2 = cv2.getEvaluator.asInstanceOf[MulticlassClassificationEvaluator]
+    assert(evaluator.uid === evaluator2.uid)
+    assert(evaluator.getMetricName === evaluator2.getMetricName)
+
+    cv2.getEstimator match {
+      case ova2: OneVsRest =>
+        assert(ova.uid === ova2.uid)
+        val classifier = ova2.getClassifier
+        classifier match {
+          case lr: LogisticRegression =>
+            assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
+              === lr.getMaxIter)
+          case _ =>
+            throw new AssertionError(s"Loaded CrossValidator expected estimator of type"
+
+              s" LogisticREgression but found ${classifier.getClass.getName}")
+        }
+
+      case other =>
+        throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
+          s" OneVsRest but found ${other.getClass.getName}")
+    }
+
+    ValidatorParamsSuiteHelpers
+      .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+  }
+
+  test("read/write: Persistence of nested estimator works if parent directory changes") {
+    val ova = new OneVsRest().setClassifier(new LogisticRegression)
+    val evaluator = new MulticlassClassificationEvaluator()
+      .setMetricName("accuracy")
+    val classifier1 = new LogisticRegression().setRegParam(2.0)
+    val classifier2 = new LogisticRegression().setRegParam(3.0)
+    // params that are not JSON serializable must inherit from Params
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(ova.classifier, Array(classifier1, classifier2))
+      .build()
+    val cv = new CrossValidator()
+      .setEstimator(ova)
+      .setEvaluator(evaluator)
+      .setNumFolds(20)
+      .setEstimatorParamMaps(paramMaps)
+
+    ValidatorParamsSuiteHelpers.testFileMove(cv)
   }
 
   test("read/write: CrossValidator with complex estimator") {
@@ -193,7 +262,8 @@ class CrossValidatorSuite
     assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
     assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
 
-    CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+    ValidatorParamsSuiteHelpers
+      .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
 
     cv2.getEstimator match {
       case pipeline2: Pipeline =>
@@ -212,7 +282,8 @@ class CrossValidatorSuite
             assert(lrcv.uid === lrcv2.uid)
             assert(lrcv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
             assert(lrEvaluator.uid === lrcv2.getEvaluator.uid)
-            CrossValidatorSuite.compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
+            ValidatorParamsSuiteHelpers
+              .compareParamMaps(lrParamMaps, lrcv2.getEstimatorParamMaps)
           case other =>
             throw new AssertionError("Loaded Pipeline expected stages (HashingTF, CrossValidator)"
+
               " but found: " + other.map(_.getClass.getName).mkString(", "))
@@ -278,7 +349,8 @@ class CrossValidatorSuite
           s" LogisticRegression but found ${other.getClass.getName}")
     }
 
-    CrossValidatorSuite.compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
+   ValidatorParamsSuiteHelpers
+     .compareParamMaps(cv.getEstimatorParamMaps, cv2.getEstimatorParamMaps)
 
     cv2.bestModel match {
       case lrModel2: LogisticRegressionModel =>
@@ -296,21 +368,6 @@ class CrossValidatorSuite
 
 object CrossValidatorSuite extends SparkFunSuite {
 
-  /**
-   * Assert sequences of estimatorParamMaps are identical.
-   * Params must be simple types comparable with `===`.
-   */
-  def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
-    assert(pMaps.length === pMaps2.length)
-    pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
-      assert(pMap.size === pMap2.size)
-      pMap.toSeq.foreach { case ParamPair(p, v) =>
-        assert(pMap2.contains(p))
-        assert(pMap2(p) === v)
-      }
-    }
-  }
-
   abstract class MyModel extends Model[MyModel]
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol
{

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index a34f930..71a1776 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.ml.tuning
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel, OneVsRest}
 import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
 import org.apache.spark.ml.linalg.Vectors
-import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.{ParamMap}
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -95,7 +95,7 @@ class TrainValidationSplitSuite
   }
 
   test("transformSchema should check estimatorParamMaps") {
-    import TrainValidationSplitSuite._
+    import TrainValidationSplitSuite.{MyEstimator, MyEvaluator}
 
     val est = new MyEstimator("est")
     val eval = new MyEvaluator
@@ -134,6 +134,82 @@ class TrainValidationSplitSuite
 
     assert(tvs.getTrainRatio === tvs2.getTrainRatio)
     assert(tvs.getSeed === tvs2.getSeed)
+
+    ValidatorParamsSuiteHelpers
+      .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
+
+    tvs2.getEstimator match {
+      case lr2: LogisticRegression =>
+        assert(lr.uid === lr2.uid)
+        assert(lr.getMaxIter === lr2.getMaxIter)
+      case other =>
+        throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type"
+
+          s" LogisticRegression but found ${other.getClass.getName}")
+    }
+  }
+
+  test("read/write: TrainValidationSplit with nested estimator") {
+    val ova = new OneVsRest()
+      .setClassifier(new LogisticRegression)
+    val evaluator = new BinaryClassificationEvaluator()
+      .setMetricName("areaUnderPR")  // not default metric
+    val classifier1 = new LogisticRegression().setRegParam(2.0)
+    val classifier2 = new LogisticRegression().setRegParam(3.0)
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(ova.classifier, Array(classifier1, classifier2))
+      .build()
+    val tvs = new TrainValidationSplit()
+      .setEstimator(ova)
+      .setEvaluator(evaluator)
+      .setTrainRatio(0.5)
+      .setEstimatorParamMaps(paramMaps)
+      .setSeed(42L)
+
+    val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+    assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+    assert(tvs.getSeed === tvs2.getSeed)
+
+    tvs2.getEstimator match {
+      case ova2: OneVsRest =>
+        assert(ova.uid === ova2.uid)
+        val classifier = ova2.getClassifier
+        classifier match {
+          case lr: LogisticRegression =>
+            assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
+              === lr.getMaxIter)
+          case _ =>
+            throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of
type" +
+              s" LogisticREgression but found ${classifier.getClass.getName}")
+        }
+
+      case other =>
+        throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type"
+
+          s" OneVsRest but found ${other.getClass.getName}")
+    }
+
+    ValidatorParamsSuiteHelpers
+      .compareParamMaps(tvs.getEstimatorParamMaps, tvs2.getEstimatorParamMaps)
+  }
+
+  test("read/write: Persistence of nested estimator works if parent directory changes") {
+    val ova = new OneVsRest()
+      .setClassifier(new LogisticRegression)
+    val evaluator = new BinaryClassificationEvaluator()
+      .setMetricName("areaUnderPR")  // not default metric
+    val classifier1 = new LogisticRegression().setRegParam(2.0)
+    val classifier2 = new LogisticRegression().setRegParam(3.0)
+    val paramMaps = new ParamGridBuilder()
+      .addGrid(ova.classifier, Array(classifier1, classifier2))
+      .build()
+    val tvs = new TrainValidationSplit()
+      .setEstimator(ova)
+      .setEvaluator(evaluator)
+      .setTrainRatio(0.5)
+      .setEstimatorParamMaps(paramMaps)
+      .setSeed(42L)
+
+    ValidatorParamsSuiteHelpers.testFileMove(tvs)
   }
 
   test("read/write: TrainValidationSplitModel") {
@@ -160,7 +236,7 @@ class TrainValidationSplitSuite
   }
 }
 
-object TrainValidationSplitSuite {
+object TrainValidationSplitSuite extends SparkFunSuite{
 
   abstract class MyModel extends Model[MyModel]
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
new file mode 100644
index 0000000..1df673c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ValidatorParamsSuiteHelpers.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tuning
+
+import java.io.File
+import java.nio.file.{Files, StandardCopyOption}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLReader, MLWritable}
+
+object ValidatorParamsSuiteHelpers extends SparkFunSuite with DefaultReadWriteTest {
+  /**
+   * Assert sequences of estimatorParamMaps are identical.
+   * If the values for a parameter are not directly comparable with ===
+   * and are instead Params types themselves then their corresponding paramMaps
+   * are compared against each other.
+   */
+  def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = {
+    assert(pMaps.length === pMaps2.length)
+    pMaps.zip(pMaps2).foreach { case (pMap, pMap2) =>
+      assert(pMap.size === pMap2.size)
+      pMap.toSeq.foreach { case ParamPair(p, v) =>
+        assert(pMap2.contains(p))
+        val otherParam = pMap2(p)
+        v match {
+          case estimator: Params =>
+            otherParam match {
+              case estimator2: Params =>
+                val estimatorParamMap = Array(estimator.extractParamMap())
+                val estimatorParamMap2 = Array(estimator2.extractParamMap())
+                compareParamMaps(estimatorParamMap, estimatorParamMap2)
+              case other =>
+                throw new AssertionError(s"Expected parameter of type Params but" +
+                  s" found ${otherParam.getClass.getName}")
+            }
+          case _ =>
+            assert(otherParam === v)
+        }
+      }
+    }
+  }
+
+  /**
+   * When nested estimators (ex. OneVsRest) are saved within meta-algorithms such as
+   * CrossValidator and TrainValidationSplit, relative paths should be used to store
+   * the path of the estimator so that if the parent directory changes, loading the
+   * model still works.
+   */
+  def testFileMove[T <: Params with MLWritable](instance: T): Unit = {
+    val uid = instance.uid
+    val subdirName = Identifiable.randomUID("test")
+
+    val subdir = new File(tempDir, subdirName)
+    val subDirWithUid = new File(subdir, uid)
+
+    instance.save(subDirWithUid.getPath)
+
+    val newSubdirName = Identifiable.randomUID("test_moved")
+    val newSubdir = new File(tempDir, newSubdirName)
+    val newSubdirWithUid = new File(newSubdir, uid)
+
+    Files.createDirectory(newSubdir.toPath)
+    Files.createDirectory(newSubdirWithUid.toPath)
+    Files.move(subDirWithUid.toPath, newSubdirWithUid.toPath, StandardCopyOption.ATOMIC_MOVE)
+
+    val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
+    val newInstance = loader.load(newSubdirWithUid.getPath)
+    assert(uid == newInstance.uid)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 27d606c..4da95e7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -55,7 +55,6 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
     instance.write.overwrite().save(path)
     val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
     val newInstance = loader.load(path)
-
     assert(newInstance.uid === instance.uid)
     if (testParams) {
       instance.params.foreach { p =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 948806a..82207f6 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -25,7 +25,7 @@ from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel
 from pyspark.ml.util import *
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
 from pyspark.ml.wrapper import JavaWrapper
-from pyspark.ml.common import inherit_doc
+from pyspark.ml.common import inherit_doc, _java2py, _py2java
 from pyspark.sql import DataFrame
 from pyspark.sql.functions import udf, when
 from pyspark.sql.types import ArrayType, DoubleType
@@ -1472,7 +1472,7 @@ class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol):
 
 
 @inherit_doc
-class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
+class OneVsRest(Estimator, OneVsRestParams, JavaMLReadable, JavaMLWritable):
     """
     .. note:: Experimental
 
@@ -1589,22 +1589,6 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
             newOvr.setClassifier(self.getClassifier().copy(extra))
         return newOvr
 
-    @since("2.0.0")
-    def write(self):
-        """Returns an MLWriter instance for this ML instance."""
-        return JavaMLWriter(self)
-
-    @since("2.0.0")
-    def save(self, path):
-        """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
-        self.write().save(path)
-
-    @classmethod
-    @since("2.0.0")
-    def read(cls):
-        """Returns an MLReader instance for this class."""
-        return JavaMLReader(cls)
-
     @classmethod
     def _from_java(cls, java_stage):
         """
@@ -1634,8 +1618,52 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
         _java_obj.setPredictionCol(self.getPredictionCol())
         return _java_obj
 
+    def _make_java_param_pair(self, param, value):
+        """
+        Makes a Java param pair.
+        """
+        sc = SparkContext._active_spark_context
+        param = self._resolveParam(param)
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRest",
+                                             self.uid)
+        java_param = _java_obj.getParam(param.name)
+        if isinstance(value, JavaParams):
+            # used in the case of an estimator having another estimator as a parameter
+            # the reason why this is not in _py2java in common.py is that importing
+            # Estimator and Model in common.py results in a circular import with inherit_doc
+            java_value = value._to_java()
+        else:
+            java_value = _py2java(sc, value)
+        return java_param.w(java_value)
 
-class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
+    def _transfer_param_map_to_java(self, pyParamMap):
+        """
+        Transforms a Python ParamMap into a Java ParamMap.
+        """
+        paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+        for param in self.params:
+            if param in pyParamMap:
+                pair = self._make_java_param_pair(param, pyParamMap[param])
+                paramMap.put([pair])
+        return paramMap
+
+    def _transfer_param_map_from_java(self, javaParamMap):
+        """
+        Transforms a Java ParamMap into a Python ParamMap.
+        """
+        sc = SparkContext._active_spark_context
+        paramMap = dict()
+        for pair in javaParamMap.toList():
+            param = pair.param()
+            if self.hasParam(str(param.name())):
+                if param.name() == "classifier":
+                    paramMap[self.getParam(param.name())] = JavaParams._from_java(pair.value())
+                else:
+                    paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+        return paramMap
+
+
+class OneVsRestModel(Model, OneVsRestParams, JavaMLReadable, JavaMLWritable):
     """
     .. note:: Experimental
 
@@ -1650,6 +1678,16 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
     def __init__(self, models):
         super(OneVsRestModel, self).__init__()
         self.models = models
+        java_models = [model._to_java() for model in self.models]
+        sc = SparkContext._active_spark_context
+        java_models_array = JavaWrapper._new_java_array(java_models,
+                                                        sc._gateway.jvm.org.apache.spark.ml
+                                                        .classification.ClassificationModel)
+        # TODO: need to set metadata
+        metadata = JavaParams._new_java_obj("org.apache.spark.sql.types.Metadata")
+        self._java_obj = \
+            JavaParams._new_java_obj("org.apache.spark.ml.classification.OneVsRestModel",
+                                     self.uid, metadata.empty(), java_models_array)
 
     def _transform(self, dataset):
         # determine the input columns: these need to be passed through
@@ -1715,22 +1753,6 @@ class OneVsRestModel(Model, OneVsRestParams, MLReadable, MLWritable):
         newModel.models = [model.copy(extra) for model in self.models]
         return newModel
 
-    @since("2.0.0")
-    def write(self):
-        """Returns an MLWriter instance for this ML instance."""
-        return JavaMLWriter(self)
-
-    @since("2.0.0")
-    def save(self, path):
-        """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
-        self.write().save(path)
-
-    @classmethod
-    @since("2.0.0")
-    def read(cls):
-        """Returns an MLReader instance for this class."""
-        return JavaMLReader(cls)
-
     @classmethod
     def _from_java(cls, java_stage):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 7870047..6c71e69 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -49,7 +49,8 @@ from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
 from pyspark.ml.classification import *
 from pyspark.ml.clustering import *
 from pyspark.ml.common import _java2py, _py2java
-from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, \
+    MulticlassClassificationEvaluator, RegressionEvaluator
 from pyspark.ml.feature import *
 from pyspark.ml.fpm import FPGrowth, FPGrowthModel
 from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, MatrixUDT,
\
@@ -678,7 +679,7 @@ class CrossValidatorTests(SparkSessionTestCase):
                          "Best model should have zero induced error")
         self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
 
-    def test_save_load(self):
+    def test_save_load_trained_model(self):
         # This tests saving and loading the trained model only.
         # Save/load for CrossValidator will be added later: SPARK-13786
         temp_path = tempfile.mkdtemp()
@@ -702,6 +703,76 @@ class CrossValidatorTests(SparkSessionTestCase):
         self.assertEqual(loadedLrModel.uid, lrModel.uid)
         self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
 
+    def test_save_load_simple_estimator(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+
+        # test save/load of CrossValidator
+        cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        cvPath = temp_path + "/cv"
+        cv.save(cvPath)
+        loadedCV = CrossValidator.load(cvPath)
+        self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+        self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+        self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+
+        # test save/load of CrossValidatorModel
+        cvModelPath = temp_path + "/cvModel"
+        cvModel.save(cvModelPath)
+        loadedModel = CrossValidatorModel.load(cvModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
+    def test_save_load_nested_estimator(self):
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+
+        ova = OneVsRest(classifier=LogisticRegression())
+        lr1 = LogisticRegression().setMaxIter(100)
+        lr2 = LogisticRegression().setMaxIter(150)
+        grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+        evaluator = MulticlassClassificationEvaluator()
+
+        # test save/load of CrossValidator
+        cv = CrossValidator(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
+        cvModel = cv.fit(dataset)
+        cvPath = temp_path + "/cv"
+        cv.save(cvPath)
+        loadedCV = CrossValidator.load(cvPath)
+        self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+        self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+
+        originalParamMap = cv.getEstimatorParamMaps()
+        loadedParamMap = loadedCV.getEstimatorParamMaps()
+        for i, param in enumerate(loadedParamMap):
+            for p in param:
+                if p.name == "classifier":
+                    self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+                else:
+                    self.assertEqual(param[p], originalParamMap[i][p])
+
+        # test save/load of CrossValidatorModel
+        cvModelPath = temp_path + "/cvModel"
+        cvModel.save(cvModelPath)
+        loadedModel = CrossValidatorModel.load(cvModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
 
 class TrainValidationSplitTests(SparkSessionTestCase):
 
@@ -759,7 +830,7 @@ class TrainValidationSplitTests(SparkSessionTestCase):
                          "validationMetrics has the same size of grid parameter")
         self.assertEqual(1.0, max(validationMetrics))
 
-    def test_save_load(self):
+    def test_save_load_trained_model(self):
         # This tests saving and loading the trained model only.
         # Save/load for TrainValidationSplit will be added later: SPARK-13786
         temp_path = tempfile.mkdtemp()
@@ -783,6 +854,74 @@ class TrainValidationSplitTests(SparkSessionTestCase):
         self.assertEqual(loadedLrModel.uid, lrModel.uid)
         self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
 
+    def test_save_load_simple_estimator(self):
+        # This tests saving and loading the trained model only.
+        # Save/load for TrainValidationSplit will be added later: SPARK-13786
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+        tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+
+        tvsPath = temp_path + "/tvs"
+        tvs.save(tvsPath)
+        loadedTvs = TrainValidationSplit.load(tvsPath)
+        self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+        self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+        self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+
+        tvsModelPath = temp_path + "/tvsModel"
+        tvsModel.save(tvsModelPath)
+        loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
+    def test_save_load_nested_estimator(self):
+        # This tests saving and loading the trained model only.
+        # Save/load for TrainValidationSplit will be added later: SPARK-13786
+        temp_path = tempfile.mkdtemp()
+        dataset = self.spark.createDataFrame(
+            [(Vectors.dense([0.0]), 0.0),
+             (Vectors.dense([0.4]), 1.0),
+             (Vectors.dense([0.5]), 0.0),
+             (Vectors.dense([0.6]), 1.0),
+             (Vectors.dense([1.0]), 1.0)] * 10,
+            ["features", "label"])
+        ova = OneVsRest(classifier=LogisticRegression())
+        lr1 = LogisticRegression().setMaxIter(100)
+        lr2 = LogisticRegression().setMaxIter(150)
+        grid = ParamGridBuilder().addGrid(ova.classifier, [lr1, lr2]).build()
+        evaluator = MulticlassClassificationEvaluator()
+
+        tvs = TrainValidationSplit(estimator=ova, estimatorParamMaps=grid, evaluator=evaluator)
+        tvsModel = tvs.fit(dataset)
+        tvsPath = temp_path + "/tvs"
+        tvs.save(tvsPath)
+        loadedTvs = TrainValidationSplit.load(tvsPath)
+        self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+        self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+
+        originalParamMap = tvs.getEstimatorParamMaps()
+        loadedParamMap = loadedTvs.getEstimatorParamMaps()
+        for i, param in enumerate(loadedParamMap):
+            for p in param:
+                if p.name == "classifier":
+                    self.assertEqual(param[p].uid, originalParamMap[i][p].uid)
+                else:
+                    self.assertEqual(param[p], originalParamMap[i][p])
+
+        tvsModelPath = temp_path + "/tvsModel"
+        tvsModel.save(tvsModelPath)
+        loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+        self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
     def test_copy(self):
         dataset = self.spark.createDataFrame([
             (10, 10.0),

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index b648582..00c348a 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -20,8 +20,11 @@ import numpy as np
 
 from pyspark import since, keyword_only
 from pyspark.ml import Estimator, Model
+from pyspark.ml.common import _py2java
 from pyspark.ml.param import Params, Param, TypeConverters
 from pyspark.ml.param.shared import HasSeed
+from pyspark.ml.util import *
+from pyspark.ml.wrapper import JavaParams
 from pyspark.sql.functions import rand
 
 __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
@@ -137,8 +140,37 @@ class ValidatorParams(HasSeed):
         """
         return self.getOrDefault(self.evaluator)
 
+    @classmethod
+    def _from_java_impl(cls, java_stage):
+        """
+        Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
+        """
+
+        # Load information from java_stage to the instance.
+        estimator = JavaParams._from_java(java_stage.getEstimator())
+        evaluator = JavaParams._from_java(java_stage.getEvaluator())
+        epms = [estimator._transfer_param_map_from_java(epm)
+                for epm in java_stage.getEstimatorParamMaps()]
+        return estimator, epms, evaluator
+
+    def _to_java_impl(self):
+        """
+        Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
+        """
+
+        gateway = SparkContext._gateway
+        cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
+
+        java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
+        for idx, epm in enumerate(self.getEstimatorParamMaps()):
+            java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
 
-class CrossValidator(Estimator, ValidatorParams):
+        java_estimator = self.getEstimator()._to_java()
+        java_evaluator = self.getEvaluator()._to_java()
+        return java_estimator, java_epms, java_evaluator
+
+
+class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
     """
 
     K-fold cross validation performs model selection by splitting the dataset into a set
of
@@ -263,8 +295,53 @@ class CrossValidator(Estimator, ValidatorParams):
             newCV.setEvaluator(self.getEvaluator().copy(extra))
         return newCV
 
+    @since("2.3.0")
+    def write(self):
+        """Returns an MLWriter instance for this ML instance."""
+        return JavaMLWriter(self)
+
+    @classmethod
+    @since("2.3.0")
+    def read(cls):
+        """Returns an MLReader instance for this class."""
+        return JavaMLReader(cls)
+
+    @classmethod
+    def _from_java(cls, java_stage):
+        """
+        Given a Java CrossValidator, create and return a Python wrapper of it.
+        Used for ML persistence.
+        """
 
-class CrossValidatorModel(Model, ValidatorParams):
+        estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
+        numFolds = java_stage.getNumFolds()
+        seed = java_stage.getSeed()
+        # Create a new instance of this stage.
+        py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+                       numFolds=numFolds, seed=seed)
+        py_stage._resetUid(java_stage.uid())
+        return py_stage
+
+    def _to_java(self):
+        """
+        Transfer this instance to a Java CrossValidator. Used for ML persistence.
+
+        :return: Java object equivalent to this instance.
+        """
+
+        estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
+
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator",
self.uid)
+        _java_obj.setEstimatorParamMaps(epms)
+        _java_obj.setEvaluator(evaluator)
+        _java_obj.setEstimator(estimator)
+        _java_obj.setSeed(self.getSeed())
+        _java_obj.setNumFolds(self.getNumFolds())
+
+        return _java_obj
+
+
+class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
     """
 
     CrossValidatorModel contains the model with the highest average cross-validation
@@ -302,8 +379,55 @@ class CrossValidatorModel(Model, ValidatorParams):
         avgMetrics = self.avgMetrics
         return CrossValidatorModel(bestModel, avgMetrics)
 
+    @since("2.3.0")
+    def write(self):
+        """Returns an MLWriter instance for this ML instance."""
+        return JavaMLWriter(self)
+
+    @classmethod
+    @since("2.3.0")
+    def read(cls):
+        """Returns an MLReader instance for this class."""
+        return JavaMLReader(cls)
 
-class TrainValidationSplit(Estimator, ValidatorParams):
+    @classmethod
+    def _from_java(cls, java_stage):
+        """
+        Given a Java CrossValidatorModel, create and return a Python wrapper of it.
+        Used for ML persistence.
+        """
+
+        bestModel = JavaParams._from_java(java_stage.bestModel())
+        estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
+
+        py_stage = cls(bestModel=bestModel).setEstimator(estimator)
+        py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+
+        py_stage._resetUid(java_stage.uid())
+        return py_stage
+
+    def _to_java(self):
+        """
+        Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
+
+        :return: Java object equivalent to this instance.
+        """
+
+        sc = SparkContext._active_spark_context
+        # TODO: persist average metrics as well
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+                                             self.uid,
+                                             self.bestModel._to_java(),
+                                             _py2java(sc, []))
+        estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
+
+        _java_obj.set("evaluator", evaluator)
+        _java_obj.set("estimator", estimator)
+        _java_obj.set("estimatorParamMaps", epms)
+        return _java_obj
+
+
+class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
     """
     .. note:: Experimental
 
@@ -418,8 +542,53 @@ class TrainValidationSplit(Estimator, ValidatorParams):
             newTVS.setEvaluator(self.getEvaluator().copy(extra))
         return newTVS
 
+    @since("2.3.0")
+    def write(self):
+        """Returns an MLWriter instance for this ML instance."""
+        return JavaMLWriter(self)
+
+    @classmethod
+    @since("2.3.0")
+    def read(cls):
+        """Returns an MLReader instance for this class."""
+        return JavaMLReader(cls)
+
+    @classmethod
+    def _from_java(cls, java_stage):
+        """
+        Given a Java TrainValidationSplit, create and return a Python wrapper of it.
+        Used for ML persistence.
+        """
+
+        estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
+        trainRatio = java_stage.getTrainRatio()
+        seed = java_stage.getSeed()
+        # Create a new instance of this stage.
+        py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+                       trainRatio=trainRatio, seed=seed)
+        py_stage._resetUid(java_stage.uid())
+        return py_stage
+
+    def _to_java(self):
+        """
+        Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
+        :return: Java object equivalent to this instance.
+        """
+
+        estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
 
-class TrainValidationSplitModel(Model, ValidatorParams):
+        _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+                                             self.uid)
+        _java_obj.setEstimatorParamMaps(epms)
+        _java_obj.setEvaluator(evaluator)
+        _java_obj.setEstimator(estimator)
+        _java_obj.setTrainRatio(self.getTrainRatio())
+        _java_obj.setSeed(self.getSeed())
+
+        return _java_obj
+
+
+class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
     """
     .. note:: Experimental
 
@@ -456,6 +625,55 @@ class TrainValidationSplitModel(Model, ValidatorParams):
         validationMetrics = list(self.validationMetrics)
         return TrainValidationSplitModel(bestModel, validationMetrics)
 
+    @since("2.3.0")
+    def write(self):
+        """Returns an MLWriter instance for this ML instance."""
+        return JavaMLWriter(self)
+
+    @classmethod
+    @since("2.3.0")
+    def read(cls):
+        """Returns an MLReader instance for this class."""
+        return JavaMLReader(cls)
+
+    @classmethod
+    def _from_java(cls, java_stage):
+        """
+        Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
+        Used for ML persistence.
+        """
+
+        # Load information from java_stage to the instance.
+        bestModel = JavaParams._from_java(java_stage.bestModel())
+        estimator, epms, evaluator = super(TrainValidationSplitModel,
+                                           cls)._from_java_impl(java_stage)
+        # Create a new instance of this stage.
+        py_stage = cls(bestModel=bestModel).setEstimator(estimator)
+        py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
+
+        py_stage._resetUid(java_stage.uid())
+        return py_stage
+
+    def _to_java(self):
+        """
+        Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
+        :return: Java object equivalent to this instance.
+        """
+
+        sc = SparkContext._active_spark_context
+        # TODO: persst validation metrics as well
+        _java_obj = JavaParams._new_java_obj(
+            "org.apache.spark.ml.tuning.TrainValidationSplitModel",
+            self.uid,
+            self.bestModel._to_java(),
+            _py2java(sc, []))
+        estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
+
+        _java_obj.set("evaluator", evaluator)
+        _java_obj.set("estimator", estimator)
+        _java_obj.set("estimatorParamMaps", epms)
+        return _java_obj
+
 
 if __name__ == "__main__":
     import doctest

http://git-wip-us.apache.org/repos/asf/spark/blob/7047f49f/python/pyspark/ml/wrapper.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 80a0b31..ee6301e 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -106,7 +106,7 @@ class JavaParams(JavaWrapper, Params):
 
     def _make_java_param_pair(self, param, value):
         """
-        Makes a Java parm pair.
+        Makes a Java param pair.
         """
         sc = SparkContext._active_spark_context
         param = self._resolveParam(param)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message