spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbradley <...@git.apache.org>
Subject [GitHub] spark pull request: [SPARK-6113] [ml] Tree ensembles for Pipelines...
Date Thu, 23 Apr 2015 19:57:01 GMT
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/5626#discussion_r28997275
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala ---
    @@ -298,3 +302,200 @@ private[ml] object TreeRegressorParams {
       // These options should be lowercase.
       val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
     }
    +
    +/**
    + * :: DeveloperApi ::
    + * Parameters for Decision Tree-based ensemble algorithms.
    + *
    + * Note: Marked as private and DeveloperApi since this may be made public in the future.
    + */
    +@DeveloperApi
    +private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
    +
    +  /**
    +   * Fraction of the training data used for learning each decision tree.
    +   * (default = 1.0)
    +   * @group param
    +   */
    +  final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
    +    "Fraction of the training data used for learning each decision tree.")
    +
    +  /**
    +   * Random seed for bootstrapping and choosing feature subsets.
    +   * @group param
    +   */
    +  final val seed: LongParam = new LongParam(this, "seed",
    +    "Random seed for bootstrapping and choosing feature subsets.")
    +
    +  setDefault(subsamplingRate -> 1.0, seed -> Utils.random.nextLong())
    +
    +  /** @group setParam */
    +  def setSubsamplingRate(value: Double): this.type = {
    +    require(value > 0.0 && value <= 1.0,
    +      s"Subsampling rate must be in range (0,1]. Bad rate: $value")
    +    set(subsamplingRate, value)
    +    this
    +  }
    +
    +  /** @group getParam */
    +  def getSubsamplingRate: Double = getOrDefault(subsamplingRate)
    +
    +  /** @group setParam */
    +  def setSeed(value: Long): this.type = {
    +    set(seed, value)
    +    this
    +  }
    +
    +  /** @group getParam */
    +  def getSeed: Long = getOrDefault(seed)
    +
    +  /**
    +   * Create a Strategy instance to use with the old API.
    +   * NOTE: The caller should set impurity and seed.
    +   */
    +  private[ml] def getOldStrategy(
    +      categoricalFeatures: Map[Int, Int],
    +      numClasses: Int,
    +      oldAlgo: OldAlgo.Algo,
    +      oldImpurity: OldImpurity): OldStrategy = {
    +    super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
    +  }
    +}
    +
    +/**
    + * :: DeveloperApi ::
    + * Parameters for Random Forest algorithms.
    + *
    + * Note: Marked as private and DeveloperApi since this may be made public in the future.
    + */
    +@DeveloperApi
    +private[ml] trait RandomForestParams extends TreeEnsembleParams {
    +
    +  /**
    +   * Number of trees to train (>= 1).
    +   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
    +   * TODO: Change to always do bootstrapping (simpler).
    +   * (default = 20)
    +   * @group param
    +   */
    +  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train
(>= 1)")
    +
    +  /**
    +   * The number of features to consider for splits at each tree node.
    +   * Supported options:
    +   *  - "auto": Choose automatically for task:
    +   *            If numTrees == 1, set to "all."
    +   *            If numTrees > 1 (forest), set to "sqrt" for classification and
    +   *              to "onethird" for regression.
    +   *  - "all": use all features
    +   *  - "onethird": use 1/3 of the features
    +   *  - "sqrt": use sqrt(number of features)
    +   *  - "log2": use log2(number of features)
    +   * (default = "auto")
    +   *
    +   * These various settings are based on the following references:
    +   *  - log2: tested in Breiman (2001)
    +   *  - sqrt: recommended by Breiman manual for random forests
    +   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
    +   *    package.
    +   * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf  Breiman (2001)]]
    +   * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf  Breiman
manual for
    +   *     random forests]]
    +   *
    +   * @group param
    +   */
    +  final val featuresPerNode: Param[String] = new Param[String](this, "featuresPerNode",
    +    "The number of features to consider for splits at each tree node." +
    --- End diff --
    
    "Strategy" makes it sound like it's something fancier than just setting a number based
on the total number of features.  I'm ambivalent, though; there aren't any great options.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message