spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-23455][ML] Default Params in ML should be saved separately in metadata
Date Tue, 24 Apr 2018 17:40:29 GMT
Repository: spark
Updated Branches:
  refs/heads/master ce7ba2e98 -> 83013752e


[SPARK-23455][ML] Default Params in ML should be saved separately in metadata

## What changes were proposed in this pull request?

We save ML's user-supplied params and default params as one entity in metadata. During loading the saved models, we set all the loaded params into created ML model instances as user-supplied params.

It causes some problems, e.g., if we strictly disallow some params to be set at the same time, a default param can fail the param check because it is treated as user-supplied param after loading.

The loaded default params should not be set as user-supplied params. We should save ML default params separately in metadata.

For backward compatibility, when loading metadata, if it is a metadata file from previous Spark, we shouldn't raise error if we can't find the default param field.

## How was this patch tested?

Pass existing tests and added tests.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #20633 from viirya/save-ml-default-params.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/83013752
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/83013752
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/83013752

Branch: refs/heads/master
Commit: 83013752e3cfcbc3edeef249439ac20b143eeabc
Parents: ce7ba2e
Author: Liang-Chi Hsieh <viirya@gmail.com>
Authored: Tue Apr 24 10:40:25 2018 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Tue Apr 24 10:40:25 2018 -0700

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala |   2 +-
 .../spark/ml/classification/GBTClassifier.scala |   4 +-
 .../spark/ml/classification/LinearSVC.scala     |   2 +-
 .../ml/classification/LogisticRegression.scala  |   2 +-
 .../MultilayerPerceptronClassifier.scala        |   2 +-
 .../spark/ml/classification/NaiveBayes.scala    |   2 +-
 .../spark/ml/classification/OneVsRest.scala     |   4 +-
 .../classification/RandomForestClassifier.scala |   4 +-
 .../spark/ml/clustering/BisectingKMeans.scala   |   2 +-
 .../spark/ml/clustering/GaussianMixture.scala   |   2 +-
 .../org/apache/spark/ml/clustering/KMeans.scala |   2 +-
 .../org/apache/spark/ml/clustering/LDA.scala    |   4 +-
 .../feature/BucketedRandomProjectionLSH.scala   |   2 +-
 .../apache/spark/ml/feature/Bucketizer.scala    |  24 ----
 .../apache/spark/ml/feature/ChiSqSelector.scala |   2 +-
 .../spark/ml/feature/CountVectorizer.scala      |   2 +-
 .../scala/org/apache/spark/ml/feature/IDF.scala |   2 +-
 .../org/apache/spark/ml/feature/Imputer.scala   |   2 +-
 .../apache/spark/ml/feature/MaxAbsScaler.scala  |   2 +-
 .../apache/spark/ml/feature/MinHashLSH.scala    |   2 +-
 .../apache/spark/ml/feature/MinMaxScaler.scala  |   2 +-
 .../ml/feature/OneHotEncoderEstimator.scala     |   2 +-
 .../scala/org/apache/spark/ml/feature/PCA.scala |   2 +-
 .../spark/ml/feature/QuantileDiscretizer.scala  |  24 ----
 .../org/apache/spark/ml/feature/RFormula.scala  |   6 +-
 .../spark/ml/feature/StandardScaler.scala       |   2 +-
 .../apache/spark/ml/feature/StringIndexer.scala |   2 +-
 .../apache/spark/ml/feature/VectorIndexer.scala |   2 +-
 .../org/apache/spark/ml/feature/Word2Vec.scala  |   2 +-
 .../org/apache/spark/ml/fpm/FPGrowth.scala      |   2 +-
 .../org/apache/spark/ml/param/params.scala      |  13 +-
 .../apache/spark/ml/recommendation/ALS.scala    |   2 +-
 .../ml/regression/AFTSurvivalRegression.scala   |   2 +-
 .../ml/regression/DecisionTreeRegressor.scala   |   2 +-
 .../spark/ml/regression/GBTRegressor.scala      |   4 +-
 .../GeneralizedLinearRegression.scala           |   2 +-
 .../ml/regression/IsotonicRegression.scala      |   2 +-
 .../spark/ml/regression/LinearRegression.scala  |   2 +-
 .../ml/regression/RandomForestRegressor.scala   |   4 +-
 .../apache/spark/ml/tuning/CrossValidator.scala |   6 +-
 .../spark/ml/tuning/TrainValidationSplit.scala  |   6 +-
 .../org/apache/spark/ml/util/ReadWrite.scala    | 130 ++++++++++++-------
 .../spark/ml/util/DefaultReadWriteTest.scala    |  73 ++++++++++-
 project/MimaExcludes.scala                      |   6 +
 44 files changed, 223 insertions(+), 147 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 771cd4f..57797d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -279,7 +279,7 @@ object DecisionTreeClassificationModel extends MLReadable[DecisionTreeClassifica
       val root = loadTreeNodes(path, metadata, sparkSession, isClassification = true)
       val model = new DecisionTreeClassificationModel(metadata.uid,
         root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c025510..0aa24f0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -379,14 +379,14 @@ object GBTClassificationModel extends MLReadable[GBTClassificationModel] {
         case (treeMetadata, root) =>
           val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
             root.asInstanceOf[RegressionNode], numFeatures)
-          DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+          treeMetadata.getAndSetParams(tree)
           tree
       }
       require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" +
         s" trees based on metadata but found ${trees.length} trees.")
       val model = new GBTClassificationModel(metadata.uid,
         trees, treeWeights, numFeatures)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 8f950cd..80c537e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -377,7 +377,7 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
       val Row(coefficients: Vector, intercept: Double) =
         data.select("coefficients", "intercept").head()
       val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index ee4b010..e426263 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1270,7 +1270,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
           numClasses, isMultinomial)
       }
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index af2e469..57ba47e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -388,7 +388,7 @@ object MultilayerPerceptronClassificationModel
       val weights = data.getAs[Vector](1)
       val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 0293e03..45fb585 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -407,7 +407,7 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
         .head()
       val model = new NaiveBayesModel(metadata.uid, pi, theta)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 5348d88..7df53a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -289,7 +289,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
         DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
       }
       val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
