spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-7176] [ML] Add validation functionality to Param
Date Thu, 30 Apr 2015 00:26:48 GMT
Repository: spark
Updated Branches:
  refs/heads/master 1fdfdb47b -> 114bad606


[SPARK-7176] [ML] Add validation functionality to Param

Main change: Added isValid field to Param.  Modified all usages to use isValid when relevant.  Added helper methods in ParamValidate.

Also overrode Params.validate() in:
* CrossValidator + model
* Pipeline + model

I made a few updates for the elastic net patch:
* I changed "tol" to "convergenceTol"
* I added some documentation

This PR is Scala + Java only.  Python will be in a follow-up PR.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5740 from jkbradley/enforce-validate and squashes the following commits:

ad9c6c1 [Joseph K. Bradley] re-generated sharedParams after merging with current master
76415e8 [Joseph K. Bradley] reverted convergenceTol to tol
af62f4b [Joseph K. Bradley] Removed changes to SparkBuild, python linalg.  Fixed test failures.  Renamed ParamValidate to ParamValidators.  Removed explicit type from ParamValidators calls where possible.
bb2665a [Joseph K. Bradley] merged with elastic net pr
ecda302 [Joseph K. Bradley] fix rat tests, plus add a little doc
6895dfc [Joseph K. Bradley] small cleanups
069ac6d [Joseph K. Bradley] many cleanups
928fb84 [Joseph K. Bradley] Maybe done
a910ac7 [Joseph K. Bradley] still workin
6d60e2e [Joseph K. Bradley] Still workin
b987319 [Joseph K. Bradley] Partly done with adding checks, but blocking on adding checking functionality to Param
dbc9fb2 [Joseph K. Bradley] merged with master.  enforcing Params.validate


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/114bad60
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/114bad60
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/114bad60

Branch: refs/heads/master
Commit: 114bad606e7a17f980ea6c99e31c8ab0179fec2e
Parents: 1fdfdb4
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Wed Apr 29 17:26:46 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Apr 29 17:26:46 2015 -0700

----------------------------------------------------------------------
 .../examples/ml/JavaDeveloperApiExample.java    |  14 +-
 .../scala/org/apache/spark/ml/Pipeline.scala    |  19 +-
 .../spark/ml/classification/GBTClassifier.scala |  13 +-
 .../org/apache/spark/ml/feature/HashingTF.scala |  12 +-
 .../apache/spark/ml/feature/Normalizer.scala    |  11 +-
 .../spark/ml/feature/PolynomialExpansion.scala  |   9 +-
 .../spark/ml/feature/StandardScaler.scala       |  10 +-
 .../org/apache/spark/ml/feature/Tokenizer.scala |  20 ++-
 .../apache/spark/ml/feature/VectorIndexer.scala |  18 +-
 .../apache/spark/ml/impl/tree/treeParams.scala  | 115 ++++--------
 .../org/apache/spark/ml/param/params.scala      | 179 +++++++++++++++++--
 .../ml/param/shared/SharedParamsCodeGen.scala   |  35 ++--
 .../spark/ml/param/shared/sharedParams.scala    | 122 +++++--------
 .../apache/spark/ml/recommendation/ALS.scala    |  35 ++--
 .../spark/ml/regression/GBTRegressor.scala      |  13 +-
 .../spark/ml/regression/LinearRegression.scala  |  16 +-
 .../apache/spark/ml/tuning/CrossValidator.scala |  22 ++-
 .../spark/mllib/tree/GradientBoostedTrees.scala |   4 +-
 .../apache/spark/ml/param/JavaParamsSuite.java  |  66 +++++++
 .../apache/spark/ml/param/JavaTestParams.java   |  63 +++++++
 .../org/apache/spark/ml/param/ParamsSuite.scala |  69 ++++++-
 .../org/apache/spark/ml/param/TestParams.scala  |   2 +-
 22 files changed, 593 insertions(+), 274 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index eaf00d0..46377a9 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.Classifier;
 import org.apache.spark.ml.classification.ClassificationModel;
 import org.apache.spark.ml.param.IntParam;
 import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.param.Params;
 import org.apache.spark.ml.param.Params$;
 import org.apache.spark.mllib.linalg.BLAS;
 import org.apache.spark.mllib.linalg.Vector;
@@ -100,11 +99,12 @@ public class JavaDeveloperApiExample {
 /**
  * Example of defining a type of {@link Classifier}.
  *
- * NOTE: This is private since it is an example.  In practice, you may not want it to be private.
+ * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to
+ *       {@link org.apache.spark.ml.param.Params#set} using incompatible return types.
+ *       However, this should still compile and run successfully.
  */
 class MyJavaLogisticRegression
-    extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel>
-    implements Params {
+    extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
 
   /**
    * Param for max number of iterations
@@ -145,10 +145,12 @@ class MyJavaLogisticRegression
 /**
  * Example of defining a type of {@link ClassificationModel}.
  *
- * NOTE: This is private since it is an example.  In practice, you may not want it to be private.
+ * Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to
+ *       {@link org.apache.spark.ml.param.Params#set} using incompatible return types.
+ *       However, this should still compile and run successfully.
  */
 class MyJavaLogisticRegressionModel
-    extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
+    extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
 
   private MyJavaLogisticRegression parent_;
   public MyJavaLogisticRegression parent() { return parent_; }

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 8eddf79..6bfeecd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.param.{Param, ParamMap}
+import org.apache.spark.ml.param.{Params, Param, ParamMap}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
@@ -86,6 +86,14 @@ class Pipeline extends Estimator[PipelineModel] {
   def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
   def getStages: Array[PipelineStage] = getOrDefault(stages)
 
+  override def validate(paramMap: ParamMap): Unit = {
+    val map = extractParamMap(paramMap)
+    getStages.foreach {
+      case pStage: Params => pStage.validate(map)
+      case _ =>
+    }
+  }
+
   /**
    * Fits the pipeline to the input dataset with additional parameters. If a stage is an
    * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
@@ -140,7 +148,7 @@ class Pipeline extends Estimator[PipelineModel] {
   override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = extractParamMap(paramMap)
     val theStages = map(stages)
-    require(theStages.toSet.size == theStages.size,
+    require(theStages.toSet.size == theStages.length,
       "Cannot have duplicate components in a pipeline.")
     theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
   }
@@ -157,6 +165,11 @@ class PipelineModel private[ml] (
     private[ml] val stages: Array[Transformer])
   extends Model[PipelineModel] with Logging {
 
+  override def validate(paramMap: ParamMap): Unit = {
+    val map = fittingParamMap ++ extractParamMap(paramMap)
+    stages.foreach(_.validate(map))
+  }
+
   /**
    * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
    * estimator does not exist in the pipeline.
@@ -168,7 +181,7 @@ class PipelineModel private[ml] (
     }
     if (matched.isEmpty) {
       throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
-    } else if (matched.size > 1) {
+    } else if (matched.length > 1) {
       throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
     } else {
       matched.head.asInstanceOf[M]

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index d2e052f..3d84986 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -103,21 +103,16 @@ final class GBTClassifier
    */
   val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
     " tries to minimize (case-insensitive). Supported options:" +
-    s" ${GBTClassifier.supportedLossTypes.mkString(", ")}")
+    s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
+    (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
 
   setDefault(lossType -> "logistic")
 
   /** @group setParam */
-  def setLossType(value: String): this.type = {
-    val lossStr = value.toLowerCase
-    require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" +
-      s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}")
-    set(lossType, lossStr)
-    this
-  }
+  def setLossType(value: String): this.type = set(lossType, value)
 
   /** @group getParam */
-  def getLossType: String = getOrDefault(lossType)
+  def getLossType: String = getOrDefault(lossType).toLowerCase
 
   /** (private[ml]) Convert new loss to old loss. */
   override private[ml] def getOldLossType: OldLoss = {

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index b20f2fc..0b3128f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
 import org.apache.spark.sql.types.DataType
@@ -32,10 +32,14 @@ import org.apache.spark.sql.types.DataType
 class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
 
   /**
-   * number of features
+   * Number of features.  Should be > 0.
+   * (default = 2^18^)
    * @group param
    */
-  val numFeatures = new IntParam(this, "numFeatures", "number of features")
+  val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
+    ParamValidators.gt(0))
+
+  setDefault(numFeatures -> (1 << 18))
 
   /** @group getParam */
   def getNumFeatures: Int = getOrDefault(numFeatures)
@@ -43,8 +47,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
   /** @group setParam */
   def setNumFeatures(value: Int): this.type = set(numFeatures, value)
 
-  setDefault(numFeatures -> (1 << 18))
-
   override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
     val hashingTF = new feature.HashingTF(paramMap(numFeatures))
     hashingTF.transform

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index decaeb0..bd2b5f6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{DoubleParam, ParamMap}
+import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap}
 import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
 import org.apache.spark.sql.types.DataType
@@ -32,10 +32,13 @@ import org.apache.spark.sql.types.DataType
 class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
 
   /**
-   * Normalization in L^p^ space, p = 2 by default.
+   * Normalization in L^p^ space.  Must be >= 1.
+   * (default: p = 2)
    * @group param
    */
-  val p = new DoubleParam(this, "p", "the p norm value")
+  val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1))
+
+  setDefault(p -> 2.0)
 
   /** @group getParam */
   def getP: Double = getOrDefault(p)
