spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbradley <...@git.apache.org>
Subject [GitHub] spark pull request: [SPARK-3207][MLLIB]Choose splits for continuou...
Date Sat, 18 Oct 2014 08:21:34 GMT
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/2780#discussion_r19052556
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -1011,4 +1014,99 @@ object DecisionTree extends Serializable with Logging {
         categories
       }
     
    +  /**
    +   * Find splits for a continuous feature
    +   * NOTE: Returned number of splits is set based on `featureSamples` and
    +   *       may be different with `numSplits`.
    +   *       MetaData's number of splits will be set accordingly.
    +   * @param featureSamples feature values of each sample
    +   * @param metadata decision tree metadata
    +   * @param featureIndex feature index to find splits
    +   * @return array of splits
    +   */
    +  private[tree] def findSplitsForContinuousFeature(
    +      featureSamples: Array[Double],
    +      metadata: DecisionTreeMetadata,
    +      featureIndex: Int): Array[Double] = {
    +    require(metadata.isContinuous(featureIndex),
    +      s"findSplitsForContinuousFeature can only be used " +
    +        s"to find splits for a continuous feature.")
    +
    +    /**
    +     * Get count for each distinct value
    +     */
    +    def getValueCount(arr: Array[Double]): Array[(Double, Int)] = {
    +      val valueCount = new ArrayBuffer[(Double, Int)]
    +      var index = 1
    +      var currentValue = arr(0)
    +      var currentCount = 1
    +      while (index < arr.length) {
    +        if (currentValue != arr(index)) {
    +          valueCount.append((currentValue, currentCount))
    +          currentCount = 1
    +          currentValue = arr(index)
    +        } else {
    +          currentCount += 1
    +        }
    +        index += 1
    +      }
    +
    +      valueCount.append((currentValue, currentCount))
    +
    +      valueCount.toArray
    +    }
    +
    +
    +    val splits = {
    +      val numSplits = metadata.numSplits(featureIndex)
    +
    +      // sort feature samples first
    +      val sortedFeatureSamples = featureSamples.sorted
    +
    +      // get count for each distinct value
    +      val valueCount = getValueCount(sortedFeatureSamples)
    +
    +      // if possible splits is not enough or just enough,
    +      // just return all possible splits
    +      val possibleSplits = valueCount.length
    +      if (possibleSplits <= numSplits) {
    +        valueCount.map(_._1)
    +      } else {
    +        // stride between splits
    +        val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
    +        logDebug("stride = " + stride)
    +
    +        // iterate `valueCount` to find splits
    +        val splits = new ArrayBuffer[Double]
    +        var index = 1
    +        // currentCount: sum of counts of values that have been visited
    +        var currentCount = valueCount(0)._2
    +        // expectedCount: expected value for `currentCount`.
    +        // If `currentCount` is closest value to `expectedCount`,
    +        // then current value is a split threshold.
    +        // After finding a split threshold, `expectedCount` is added by stride.
    +        var expectedCount = stride
    +        while (index < valueCount.length) {
    +          // If adding count of current value to currentCount
    +          // makes currentCount less close to expectedCount,
    +          // previous value is a split threshold.
    +          if (math.abs(currentCount - expectedCount) <
    +            math.abs(currentCount + valueCount(index)._2 - expectedCount)) {
    +            splits.append(valueCount(index-1)._1)
    +            expectedCount += stride
    +          }
    +          currentCount += valueCount(index)._2
    +          index += 1
    +        }
    +
    +        splits.toArray
    +      }
    +    }
    +
    +    assert(splits.length > 0)
    --- End diff --
    
    I think this will fail when all values for a feature are the same.  We should either change
the code to add at least 1 split (even though it may never be used), or just allow a feature
to have 0 splits.  If we allow 0 splits, then we should check to make sure the DecisionTree
training still works when a feature has 0 splits, and also when all features have 0 splits.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message