-      DefaultParamsReader.getAndSetParams(ovrModel, metadata)
+      metadata.getAndSetParams(ovrModel)
       ovrModel.set("classifier", classifier)
       ovrModel
     }
@@ -484,7 +484,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
     override def load(path: String): OneVsRest = {
       val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
       val ovr = new OneVsRest(metadata.uid)
-      DefaultParamsReader.getAndSetParams(ovr, metadata)
+      metadata.getAndSetParams(ovr)
       ovr.setClassifier(classifier)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index bb972e9..f1ef26a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -319,14 +319,14 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica
         case (treeMetadata, root) =>
           val tree = new DecisionTreeClassificationModel(treeMetadata.uid,
             root.asInstanceOf[ClassificationNode], numFeatures, numClasses)
-          DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+          treeMetadata.getAndSetParams(tree)
           tree
       }
       require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" +
         s" trees based on metadata but found ${trees.length} trees.")
 
       val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index f7c422d..addc12ac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -193,7 +193,7 @@ object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
       val dataPath = new Path(path, "data").toString
       val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath)
       val model = new BisectingKMeansModel(metadata.uid, mllibModel)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index f19ad7a..b580490 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -233,7 +233,7 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] {
       }
       val model = new GaussianMixtureModel(metadata.uid, weights, gaussians)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index d475c72..de61c9c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -280,7 +280,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
         sparkSession.read.parquet(dataPath).as[OldData].head().clusterCenters
       }
       val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 4bab670..4707723 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -366,7 +366,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
 private object LDAParams {
 
   /**
-   * Equivalent to [[DefaultParamsReader.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
+   * Equivalent to [[Metadata.getAndSetParams()]], but handles [[LDA]] and [[LDAModel]]
    * formats saved with Spark 1.6, which differ from the formats in Spark 2.0+.
    *
    * @param model    [[LDA]] or [[LDAModel]] instance.  This instance will be modified with
@@ -391,7 +391,7 @@ private object LDAParams {
               s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
         }
       case _ => // 2.0+
-        DefaultParamsReader.getAndSetParams(model, metadata)
+        metadata.getAndSetParams(model)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index 41eaaf9..a906e95 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -238,7 +238,7 @@ object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProject
       val model = new BucketedRandomProjectionLSHModel(metadata.uid,
         randUnitVectors.rowIter.toArray)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index f49c410..f99649f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -217,8 +217,6 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
   override def copy(extra: ParamMap): Bucketizer = {
     defaultCopy[Bucketizer](extra).setParent(parent)
   }
-
-  override def write: MLWriter = new Bucketizer.BucketizerWriter(this)
 }
 
 @Since("1.6.0")
@@ -296,28 +294,6 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
     }
   }
 
-
-  private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter {
-
-    override protected def saveImpl(path: String): Unit = {
-      // SPARK-23377: The default params will be saved and loaded as user-supplied params.
-      // Once `inputCols` is set, the default value of `outputCol` param causes the error
-      // when checking exclusive params. As a temporary to fix it, we skip the default value
-      // of `outputCol` if `inputCols` is set when saving the metadata.
-      // TODO: If we modify the persistence mechanism later to better handle default params,
-      // we can get rid of this.
-      var paramWithoutOutputCol: Option[JValue] = None
-      if (instance.isSet(instance.inputCols)) {
-        val params = instance.extractParamMap().toSeq
-        val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
-          p.name -> parse(p.jsonEncode(v))
-        }.toList
-        paramWithoutOutputCol = Some(render(jsonParams))
-      }
-      DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
-    }
-  }
-
   @Since("1.6.0")
   override def load(path: String): Bucketizer = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 16abc49..dbfb199 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -334,7 +334,7 @@ object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
       val selectedFeatures = data.getAs[Seq[Int]](0).toArray
       val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
       val model = new ChiSqSelectorModel(metadata.uid, oldModel)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 9e0ed43..10c48c3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -363,7 +363,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
         .head()
       val vocabulary = data.getAs[Seq[String]](0).toArray
       val model = new CountVectorizerModel(metadata.uid, vocabulary)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 46a0730..58897cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -182,7 +182,7 @@ object IDFModel extends MLReadable[IDFModel] {
         .select("idf")
         .head()
       val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf)))
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 730ee9f..1c074e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -262,7 +262,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
       val dataPath = new Path(path, "data").toString
       val surrogateDF = sqlContext.read.parquet(dataPath)
       val model = new ImputerModel(metadata.uid, surrogateDF)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 85f9732..90eceb0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -172,7 +172,7 @@ object MaxAbsScalerModel extends MLReadable[MaxAbsScalerModel] {
         .select("maxAbs")
         .head()
       val model = new MaxAbsScalerModel(metadata.uid, maxAbs)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 556848e..a67a3b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -205,7 +205,7 @@ object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
         .map(tuple => (tuple(0), tuple(1))).toArray
       val model = new MinHashLSHModel(metadata.uid, randCoefficients)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index f648dec..2e0ae4a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -243,7 +243,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
           .select("originalMin", "originalMax")
           .head()
       val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
index bd1e342..4a44f31 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala
@@ -386,7 +386,7 @@ object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
         .head()
       val categorySizes = data.getAs[Seq[Int]](0).toArray
       val model = new OneHotEncoderModel(metadata.uid, categorySizes)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 4143d86..8172491 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -220,7 +220,7 @@ object PCAModel extends MLReadable[PCAModel] {
         new PCAModel(metadata.uid, pc.asML,
           Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
       }
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 3b4c254..56e2c54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -253,35 +253,11 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
 
   @Since("1.6.0")
   override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
-
-  override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this)
 }
 
 @Since("1.6.0")
 object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
 
-  private[QuantileDiscretizer]
-  class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter {
-
-    override protected def saveImpl(path: String): Unit = {
-      // SPARK-23377: The default params will be saved and loaded as user-supplied params.
-      // Once `inputCols` is set, the default value of `outputCol` param causes the error
-      // when checking exclusive params. As a temporary to fix it, we skip the default value
-      // of `outputCol` if `inputCols` is set when saving the metadata.
-      // TODO: If we modify the persistence mechanism later to better handle default params,
-      // we can get rid of this.
-      var paramWithoutOutputCol: Option[JValue] = None
-      if (instance.isSet(instance.inputCols)) {
-        val params = instance.extractParamMap().toSeq
-        val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
-          p.name -> parse(p.jsonEncode(v))
-        }.toList
-        paramWithoutOutputCol = Some(render(jsonParams))
-      }
-      DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
-    }
-  }
-
   @Since("1.6.0")
   override def load(path: String): QuantileDiscretizer = super.load(path)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index e214765..55e595e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -446,7 +446,7 @@ object RFormulaModel extends MLReadable[RFormulaModel] {
 
       val model = new RFormulaModel(metadata.uid, resolvedRFormula, pipelineModel)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }
@@ -510,7 +510,7 @@ private object ColumnPruner extends MLReadable[ColumnPruner] {
       val columnsToPrune = data.getAs[Seq[String]](0).toSet
       val pruner = new ColumnPruner(metadata.uid, columnsToPrune)
 
-      DefaultParamsReader.getAndSetParams(pruner, metadata)
+      metadata.getAndSetParams(pruner)
       pruner
     }
   }
@@ -602,7 +602,7 @@ private object VectorAttributeRewriter extends MLReadable[VectorAttributeRewrite
       val prefixesToRewrite = data.getAs[Map[String, String]](1)
       val rewriter = new VectorAttributeRewriter(metadata.uid, vectorCol, prefixesToRewrite)
 
-      DefaultParamsReader.getAndSetParams(rewriter, metadata)
+      metadata.getAndSetParams(rewriter)
       rewriter
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 8f125d8..91b0707 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -212,7 +212,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
         .select("std", "mean")
         .head()
       val model = new StandardScalerModel(metadata.uid, std, mean)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 67cdb09..a833d8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -315,7 +315,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
         .head()
       val labels = data.getAs[Seq[String]](0).toArray
       val model = new StringIndexerModel(metadata.uid, labels)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index e6ec4e2..0e7396a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -537,7 +537,7 @@ object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
       val numFeatures = data.getAs[Int](0)
       val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1)
       val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index fe3306e..fc9996d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -410,7 +410,7 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
       }
 
       val model = new Word2VecModel(metadata.uid, oldModel)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 3d041fc..0bf405d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -335,7 +335,7 @@ object FPGrowthModel extends MLReadable[FPGrowthModel] {
       val dataPath = new Path(path, "data").toString
       val frequentItems = sparkSession.read.parquet(dataPath)
       val model = new FPGrowthModel(metadata.uid, frequentItems)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 9a83a58..e6c347e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -865,10 +865,10 @@ trait Params extends Identifiable with Serializable {
   }
 
   /** Internal param map for user-supplied values. */
-  private val paramMap: ParamMap = ParamMap.empty
+  private[ml] val paramMap: ParamMap = ParamMap.empty
 
   /** Internal param map for default values. */
-  private val defaultParamMap: ParamMap = ParamMap.empty
+  private[ml] val defaultParamMap: ParamMap = ParamMap.empty
 
   /** Validates that the input param belongs to this instance. */
   private def shouldOwn(param: Param[_]): Unit = {
@@ -905,6 +905,15 @@ trait Params extends Identifiable with Serializable {
   }
 }
 
+private[ml] object Params {
+  /**
+   * Sets a default param value for a `Params`.
+   */
+  private[ml] final def setDefault[T](params: Params, param: Param[T], value: T): Unit = {
+    params.defaultParamMap.put(param -> value)
+  }
+}
+
 /**
  * :: DeveloperApi ::
  * Java-friendly wrapper for [[Params]].

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 81a8f50..a23f955 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -529,7 +529,7 @@ object ALSModel extends MLReadable[ALSModel] {
 
       val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 4b46c38..7c6ec2a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -423,7 +423,7 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
           .head()
       val model = new AFTSurvivalRegressionModel(metadata.uid, coefficients, intercept, scale)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 5cef5c9..8bcf079 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -282,7 +282,7 @@ object DecisionTreeRegressionModel extends MLReadable[DecisionTreeRegressionMode
       val root = loadTreeNodes(path, metadata, sparkSession, isClassification = false)
       val model = new DecisionTreeRegressionModel(metadata.uid,
         root.asInstanceOf[RegressionNode], numFeatures)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 834aaa0..8598e80 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -311,7 +311,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
         case (treeMetadata, root) =>
           val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
             root.asInstanceOf[RegressionNode], numFeatures)
-          DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+          treeMetadata.getAndSetParams(tree)
           tree
       }
 
@@ -319,7 +319,7 @@ object GBTRegressionModel extends MLReadable[GBTRegressionModel] {
         s" trees based on metadata but found ${trees.length} trees.")
 
       val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 4c3f143..e030a40 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1146,7 +1146,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
 
       val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept)
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 8faab52..b046897 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -308,7 +308,7 @@ object IsotonicRegressionModel extends MLReadable[IsotonicRegressionModel] {
       val model = new IsotonicRegressionModel(
         metadata.uid, new MLlibIsotonicRegressionModel(boundaries, predictions, isotonic))
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 9cdd3a0..f1d9a44 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -799,7 +799,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
         new LinearRegressionModel(metadata.uid, coefficients, intercept, scale)
       }
 
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 7f77398..4509f85 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -276,14 +276,14 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode
       val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) =>
         val tree = new DecisionTreeRegressionModel(treeMetadata.uid,
           root.asInstanceOf[RegressionNode], numFeatures)
-        DefaultParamsReader.getAndSetParams(tree, treeMetadata)
+        treeMetadata.getAndSetParams(tree)
         tree
       }
       require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" +
         s" trees based on metadata but found ${trees.length} trees.")
 
       val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures)
-      DefaultParamsReader.getAndSetParams(model, metadata)
+      metadata.getAndSetParams(model)
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index c2826dc..5e916cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -234,8 +234,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
         .setEstimator(estimator)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(estimatorParamMaps)
-      DefaultParamsReader.getAndSetParams(cv, metadata,
-        skipParams = Option(List("estimatorParamMaps")))
+      metadata.getAndSetParams(cv, skipParams = Option(List("estimatorParamMaps")))
       cv
     }
   }
@@ -424,8 +423,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
       model.set(model.estimator, estimator)
         .set(model.evaluator, evaluator)
         .set(model.estimatorParamMaps, estimatorParamMaps)
-      DefaultParamsReader.getAndSetParams(model, metadata,
-        skipParams = Option(List("estimatorParamMaps")))
+      metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps")))
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 8d1b9a8..13369c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -228,8 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
         .setEstimator(estimator)
         .setEvaluator(evaluator)
         .setEstimatorParamMaps(estimatorParamMaps)
-      DefaultParamsReader.getAndSetParams(tvs, metadata,
-        skipParams = Option(List("estimatorParamMaps")))
+      metadata.getAndSetParams(tvs, skipParams = Option(List("estimatorParamMaps")))
       tvs
     }
   }
@@ -407,8 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
       model.set(model.estimator, estimator)
         .set(model.evaluator, evaluator)
         .set(model.estimatorParamMaps, estimatorParamMaps)
-      DefaultParamsReader.getAndSetParams(model, metadata,
-        skipParams = Option(List("estimatorParamMaps")))
+      metadata.getAndSetParams(model, skipParams = Option(List("estimatorParamMaps")))
       model
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 7edcd49..72a60e0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -39,7 +39,7 @@ import org.apache.spark.ml.feature.RFormulaModel
 import org.apache.spark.ml.param.{ParamPair, Params}
 import org.apache.spark.ml.tuning.ValidatorParams
 import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{Utils, VersionUtils}
 
 /**
  * Trait for `MLWriter` and `MLReader`.
@@ -421,6 +421,7 @@ private[ml] object DefaultParamsWriter {
    *  - timestamp
    *  - sparkVersion
    *  - uid
+   *  - defaultParamMap
    *  - paramMap
    *  - (optionally, extra metadata)
    *
@@ -453,15 +454,20 @@ private[ml] object DefaultParamsWriter {
       paramMap: Option[JValue] = None): String = {
     val uid = instance.uid
     val cls = instance.getClass.getName
-    val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
+    val params = instance.paramMap.toSeq
+    val defaultParams = instance.defaultParamMap.toSeq
     val jsonParams = paramMap.getOrElse(render(params.map { case ParamPair(p, v) =>
       p.name -> parse(p.jsonEncode(v))
     }.toList))
+    val jsonDefaultParams = render(defaultParams.map { case ParamPair(p, v) =>
+      p.name -> parse(p.jsonEncode(v))
+    }.toList)
     val basicMetadata = ("class" -> cls) ~
       ("timestamp" -> System.currentTimeMillis()) ~
       ("sparkVersion" -> sc.version) ~
       ("uid" -> uid) ~
-      ("paramMap" -> jsonParams)
+      ("paramMap" -> jsonParams) ~
+      ("defaultParamMap" -> jsonDefaultParams)
     val metadata = extraMetadata match {
       case Some(jObject) =>
         basicMetadata ~ jObject
@@ -488,7 +494,7 @@ private[ml] class DefaultParamsReader[T] extends MLReader[T] {
     val cls = Utils.classForName(metadata.className)
     val instance =
       cls.getConstructor(classOf[String]).newInstance(metadata.uid).asInstanceOf[Params]
-    DefaultParamsReader.getAndSetParams(instance, metadata)
+    metadata.getAndSetParams(instance)
     instance.asInstanceOf[T]
   }
 }
@@ -499,6 +505,8 @@ private[ml] object DefaultParamsReader {
    * All info from metadata file.
    *
    * @param params  paramMap, as a `JValue`
+   * @param defaultParams defaultParamMap, as a `JValue`. For metadata file prior to Spark 2.4,
+   *                      this is `JNothing`.
    * @param metadata  All metadata, including the other fields
    * @param metadataJson  Full metadata file String (for debugging)
    */
@@ -508,27 +516,90 @@ private[ml] object DefaultParamsReader {
       timestamp: Long,
       sparkVersion: String,
       params: JValue,
+      defaultParams: JValue,
       metadata: JValue,
       metadataJson: String) {
 
+
+    private def getValueFromParams(params: JValue): Seq[(String, JValue)] = {
+      params match {
+        case JObject(pairs) => pairs
+        case _ =>
+          throw new IllegalArgumentException(
+            s"Cannot recognize JSON metadata: $metadataJson.")
+      }
+    }
+
     /**
      * Get the JSON value of the [[org.apache.spark.ml.param.Param]] of the given name.
      * This can be useful for getting a Param value before an instance of `Params`
-     * is available.
+     * is available. This will look up `params` first, if not existing then looking up
+     * `defaultParams`.
      */
     def getParamValue(paramName: String): JValue = {
       implicit val format = DefaultFormats
-      params match {
+
+      // Looking up for `params` first.
+      var pairs = getValueFromParams(params)
+      var foundPairs = pairs.filter { case (pName, jsonValue) =>
+        pName == paramName
+      }
+      if (foundPairs.length == 0) {
+        // Looking up for `defaultParams` then.
+        pairs = getValueFromParams(defaultParams)
+        foundPairs = pairs.filter { case (pName, jsonValue) =>
+          pName == paramName
+        }
+      }
+      assert(foundPairs.length == 1, s"Expected one instance of Param '$paramName' but found" +
+        s" ${foundPairs.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
+
+      foundPairs.map(_._2).head
+    }
+
+    /**
+     * Extract Params from metadata, and set them in the instance.
+     * This works if all Params (except params included by `skipParams` list) implement
+     * [[org.apache.spark.ml.param.Param.jsonDecode()]].
+     *
+     * @param skipParams The params included in `skipParams` won't be set. This is useful if some
+     *                   params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
+     *                   and need special handling.
+     */
+    def getAndSetParams(
+        instance: Params,
+        skipParams: Option[List[String]] = None): Unit = {
+      setParams(instance, skipParams, isDefault = false)
+
+      // For metadata file prior to Spark 2.4, there is no default section.
+      val (major, minor) = VersionUtils.majorMinorVersion(sparkVersion)
+      if (major > 2 || (major == 2 && minor >= 4)) {
+        setParams(instance, skipParams, isDefault = true)
+      }
+    }
+
+    private def setParams(
+        instance: Params,
+        skipParams: Option[List[String]],
+        isDefault: Boolean): Unit = {
+      implicit val format = DefaultFormats
+      val paramsToSet = if (isDefault) defaultParams else params
+      paramsToSet match {
         case JObject(pairs) =>
-          val values = pairs.filter { case (pName, jsonValue) =>
-            pName == paramName
-          }.map(_._2)
-          assert(values.length == 1, s"Expected one instance of Param '$paramName' but found" +
-            s" ${values.length} in JSON Params: " + pairs.map(_.toString).mkString(", "))
-          values.head
+          pairs.foreach { case (paramName, jsonValue) =>
+            if (skipParams == None || !skipParams.get.contains(paramName)) {
+              val param = instance.getParam(paramName)
+              val value = param.jsonDecode(compact(render(jsonValue)))
+              if (isDefault) {
+                Params.setDefault(instance, param, value)
+              } else {
+                instance.set(param, value)
+              }
+            }
+          }
         case _ =>
           throw new IllegalArgumentException(
-            s"Cannot recognize JSON metadata: $metadataJson.")
+            s"Cannot recognize JSON metadata: ${metadataJson}.")
       }
     }
   }
@@ -561,43 +632,14 @@ private[ml] object DefaultParamsReader {
     val uid = (metadata \ "uid").extract[String]
     val timestamp = (metadata \ "timestamp").extract[Long]
     val sparkVersion = (metadata \ "sparkVersion").extract[String]
+    val defaultParams = metadata \ "defaultParamMap"
     val params = metadata \ "paramMap"
     if (expectedClassName.nonEmpty) {
       require(className == expectedClassName, s"Error loading metadata: Expected class name" +
         s" $expectedClassName but found class name $className")
     }
 
-    Metadata(className, uid, timestamp, sparkVersion, params, metadata, metadataStr)
-  }
-
-  /**
-   * Extract Params from metadata, and set them in the instance.
-   * This works if all Params (except params included by `skipParams` list) implement
-   * [[org.apache.spark.ml.param.Param.jsonDecode()]].
-   *
-   * @param skipParams The params included in `skipParams` won't be set. This is useful if some
-   *                   params don't implement [[org.apache.spark.ml.param.Param.jsonDecode()]]
-   *                   and need special handling.
-   * TODO: Move to [[Metadata]] method
-   */
-  def getAndSetParams(
-      instance: Params,
-      metadata: Metadata,
-      skipParams: Option[List[String]] = None): Unit = {
-    implicit val format = DefaultFormats
-    metadata.params match {
-      case JObject(pairs) =>
-        pairs.foreach { case (paramName, jsonValue) =>
-          if (skipParams == None || !skipParams.get.contains(paramName)) {
-            val param = instance.getParam(paramName)
-            val value = param.jsonDecode(compact(render(jsonValue)))
-            instance.set(param, value)
-          }
-        }
-      case _ =>
-        throw new IllegalArgumentException(
-          s"Cannot recognize JSON metadata: ${metadata.metadataJson}.")
-    }
+    Metadata(className, uid, timestamp, sparkVersion, params, defaultParams, metadata, metadataStr)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 4da95e7..4d9e664 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -19,9 +19,10 @@ package org.apache.spark.ml.util
 
 import java.io.{File, IOException}
 
+import org.json4s.JNothing
 import org.scalatest.Suite
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -129,6 +130,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
 class MyParams(override val uid: String) extends Params with MLWritable {
 
   final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
+  final val shouldNotSetIfSetintParamWithDefault: IntParam =
+    new IntParam(this, "shouldNotSetIfSetintParamWithDefault", "doc")
   final val intParam: IntParam = new IntParam(this, "intParam", "doc")
   final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc")
   final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc")
@@ -150,6 +153,13 @@ class MyParams(override val uid: String) extends Params with MLWritable {
   set(doubleArrayParam -> Array(8.0, 9.0))
   set(stringArrayParam -> Array("10", "11"))
 
+  def checkExclusiveParams(): Unit = {
+    if (isSet(shouldNotSetIfSetintParamWithDefault) && isSet(intParamWithDefault)) {
+      throw new SparkException("intParamWithDefault and shouldNotSetIfSetintParamWithDefault " +
+        "shouldn't be set at the same time")
+    }
+  }
+
   override def copy(extra: ParamMap): Params = defaultCopy(extra)
 
   override def write: MLWriter = new DefaultParamsWriter(this)
@@ -169,4 +179,65 @@ class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext
     val myParams = new MyParams("my_params")
     testDefaultReadWrite(myParams)
   }
+
+  test("default param shouldn't become user-supplied param after persistence") {
+    val myParams = new MyParams("my_params")
+    myParams.set(myParams.shouldNotSetIfSetintParamWithDefault, 1)
+    myParams.checkExclusiveParams()
+    val loadedMyParams = testDefaultReadWrite(myParams)
+    loadedMyParams.checkExclusiveParams()
+    assert(loadedMyParams.getDefault(loadedMyParams.intParamWithDefault) ==
+      myParams.getDefault(myParams.intParamWithDefault))
+
+    loadedMyParams.set(myParams.intParamWithDefault, 1)
+    intercept[SparkException] {
+      loadedMyParams.checkExclusiveParams()
+    }
+  }
+
+  test("User-supplied value for default param should be kept after persistence") {
+    val myParams = new MyParams("my_params")
+    myParams.set(myParams.intParamWithDefault, 100)
+    val loadedMyParams = testDefaultReadWrite(myParams)
+    assert(loadedMyParams.get(myParams.intParamWithDefault).get == 100)
+  }
+
+  test("Read metadata without default field prior to 2.4") {
+    // default params are saved in `paramMap` field in metadata file prior to Spark 2.4.
+    val metadata = """{"class":"org.apache.spark.ml.util.MyParams",
+      |"timestamp":1518852502761,"sparkVersion":"2.3.0",
+      |"uid":"my_params",
+      |"paramMap":{"intParamWithDefault":0}}""".stripMargin
+    val parsedMetadata = DefaultParamsReader.parseMetadata(metadata)
+    val myParams = new MyParams("my_params")
+    assert(!myParams.isSet(myParams.intParamWithDefault))
+    parsedMetadata.getAndSetParams(myParams)
+
+    // The behavior prior to Spark 2.4, default params are set in loaded ML instance.
+    assert(myParams.isSet(myParams.intParamWithDefault))
+  }
+
+  test("Should raise error when read metadata without default field after Spark 2.4") {
+    val myParams = new MyParams("my_params")
+
+    val metadata1 = """{"class":"org.apache.spark.ml.util.MyParams",
+      |"timestamp":1518852502761,"sparkVersion":"2.4.0",
+      |"uid":"my_params",
+      |"paramMap":{"intParamWithDefault":0}}""".stripMargin
+    val parsedMetadata1 = DefaultParamsReader.parseMetadata(metadata1)
+    val err1 = intercept[IllegalArgumentException] {
+      parsedMetadata1.getAndSetParams(myParams)
+    }
+    assert(err1.getMessage().contains("Cannot recognize JSON metadata"))
+
+    val metadata2 = """{"class":"org.apache.spark.ml.util.MyParams",
+      |"timestamp":1518852502761,"sparkVersion":"3.0.0",
+      |"uid":"my_params",
+      |"paramMap":{"intParamWithDefault":0}}""".stripMargin
+    val parsedMetadata2 = DefaultParamsReader.parseMetadata(metadata2)
+    val err2 = intercept[IllegalArgumentException] {
+      parsedMetadata2.getAndSetParams(myParams)
+    }
+    assert(err2.getMessage().contains("Cannot recognize JSON metadata"))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/83013752/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index a87fa68..7d0e88e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -62,6 +62,12 @@ object MimaExcludes {
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"),
     ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"),
 
+    // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="),
+
     // [SPARK-14681][ML] Provide label/impurity stats for spark.ml decision tree nodes
     ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.LeafNode"),
     ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message