@@ -43,8 +46,6 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
   /** @group setParam */
   def setP(value: Double): this.type = set(p, value)
 
-  setDefault(p -> 2.0)
-
   override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
     val normalizer = new feature.Normalizer(paramMap(p))
     normalizer.transform

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index d855f04..1b7c939 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.sql.types.DataType
 
@@ -37,10 +37,13 @@ import org.apache.spark.sql.types.DataType
 class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
 
   /**
-   * The polynomial degree to expand, which should be larger than 1.
+   * The polynomial degree to expand, which should be >= 1.  A value of 1 means no expansion.
+   * Default: 2
    * @group param
    */
-  val degree = new IntParam(this, "degree", "the polynomial degree to expand")
+  val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)",
+    ParamValidators.gt(1))
+
   setDefault(degree -> 2)
 
   /** @group getParam */

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 447851e..a0e9ed3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -31,17 +31,19 @@ import org.apache.spark.sql.types.{StructField, StructType}
  * Params for [[StandardScaler]] and [[StandardScalerModel]].
  */
 private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
-  
+
   /**
-   * False by default. Centers the data with mean before scaling. 
+   * Centers the data with mean before scaling.
    * It will build a dense output, so this does not work on sparse input 
    * and will raise an exception.
+   * Default: false
    * @group param
    */
   val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
   
   /**
-   * True by default. Scales the data to unit standard deviation.
+   * Scales the data to unit standard deviation.
+   * Default: true
    * @group param
    */
   val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
@@ -56,7 +58,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
 class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
 
   setDefault(withMean -> false, withStd -> true)
-  
+
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 376a004..01752ba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
 
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param}
+import org.apache.spark.ml.param._
 import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
 
 /**
@@ -43,20 +43,20 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
 /**
  * :: AlphaComponent ::
  * A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default) 
- * or using it to split the text (set matching to false). Optional parameters also allow to fold
- * the text to lowercase prior to it being tokenized and to filer tokens using a minimal length. 
+ * or using it to split the text (set matching to false). Optional parameters also allow filtering
+ * tokens using a minimal length.
  * It returns an array of strings that can be empty.
- * The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true, 
- * lowercase = false, minTokenLength = 1
  */
 @AlphaComponent
 class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
 
   /**
-   * param for minimum token length, default is one to avoid returning empty strings
+   * Minimum token length, >= 0.
+   * Default: 1, to avoid returning empty strings
    * @group param
    */
-  val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length")
+  val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)",
+    ParamValidators.gtEq(0))
 
   /** @group setParam */
   def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
@@ -65,7 +65,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
   def getMinTokenLength: Int = getOrDefault(minTokenLength)
 
   /**
-   * param sets regex as splitting on gaps (true) or matching tokens (false)
+   * Indicates whether regex splits on gaps (true) or matching tokens (false).
+   * Default: false
    * @group param
    */
   val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")
@@ -77,7 +78,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
   def getGaps: Boolean = getOrDefault(gaps)
 
   /**
-   * param sets regex pattern used by tokenizer 
+   * Regex pattern used by tokenizer.
+   * Default: `"\\p{L}+|[^\\p{L}\\s]+"`
    * @group param
    */
   val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 1e5ffd1..ed833c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute,
   Attribute, AttributeGroup}
-import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
+import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params}
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
 import org.apache.spark.sql.{Row, DataFrame}
@@ -37,17 +37,19 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
   /**
    * Threshold for the number of values a categorical feature can take.
    * If a feature is found to have > maxCategories values, then it is declared continuous.
+   * Must be >= 2.
    *
    * (default = 20)
    */
   val maxCategories = new IntParam(this, "maxCategories",
-    "Threshold for the number of values a categorical feature can take." +
-      " If a feature is found to have > maxCategories values, then it is declared continuous.")
+    "Threshold for the number of values a categorical feature can take (>= 2)." +
+      " If a feature is found to have > maxCategories values, then it is declared continuous.",
+    ParamValidators.gtEq(2))
+
+  setDefault(maxCategories -> 20)
 
   /** @group getParam */
   def getMaxCategories: Int = getOrDefault(maxCategories)
-
-  setDefault(maxCategories -> 20)
 }
 
 /**
@@ -90,11 +92,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
 class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams {
 
   /** @group setParam */
-  def setMaxCategories(value: Int): this.type = {
-    require(value > 1,
-      s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.")
-    set(maxCategories, value)
-  }
+  def setMaxCategories(value: Int): this.type = set(maxCategories, value)
 
   /** @group setParam */
   def setInputCol(value: String): this.type = set(inputCol, value)

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
index ab6281b..fb77062 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -38,14 +38,15 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
 private[ml] trait DecisionTreeParams extends PredictorParams {
 
   /**
-   * Maximum depth of the tree.
+   * Maximum depth of the tree (>= 0).
    * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
    * (default = 5)
    * @group param
    */
   final val maxDepth: IntParam =
-    new IntParam(this, "maxDepth", "Maximum depth of the tree." +
-      " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
+    new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
+      " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
+      ParamValidators.gtEq(0))
 
   /**
    * Maximum number of bins used for discretizing continuous features and for choosing how to split
@@ -56,7 +57,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
    */
   final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
     " discretizing continuous features.  Must be >=2 and >= number of categories for any" +
-    " categorical feature.")
+    " categorical feature.", ParamValidators.gtEq(2))
 
   /**
    * Minimum number of instances each child must have after split.
@@ -69,7 +70,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
   final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
     " number of instances each child must have after split.  If a split causes the left or right" +
     " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
-    " Should be >= 1.")
+    " Should be >= 1.", ParamValidators.gtEq(1))
 
   /**
    * Minimum information gain for a split to be considered at a tree node.
@@ -85,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
    * @group expertParam
    */
   final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
-    "Maximum memory in MB allocated to histogram aggregation.")
+    "Maximum memory in MB allocated to histogram aggregation.",
+    ParamValidators.gtEq(0))
 
   /**
    * If false, the algorithm will pass trees to executors to match instances with nodes.
@@ -111,34 +113,26 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
   final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
     " how often to checkpoint the cached node IDs.  E.g. 10 means that the cache will get" +
     " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
-    " checkpoint directory is set in the SparkContext. Must be >= 1.")
+    " checkpoint directory is set in the SparkContext. Must be >= 1.",
+    ParamValidators.gtEq(1))
 
   setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
     maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
 
   /** @group setParam */
-  def setMaxDepth(value: Int): this.type = {
-    require(value >= 0, s"maxDepth parameter must be >= 0.  Given bad value: $value")
-    set(maxDepth, value)
-  }
+  def setMaxDepth(value: Int): this.type = set(maxDepth, value)
 
   /** @group getParam */
   final def getMaxDepth: Int = getOrDefault(maxDepth)
 
   /** @group setParam */
-  def setMaxBins(value: Int): this.type = {
-    require(value >= 2, s"maxBins parameter must be >= 2.  Given bad value: $value")
-    set(maxBins, value)
-  }
+  def setMaxBins(value: Int): this.type = set(maxBins, value)
 
   /** @group getParam */
   final def getMaxBins: Int = getOrDefault(maxBins)
 
   /** @group setParam */
-  def setMinInstancesPerNode(value: Int): this.type = {
-    require(value >= 1, s"minInstancesPerNode parameter must be >= 1.  Given bad value: $value")
-    set(minInstancesPerNode, value)
-  }
+  def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
 
   /** @group getParam */
   final def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
@@ -150,10 +144,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
   final def getMinInfoGain: Double = getOrDefault(minInfoGain)
 
   /** @group expertSetParam */
-  def setMaxMemoryInMB(value: Int): this.type = {
-    require(value > 0, s"maxMemoryInMB parameter must be > 0.  Given bad value: $value")
-    set(maxMemoryInMB, value)
-  }
+  def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
 
   /** @group expertGetParam */
   final def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
@@ -165,10 +156,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
   final def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
 
   /** @group expertSetParam */
-  def setCheckpointInterval(value: Int): this.type = {
-    require(value >= 1, s"checkpointInterval parameter must be >= 1.  Given bad value: $value")
-    set(checkpointInterval, value)
-  }
+  def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
 
   /** @group expertGetParam */
   final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
@@ -209,21 +197,16 @@ private[ml] trait TreeClassifierParams extends Params {
    */
   final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
     " information gain calculation (case-insensitive). Supported options:" +
-    s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+    s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
+    (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
 
   setDefault(impurity -> "gini")
 
   /** @group setParam */
-  def setImpurity(value: String): this.type = {
-    val impurityStr = value.toLowerCase
-    require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
-      s"Tree-based classifier was given unrecognized impurity: $value." +
-      s"  Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
-    set(impurity, impurityStr)
-  }
+  def setImpurity(value: String): this.type = set(impurity, value)
 
   /** @group getParam */
-  final def getImpurity: String = getOrDefault(impurity)
+  final def getImpurity: String = getOrDefault(impurity).toLowerCase
 
   /** Convert new impurity to old impurity. */
   private[ml] def getOldImpurity: OldImpurity = {
@@ -256,21 +239,16 @@ private[ml] trait TreeRegressorParams extends Params {
    */
   final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
     " information gain calculation (case-insensitive). Supported options:" +
-    s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+    s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
+    (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
 
   setDefault(impurity -> "variance")
 
   /** @group setParam */
-  def setImpurity(value: String): this.type = {
-    val impurityStr = value.toLowerCase
-    require(TreeRegressorParams.supportedImpurities.contains(impurityStr),
-      s"Tree-based regressor was given unrecognized impurity: $value." +
-        s"  Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
-    set(impurity, impurityStr)
-  }
+  def setImpurity(value: String): this.type = set(impurity, value)
 
   /** @group getParam */
-  final def getImpurity: String = getOrDefault(impurity)
+  final def getImpurity: String = getOrDefault(impurity).toLowerCase
 
   /** Convert new impurity to old impurity. */
   private[ml] def getOldImpurity: OldImpurity = {
@@ -299,21 +277,18 @@ private[ml] object TreeRegressorParams {
 private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
 
   /**
-   * Fraction of the training data used for learning each decision tree.
+   * Fraction of the training data used for learning each decision tree, in range (0, 1].
    * (default = 1.0)
    * @group param
    */
   final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
-    "Fraction of the training data used for learning each decision tree.")
+    "Fraction of the training data used for learning each decision tree, in range (0, 1].",
+    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
 
   setDefault(subsamplingRate -> 1.0)
 
   /** @group setParam */
-  def setSubsamplingRate(value: Double): this.type = {
-    require(value > 0.0 && value <= 1.0,
-      s"Subsampling rate must be in range (0,1]. Bad rate: $value")
-    set(subsamplingRate, value)
-  }
+  def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
 
   /** @group getParam */
   final def getSubsamplingRate: Double = getOrDefault(subsamplingRate)
@@ -350,7 +325,8 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
    * (default = 20)
    * @group param
    */
-  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)")
+  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+    ParamValidators.gtEq(1))
 
   /**
    * The number of features to consider for splits at each tree node.
@@ -378,30 +354,23 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
    */
   final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
     "The number of features to consider for splits at each tree node." +
-      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
+      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
+    (value: String) =>
+      RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
 
   setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
 
   /** @group setParam */
-  def setNumTrees(value: Int): this.type = {
-    require(value >= 1, s"Random Forest numTrees parameter cannot be $value; it must be >= 1.")
-    set(numTrees, value)
-  }
+  def setNumTrees(value: Int): this.type = set(numTrees, value)
 
   /** @group getParam */
   final def getNumTrees: Int = getOrDefault(numTrees)
 
   /** @group setParam */
-  def setFeatureSubsetStrategy(value: String): this.type = {
-    val strategyStr = value.toLowerCase
-    require(RandomForestParams.supportedFeatureSubsetStrategies.contains(strategyStr),
-      s"RandomForestParams was given unrecognized featureSubsetStrategy: $value. Supported" +
-        s" options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}")
-    set(featureSubsetStrategy, strategyStr)
-  }
+  def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
 
   /** @group getParam */
-  final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy)
+  final def getFeatureSubsetStrategy: String = getOrDefault(featureSubsetStrategy).toLowerCase
 }
 
 private[ml] object RandomForestParams {
@@ -426,7 +395,8 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
    * @group param
    */
   final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
-    " learning rate) in interval (0, 1] for shrinking the contribution of each estimator")
+    " learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
+    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
 
   /* TODO: Add this doc when we add this param.  SPARK-7132
    * Threshold for stopping early when runWithValidation is used.
@@ -442,17 +412,10 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
   setDefault(maxIter -> 20, stepSize -> 0.1)
 
   /** @group setParam */
-  def setMaxIter(value: Int): this.type = {
-    require(value >= 1, s"Gradient Boosting maxIter parameter cannot be $value; it must be >= 1.")
-    set(maxIter, value)
-  }
+  def setMaxIter(value: Int): this.type = set(maxIter, value)
 
   /** @group setParam */
-  def setStepSize(value: Double): this.type = {
-    require(value > 0.0 && value <= 1.0,
-      s"GBT given invalid step size ($value).  Value should be in (0,1].")
-    set(stepSize, value)
-  }
+  def setStepSize(value: Double): this.type = set(stepSize, value)
 
   /** @group getParam */
   final def getStepSize: Double = getOrDefault(stepSize)

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 014e124..df6360d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -34,10 +34,35 @@ import org.apache.spark.ml.util.Identifiable
  * @param parent parent object
  * @param name param name
  * @param doc documentation
+ * @param isValid optional validation method which indicates if a value is valid.
+ *                See [[ParamValidators]] for factory methods for common validation functions.
  * @tparam T param value type
  */
 @AlphaComponent
-class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable {
+class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean)
+  extends Serializable {
+
+  def this(parent: Params, name: String, doc: String) =
+    this(parent, name, doc, ParamValidators.alwaysTrue[T])
+
+  /**
+   * Assert that the given value is valid for this parameter.
+   *
+   * Note: Parameter checks involving interactions between multiple parameters should be
+   *       implemented in [[Params.validate()]].  Checks for input/output columns should be
+   *       implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
+   *
+   * DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
+   *             should be specified via [[ParamPair]].
+   *
+   * @throws IllegalArgumentException if the value is invalid
+   */
+  private[param] def validate(value: T): Unit = {
+    if (!isValid(value)) {
+      throw new IllegalArgumentException(s"$parent parameter $name given invalid value $value." +
+        s" Parameter description: $toString")
+    }
+  }
 
   /**
    * Creates a param pair with the given value (for Java).
@@ -65,38 +90,129 @@ class Param[T] (val parent: Params, val name: String, val doc: String) extends S
   }
 }
 
+/**
+ * Factory methods for common validation functions for [[Param.isValid]].
+ * The numerical methods only support Int, Long, Float, and Double.
+ */
+object ParamValidators {
+
+  /** (private[param]) Default validation always return true */
+  private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true
+
+  /**
+   * Private method for checking numerical types and converting to Double.
+   * This is mainly for the sake of compilation; type checks are really handled
+   * by [[Params]] setters and the [[ParamPair]] constructor.
+   */
+  private def getDouble[T](value: T): Double = value match {
+    case x: Int => x.toDouble
+    case x: Long => x.toDouble
+    case x: Float => x.toDouble
+    case x: Double => x.toDouble
+    case _ =>
+      // The type should be checked before this is ever called.
+      throw new IllegalArgumentException("Numerical Param validation failed because" +
+        s" of unexpected input type: ${value.getClass}")
+  }
+
+  /** Check if value > lowerBound */
+  def gt[T](lowerBound: Double): T => Boolean = { (value: T) =>
+    getDouble(value) > lowerBound
+  }
+
+  /** Check if value >= lowerBound */
+  def gtEq[T](lowerBound: Double): T => Boolean = { (value: T) =>
+    getDouble(value) >= lowerBound
+  }
+
+  /** Check if value < upperBound */
+  def lt[T](upperBound: Double): T => Boolean = { (value: T) =>
+    getDouble(value) < upperBound
+  }
+
+  /** Check if value <= upperBound */
+  def ltEq[T](upperBound: Double): T => Boolean = { (value: T) =>
+    getDouble(value) <= upperBound
+  }
+
+  /**
+   * Check for value in range lowerBound to upperBound.
+   * @param lowerInclusive  If true, check for value >= lowerBound.
+   *                        If false, check for value > lowerBound.
+   * @param upperInclusive  If true, check for value <= upperBound.
+   *                        If false, check for value < upperBound.
+   */
+  def inRange[T](
+      lowerBound: Double,
+      upperBound: Double,
+      lowerInclusive: Boolean,
+      upperInclusive: Boolean): T => Boolean = { (value: T) =>
+    val x: Double = getDouble(value)
+    val lowerValid = if (lowerInclusive) x >= lowerBound else x > lowerBound
+    val upperValid = if (upperInclusive) x <= upperBound else x < upperBound
+    lowerValid && upperValid
+  }
+
+  /** Version of [[inRange()]] which uses inclusive be default: [lowerBound, upperBound] */
+  def inRange[T](lowerBound: Double, upperBound: Double): T => Boolean = {
+    inRange[T](lowerBound, upperBound, lowerInclusive = true, upperInclusive = true)
+  }
+
+  /** Check for value in an allowed set of values. */
+  def inArray[T](allowed: Array[T]): T => Boolean = { (value: T) =>
+    allowed.contains(value)
+  }
+
+  /** Check for value in an allowed set of values. */
+  def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) =>
+    allowed.contains(value)
+  }
+}
+
 // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
 
 /** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String)
-  extends Param[Double](parent, name, doc) {
+class DoubleParam(parent: Params, name: String, doc: String, isValid: Double => Boolean)
+  extends Param[Double](parent, name, doc, isValid) {
+
+  def this(parent: Params, name: String, doc: String) =
+    this(parent, name, doc, ParamValidators.alwaysTrue)
 
   override def w(value: Double): ParamPair[Double] = super.w(value)
 }
 
 /** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String)
-  extends Param[Int](parent, name, doc) {
+class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean)
+  extends Param[Int](parent, name, doc, isValid) {
+
+  def this(parent: Params, name: String, doc: String) =
+    this(parent, name, doc, ParamValidators.alwaysTrue)
 
   override def w(value: Int): ParamPair[Int] = super.w(value)
 }
 
 /** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String)
-  extends Param[Float](parent, name, doc) {
+class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean)
+  extends Param[Float](parent, name, doc, isValid) {
+
+  def this(parent: Params, name: String, doc: String) =
+    this(parent, name, doc, ParamValidators.alwaysTrue)
 
   override def w(value: Float): ParamPair[Float] = super.w(value)
 }
 
 /** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String)
-  extends Param[Long](parent, name, doc) {
+class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean)
+  extends Param[Long](parent, name, doc, isValid) {
+
+  def this(parent: Params, name: String, doc: String) =
+    this(parent, name, doc, ParamValidators.alwaysTrue)
 
   override def w(value: Long): ParamPair[Long] = super.w(value)
 }
 
 /** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String)
+class BooleanParam(parent: Params, name: String, doc: String) // No need for isValid
   extends Param[Boolean](parent, name, doc) {
 
   override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
@@ -105,7 +221,11 @@ class BooleanParam(parent: Params, name: String, doc: String)
 /**
  * A param amd its value.
  */
-case class ParamPair[T](param: Param[T], value: T)
+case class ParamPair[T](param: Param[T], value: T) {
+  // This is *the* place Param.validate is called.  Whenever a parameter is specified, we should
+  // always construct a ParamPair so that validate is called.
+  param.validate(value)
+}
 
 /**
  * :: AlphaComponent ::
@@ -132,12 +252,22 @@ trait Params extends Identifiable with Serializable {
   /**
    * Validates parameter values stored internally plus the input parameter map.
    * Raises an exception if any parameter is invalid.
+   *
+   * This only needs to check for interactions between parameters.
+   * Parameter value checks which do not depend on other parameters are handled by
+   * [[Param.validate()]].  This method does not handle input/output column parameters;
+   * those are checked during schema validation.
    */
-  def validate(paramMap: ParamMap): Unit = {}
+  def validate(paramMap: ParamMap): Unit = { }
 
   /**
    * Validates parameter values stored internally.
    * Raise an exception if any parameter value is invalid.
+   *
+   * This only needs to check for interactions between parameters.
+   * Parameter value checks which do not depend on other parameters are handled by
+   * [[Param.validate()]].  This method does not handle input/output column parameters;
+   * those are checked during schema validation.
    */
   def validate(): Unit = validate(ParamMap.empty)
 
@@ -221,6 +351,10 @@ trait Params extends Identifiable with Serializable {
 
   /**
    * Sets default values for a list of params.
+   *
+   * Note: Java developers should use the single-parameter [[setDefault()]].
+   *       Annotating this with varargs causes compilation failures.
+   *
    * @param paramPairs  a list of param pairs that specify params and their default values to set
    *                    respectively. Make sure that the params are initialized before this method
    *                    gets called.
@@ -306,6 +440,14 @@ private[spark] object Params {
 }
 
 /**
+ * Java-friendly wrapper for [[Params]].
+ * Java developers who need to extend [[Params]] should use this class instead.
+ * If you need to extend a abstract class which already extends [[Params]], then that abstract
+ * class should be Java-friendly as well.
+ */
+abstract class JavaParams extends Params
+
+/**
  * :: AlphaComponent ::
  * A param to value map.
  */
@@ -313,6 +455,12 @@ private[spark] object Params {
 final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
   extends Serializable {
 
+  /* DEVELOPERS: About validating parameter values
+   *   This and ParamPair are the only two collections of parameters.
+   *   This class should always create ParamPairs when
+   *   specifying new parameter values.  ParamPair will then call Param.validate().
+   */
+
   /**
    * Creates an empty param map.
    */
@@ -321,10 +469,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
   /**
    * Puts a (param, value) pair (overwrites if the input param exists).
    */
-  def put[T](param: Param[T], value: T): this.type = {
-    map(param.asInstanceOf[Param[Any]]) = value
-    this
-  }
+  def put[T](param: Param[T], value: T): this.type = put(ParamPair(param, value))
 
   /**
    * Puts a list of param pairs (overwrites if the input params exists).
@@ -332,7 +477,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
   @varargs
   def put(paramPairs: ParamPair[_]*): this.type = {
     paramPairs.foreach { p =>
-      put(p.param.asInstanceOf[Param[Any]], p.value)
+      map(p.param.asInstanceOf[Param[Any]]) = p.value
     }
     this
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 654cd72..7da4bb4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -21,6 +21,8 @@ import java.io.PrintWriter
 
 import scala.reflect.ClassTag
 
+import org.apache.spark.ml.param.ParamValidators
+
 /**
  * Code generator for shared params (sharedParams.scala). Run under the Spark folder with
  * {{{
@@ -31,8 +33,10 @@ private[shared] object SharedParamsCodeGen {
 
   def main(args: Array[String]): Unit = {
     val params = Seq(
-      ParamDesc[Double]("regParam", "regularization parameter"),
-      ParamDesc[Int]("maxIter", "max number of iterations"),
+      ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
+        isValid = "ParamValidators.gtEq(0)"),
+      ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
+        isValid = "ParamValidators.gtEq(0)"),
       ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
       ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
       ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
@@ -40,14 +44,19 @@ private[shared] object SharedParamsCodeGen {
         Some("\"rawPrediction\"")),
       ParamDesc[String]("probabilityCol",
         "column name for predicted class conditional probabilities", Some("\"probability\"")),
-      ParamDesc[Double]("threshold", "threshold in binary classification prediction"),
+      ParamDesc[Double]("threshold",
+        "threshold in binary classification prediction, in range [0, 1]",
+        isValid = "ParamValidators.inRange(0, 1)"),
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
       ParamDesc[String]("outputCol", "output column name"),
-      ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
+      ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
+        isValid = "ParamValidators.gtEq(1)"),
       ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
       ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
-      ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
+      ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
+        " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
+        isValid = "ParamValidators.inRange(0, 1)"),
       ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
       ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
 
@@ -62,7 +71,8 @@ private[shared] object SharedParamsCodeGen {
   private case class ParamDesc[T: ClassTag](
       name: String,
       doc: String,
-      defaultValueStr: Option[String] = None) {
+      defaultValueStr: Option[String] = None,
+      isValid: String = "") {
 
     require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
     require(doc.nonEmpty) // TODO: more rigorous on doc
@@ -113,20 +123,23 @@ private[shared] object SharedParamsCodeGen {
          |  setDefault($name, $v)
          |""".stripMargin
     }.getOrElse("")
+    val isValid = if (param.isValid != "") {
+      ", " + param.isValid
+    } else {
+      ""
+    }
 
     s"""
       |/**
-      | * :: DeveloperApi ::
-      | * Trait for shared param $name$defaultValueDoc.
+      | * (private[ml]) Trait for shared param $name$defaultValueDoc.
       | */
-      |@DeveloperApi
-      |trait Has$Name extends Params {
+      |private[ml] trait Has$Name extends Params {
       |
       |  /**
       |   * Param for $doc.
       |   * @group param
       |   */
-      |  final val $name: $Param = new $Param(this, "$name", "$doc")
+      |  final val $name: $Param = new $Param(this, "$name", "$doc"$isValid)
       |$setDefault
       |  /** @group getParam */
       |  final def get$Name: $T = getOrDefault($name)

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 96d11ed..e1549f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -26,45 +26,39 @@ import org.apache.spark.util.Utils
 // scalastyle:off
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param regParam.
+ * (private[ml]) Trait for shared param regParam.
  */
-@DeveloperApi
-trait HasRegParam extends Params {
+private[ml] trait HasRegParam extends Params {
 
   /**
-   * Param for regularization parameter.
+   * Param for regularization parameter (>= 0).
    * @group param
    */
-  final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+  final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0))
 
   /** @group getParam */
   final def getRegParam: Double = getOrDefault(regParam)
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param maxIter.
+ * (private[ml]) Trait for shared param maxIter.
  */
-@DeveloperApi
-trait HasMaxIter extends Params {
+private[ml] trait HasMaxIter extends Params {
 
   /**
-   * Param for max number of iterations.
+   * Param for max number of iterations (>= 0).
    * @group param
    */
-  final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+  final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
 
   /** @group getParam */
   final def getMaxIter: Int = getOrDefault(maxIter)
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param featuresCol (default: "features").
+ * (private[ml]) Trait for shared param featuresCol (default: "features").
  */
-@DeveloperApi
-trait HasFeaturesCol extends Params {
+private[ml] trait HasFeaturesCol extends Params {
 
   /**
    * Param for features column name.
@@ -79,11 +73,9 @@ trait HasFeaturesCol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param labelCol (default: "label").
+ * (private[ml]) Trait for shared param labelCol (default: "label").
  */
-@DeveloperApi
-trait HasLabelCol extends Params {
+private[ml] trait HasLabelCol extends Params {
 
   /**
    * Param for label column name.
@@ -98,11 +90,9 @@ trait HasLabelCol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param predictionCol (default: "prediction").
+ * (private[ml]) Trait for shared param predictionCol (default: "prediction").
  */
-@DeveloperApi
-trait HasPredictionCol extends Params {
+private[ml] trait HasPredictionCol extends Params {
 
   /**
    * Param for prediction column name.
@@ -117,11 +107,9 @@ trait HasPredictionCol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param rawPredictionCol (default: "rawPrediction").
+ * (private[ml]) Trait for shared param rawPredictionCol (default: "rawPrediction").
  */
-@DeveloperApi
-trait HasRawPredictionCol extends Params {
+private[ml] trait HasRawPredictionCol extends Params {
 
   /**
    * Param for raw prediction (a.k.a. confidence) column name.
@@ -136,11 +124,9 @@ trait HasRawPredictionCol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param probabilityCol (default: "probability").
+ * (private[ml]) Trait for shared param probabilityCol (default: "probability").
  */
-@DeveloperApi
-trait HasProbabilityCol extends Params {
+private[ml] trait HasProbabilityCol extends Params {
 
   /**
    * Param for column name for predicted class conditional probabilities.
@@ -155,28 +141,24 @@ trait HasProbabilityCol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param threshold.
+ * (private[ml]) Trait for shared param threshold.
  */
-@DeveloperApi
-trait HasThreshold extends Params {
+private[ml] trait HasThreshold extends Params {
 
   /**
-   * Param for threshold in binary classification prediction.
+   * Param for threshold in binary classification prediction, in range [0, 1].
    * @group param
    */
-  final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction")
+  final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
 
   /** @group getParam */
   final def getThreshold: Double = getOrDefault(threshold)
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param inputCol.
+ * (private[ml]) Trait for shared param inputCol.
  */
-@DeveloperApi
-trait HasInputCol extends Params {
+private[ml] trait HasInputCol extends Params {
 
   /**
    * Param for input column name.
@@ -189,11 +171,9 @@ trait HasInputCol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param inputCols.
+ * (private[ml]) Trait for shared param inputCols.
  */
-@DeveloperApi
-trait HasInputCols extends Params {
+private[ml] trait HasInputCols extends Params {
 
   /**
    * Param for input column names.
@@ -206,11 +186,9 @@ trait HasInputCols extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param outputCol.
+ * (private[ml]) Trait for shared param outputCol.
  */
-@DeveloperApi
-trait HasOutputCol extends Params {
+private[ml] trait HasOutputCol extends Params {
 
   /**
    * Param for output column name.
@@ -223,28 +201,24 @@ trait HasOutputCol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param checkpointInterval.
+ * (private[ml]) Trait for shared param checkpointInterval.
  */
-@DeveloperApi
-trait HasCheckpointInterval extends Params {
+private[ml] trait HasCheckpointInterval extends Params {
 
   /**
-   * Param for checkpoint interval.
+   * Param for checkpoint interval (>= 1).
    * @group param
    */
-  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
+  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1))
 
   /** @group getParam */
   final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param fitIntercept (default: true).
+ * (private[ml]) Trait for shared param fitIntercept (default: true).
  */
-@DeveloperApi
-trait HasFitIntercept extends Params {
+private[ml] trait HasFitIntercept extends Params {
 
   /**
    * Param for whether to fit an intercept term.
@@ -259,11 +233,9 @@ trait HasFitIntercept extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param seed (default: Utils.random.nextLong()).
+ * (private[ml]) Trait for shared param seed (default: Utils.random.nextLong()).
  */
-@DeveloperApi
-trait HasSeed extends Params {
+private[ml] trait HasSeed extends Params {
 
   /**
    * Param for random seed.
@@ -278,28 +250,24 @@ trait HasSeed extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param elasticNetParam.
+ * (private[ml]) Trait for shared param elasticNetParam.
  */
-@DeveloperApi
-trait HasElasticNetParam extends Params {
+private[ml] trait HasElasticNetParam extends Params {
 
   /**
-   * Param for the ElasticNet mixing parameter.
+   * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
    * @group param
    */
-  final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter")
+  final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1))
 
   /** @group getParam */
   final def getElasticNetParam: Double = getOrDefault(elasticNetParam)
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param tol.
+ * (private[ml]) Trait for shared param tol.
  */
-@DeveloperApi
-trait HasTol extends Params {
+private[ml] trait HasTol extends Params {
 
   /**
    * Param for the convergence tolerance for iterative algorithms.
@@ -312,11 +280,9 @@ trait HasTol extends Params {
 }
 
 /**
- * :: DeveloperApi ::
- * Trait for shared param stepSize.
+ * (private[ml]) Trait for shared param stepSize.
  */
-@DeveloperApi
-trait HasStepSize extends Params {
+private[ml] trait HasStepSize extends Params {
 
   /**
    * Param for Step size to be used for each iteration of optimization..

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index bd793be..f9f2b27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -52,35 +52,40 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
   with HasPredictionCol with HasCheckpointInterval {
 
   /**
-   * Param for rank of the matrix factorization.
+   * Param for rank of the matrix factorization (>= 1).
+   * Default: 10
    * @group param
    */
-  val rank = new IntParam(this, "rank", "rank of the factorization")
+  val rank = new IntParam(this, "rank", "rank of the factorization", ParamValidators.gtEq(1))
 
   /** @group getParam */
   def getRank: Int = getOrDefault(rank)
 
   /**
-   * Param for number of user blocks.
+   * Param for number of user blocks (>= 1).
+   * Default: 10
    * @group param
    */
-  val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks")
+  val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks",
+    ParamValidators.gtEq(1))
 
   /** @group getParam */
   def getNumUserBlocks: Int = getOrDefault(numUserBlocks)
 
   /**
-   * Param for number of item blocks.
+   * Param for number of item blocks (>= 1).
+   * Default: 10
    * @group param
    */
-  val numItemBlocks =
-    new IntParam(this, "numItemBlocks", "number of item blocks")
+  val numItemBlocks = new IntParam(this, "numItemBlocks", "number of item blocks",
+      ParamValidators.gtEq(1))
 
   /** @group getParam */
   def getNumItemBlocks: Int = getOrDefault(numItemBlocks)
 
   /**
    * Param to decide whether to use implicit preference.
+   * Default: false
    * @group param
    */
   val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference")
@@ -89,16 +94,19 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
   def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs)
 
   /**
-   * Param for the alpha parameter in the implicit preference formulation.
+   * Param for the alpha parameter in the implicit preference formulation (>= 0).
+   * Default: 1.0
    * @group param
    */
-  val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference")
+  val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference",
+    ParamValidators.gtEq(0))
 
   /** @group getParam */
   def getAlpha: Double = getOrDefault(alpha)
 
   /**
    * Param for the column name for user ids.
+   * Default: "user"
    * @group param
    */
   val userCol = new Param[String](this, "userCol", "column name for user ids")
@@ -108,6 +116,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
 
   /**
    * Param for the column name for item ids.
+   * Default: "item"
    * @group param
    */
   val itemCol = new Param[String](this, "itemCol", "column name for item ids")
@@ -117,6 +126,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
 
   /**
    * Param for the column name for ratings.
+   * Default: "rating"
    * @group param
    */
   val ratingCol = new Param[String](this, "ratingCol", "column name for ratings")
@@ -126,6 +136,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
 
   /**
    * Param for whether to apply nonnegativity constraints.
+   * Default: false
    * @group param
    */
   val nonnegative = new BooleanParam(
@@ -136,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
 
   setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
     implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
-    ratingCol -> "rating", nonnegative -> false)
+    ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
 
   /**
    * Validates and transforms the input schema.
@@ -281,10 +292,6 @@ class ALS extends Estimator[ALSModel] with ALSParams {
     this
   }
 
-  setMaxIter(20)
-  setRegParam(1.0)
-  setCheckpointInterval(10)
-
   override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
     val map = extractParamMap(paramMap)
     val ratings = dataset

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index c784cf3..76c9837 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -102,21 +102,16 @@ final class GBTRegressor
    */
   val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
     " tries to minimize (case-insensitive). Supported options:" +
-    s" ${GBTRegressor.supportedLossTypes.mkString(", ")}")
+    s" ${GBTRegressor.supportedLossTypes.mkString(", ")}",
+    (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase))
 
   setDefault(lossType -> "squared")
 
   /** @group setParam */
-  def setLossType(value: String): this.type = {
-    val lossStr = value.toLowerCase
-    require(GBTRegressor.supportedLossTypes.contains(lossStr), "GBTRegressor was given bad loss" +
-      s" type: $value. Supported options: ${GBTRegressor.supportedLossTypes.mkString(", ")}")
-    set(lossType, lossStr)
-    this
-  }
+  def setLossType(value: String): this.type = set(lossType, value)
 
   /** @group getParam */
-  def getLossType: String = getOrDefault(lossType)
+  def getLossType: String = getOrDefault(lossType).toLowerCase
 
   /** (private[ml]) Convert new loss to old loss. */
   override private[ml] def getOldLossType: OldLoss = {

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index cc9ad22..11c6cea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -25,7 +25,8 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction}
 
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.param.{Params, ParamMap}
-import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
+import org.apache.spark.ml.param.shared.{HasTol, HasElasticNetParam, HasMaxIter,
+  HasRegParam}
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
@@ -46,6 +47,16 @@ private[regression] trait LinearRegressionParams extends RegressorParams
  * :: AlphaComponent ::
  *
  * Linear regression.
+ *
+ * The learning objective is to minimize the squared error, with regularization.
+ * The specific squared error loss function used is:
+ *   L = 1/2n ||A weights - y||^2^
+ *
+ * This support multiple types of regularization:
+ *  - none (a.k.a. ordinary least squares)
+ *  - L2 (ridge regression)
+ *  - L1 (Lasso)
+ *  - L2 + L1 (elastic net)
  */
 @AlphaComponent
 class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
@@ -135,7 +146,8 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
     val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
       new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
     } else {
-      new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
+      new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam,
+        paramMap(tol))
     }
 
     val initialWeights = Vectors.zeros(numFeatures)

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 4bb4ed8..d1ad089 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -22,7 +22,7 @@ import com.github.fommil.netlib.F2jBLAS
 import org.apache.spark.Logging
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml._
-import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param._
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -61,10 +61,12 @@ private[ml] trait CrossValidatorParams extends Params {
   def getEvaluator: Evaluator = getOrDefault(evaluator)
 
   /**
-   * param for number of folds for cross validation
+   * Param for number of folds for cross validation.  Must be >= 2.
+   * Default: 3
    * @group param
    */
-  val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation")
+  val numFolds: IntParam = new IntParam(this, "numFolds",
+    "number of folds for cross validation (>= 2)", ParamValidators.gtEq(2))
 
   /** @group getParam */
   def getNumFolds: Int = getOrDefault(numFolds)
@@ -93,6 +95,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
   /** @group setParam */
   def setNumFolds(value: Int): this.type = set(numFolds, value)
 
+  override def validate(paramMap: ParamMap): Unit = {
+    getEstimatorParamMaps.foreach { eMap =>
+      getEstimator.validate(eMap ++ paramMap)
+    }
+  }
+
   override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
     val map = extractParamMap(paramMap)
     val schema = dataset.schema
@@ -101,8 +109,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
     val est = map(estimator)
     val eval = map(evaluator)
     val epm = map(estimatorParamMaps)
-    val numModels = epm.size
-    val metrics = new Array[Double](epm.size)
+    val numModels = epm.length
+    val metrics = new Array[Double](epm.length)
     val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
     splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
       val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
@@ -148,6 +156,10 @@ class CrossValidatorModel private[ml] (
     val bestModel: Model[_])
   extends Model[CrossValidatorModel] with CrossValidatorParams {
 
+  override def validate(paramMap: ParamMap): Unit = {
+    bestModel.validate(paramMap)
+  }
+
   override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
     bestModel.transform(dataset, paramMap)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index deac390..1f77958 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -180,7 +180,9 @@ object GradientBoostedTrees extends Logging {
     val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) {
       input.persist(StorageLevel.MEMORY_AND_DISK)
       true
-    } else false
+    } else {
+      false
+    }
 
     timer.stop("init")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
new file mode 100644
index 0000000..e7df10d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+
+/**
+ * Test Param and related classes in Java
+ */
+public class JavaParamsSuite {
+
+  private transient JavaSparkContext jsc;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaParamsSuite");
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void testParams() {
+    JavaTestParams testParams = new JavaTestParams();
+    Assert.assertEquals(testParams.getMyIntParam(), 1);
+    testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a");
+    Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0);
+    Assert.assertEquals(testParams.getMyStringParam(), "a");
+  }
+
+  @Test
+  public void testParamValidate() {
+    ParamValidators.gt(1.0);
+    ParamValidators.gtEq(1.0);
+    ParamValidators.lt(1.0);
+    ParamValidators.ltEq(1.0);
+    ParamValidators.inRange(0, 1, true, false);
+    ParamValidators.inRange(0, 1);
+    ParamValidators.inArray(Lists.newArrayList(0, 1, 3));
+    ParamValidators.inArray(Lists.newArrayList("a", "b"));
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
new file mode 100644
index 0000000..8abe575
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+/**
+ * A subclass of Params for testing.
+ */
+public class JavaTestParams extends JavaParams {
+
+  public IntParam myIntParam;
+
+  public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }
+
+  public JavaTestParams setMyIntParam(int value) {
+    set(myIntParam, value); return this;
+  }
+
+  public DoubleParam myDoubleParam;
+
+  public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
+
+  public JavaTestParams setMyDoubleParam(double value) {
+    set(myDoubleParam, value); return this;
+  }
+
+  public Param<String> myStringParam;
+
+  public String getMyStringParam() { return (String)getOrDefault(myStringParam); }
+
+  public JavaTestParams setMyStringParam(String value) {
+    set(myStringParam, value); return this;
+  }
+
+  public JavaTestParams() {
+    myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
+    myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
+      ParamValidators.inRange(0.0, 1.0));
+    List<String> validStrings = Lists.newArrayList("a", "b");
+    myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
+      ParamValidators.inArray(validStrings));
+    setDefault(myIntParam, 1);
+    setDefault(myDoubleParam, 0.5);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 88ea679..f885260 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -26,14 +26,22 @@ class ParamsSuite extends FunSuite {
     import solver.{maxIter, inputCol}
 
     assert(maxIter.name === "maxIter")
-    assert(maxIter.doc === "max number of iterations")
+    assert(maxIter.doc === "max number of iterations (>= 0)")
     assert(maxIter.parent.eq(solver))
-    assert(maxIter.toString === "maxIter: max number of iterations (default: 10)")
+    assert(maxIter.toString === "maxIter: max number of iterations (>= 0) (default: 10)")
+    assert(!maxIter.isValid(-1))
+    assert(maxIter.isValid(0))
+    assert(maxIter.isValid(1))
 
     solver.setMaxIter(5)
-    assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)")
+    assert(maxIter.toString ===
+      "maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
 
     assert(inputCol.toString === "inputCol: input column name (undefined)")
+
+    intercept[IllegalArgumentException] {
+      solver.setMaxIter(-1)
+    }
   }
 
   test("param pair") {
@@ -47,6 +55,9 @@ class ParamsSuite extends FunSuite {
       assert(pair.param.eq(maxIter))
       assert(pair.value === 5)
     }
+    intercept[IllegalArgumentException] {
+      val pair = maxIter -> -1
+    }
   }
 
   test("param map") {
@@ -59,6 +70,9 @@ class ParamsSuite extends FunSuite {
     map0.put(maxIter, 10)
     assert(map0.contains(maxIter))
     assert(map0(maxIter) === 10)
+    intercept[IllegalArgumentException] {
+      map0.put(maxIter, -1)
+    }
 
     assert(!map0.contains(inputCol))
     intercept[NoSuchElementException] {
@@ -122,14 +136,57 @@ class ParamsSuite extends FunSuite {
     assert(solver.getInputCol === "input")
     solver.validate()
     intercept[IllegalArgumentException] {
-      solver.validate(ParamMap(maxIter -> -10))
+      ParamMap(maxIter -> -10)
     }
-    solver.setMaxIter(-10)
     intercept[IllegalArgumentException] {
-      solver.validate()
+      solver.setMaxIter(-10)
     }
 
     solver.clearMaxIter()
     assert(!solver.isSet(maxIter))
   }
+
+  test("ParamValidate") {
+    val alwaysTrue = ParamValidators.alwaysTrue[Int]
+    assert(alwaysTrue(1))
+
+    val gt1Int = ParamValidators.gt[Int](1)
+    assert(!gt1Int(1) && gt1Int(2))
+    val gt1Double = ParamValidators.gt[Double](1)
+    assert(!gt1Double(1.0) && gt1Double(1.1))
+
+    val gtEq1Int = ParamValidators.gtEq[Int](1)
+    assert(!gtEq1Int(0) && gtEq1Int(1))
+    val gtEq1Double = ParamValidators.gtEq[Double](1)
+    assert(!gtEq1Double(0.9) && gtEq1Double(1.0))
+
+    val lt1Int = ParamValidators.lt[Int](1)
+    assert(lt1Int(0) && !lt1Int(1))
+    val lt1Double = ParamValidators.lt[Double](1)
+    assert(lt1Double(0.9) && !lt1Double(1.0))
+
+    val ltEq1Int = ParamValidators.ltEq[Int](1)
+    assert(ltEq1Int(1) && !ltEq1Int(2))
+    val ltEq1Double = ParamValidators.ltEq[Double](1)
+    assert(ltEq1Double(1.0) && !ltEq1Double(1.1))
+
+    val inRange02IntInclusive = ParamValidators.inRange[Int](0, 2)
+    assert(inRange02IntInclusive(0) && inRange02IntInclusive(1) && inRange02IntInclusive(2) &&
+      !inRange02IntInclusive(-1) && !inRange02IntInclusive(3))
+    val inRange02IntExclusive =
+      ParamValidators.inRange[Int](0, 2, lowerInclusive = false, upperInclusive = false)
+    assert(!inRange02IntExclusive(0) && inRange02IntExclusive(1) && !inRange02IntExclusive(2))
+
+    val inRange02DoubleInclusive = ParamValidators.inRange[Double](0, 2)
+    assert(inRange02DoubleInclusive(0) && inRange02DoubleInclusive(1) &&
+      inRange02DoubleInclusive(2) &&
+      !inRange02DoubleInclusive(-0.1) && !inRange02DoubleInclusive(2.1))
+    val inRange02DoubleExclusive =
+      ParamValidators.inRange[Double](0, 2, lowerInclusive = false, upperInclusive = false)
+    assert(!inRange02DoubleExclusive(0) && inRange02DoubleExclusive(1) &&
+      !inRange02DoubleExclusive(2))
+
+    val inArray = ParamValidators.inArray[Int](Array(1, 2))
+    assert(inArray(1) && inArray(2) && !inArray(0))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/114bad60/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index 641b64b..6f9c9cb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -29,7 +29,7 @@ class TestParams extends Params with HasMaxIter with HasInputCol {
 
   override def validate(paramMap: ParamMap): Unit = {
     val m = extractParamMap(paramMap)
-    require(m(maxIter) >= 0)
+    // Note: maxIter is validated when it is set.
     require(m.contains(inputCol))
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message