spark-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mengxr <...@git.apache.org>
Subject [GitHub] spark pull request: MLI-1 Decision Trees
Date Mon, 10 Mar 2014 18:30:13 GMT
Github user mengxr commented on a diff in the pull request:

    https://github.com/apache/spark/pull/79#discussion_r10443049
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---
    @@ -0,0 +1,1055 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.tree
    +
    +import org.apache.spark.SparkContext._
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.mllib.tree.model._
    +import org.apache.spark.{SparkContext, Logging}
    +import org.apache.spark.mllib.regression.LabeledPoint
    +import org.apache.spark.mllib.tree.model.Split
    +import org.apache.spark.mllib.tree.configuration.Strategy
    +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
    +import org.apache.spark.mllib.tree.configuration.FeatureType._
    +import org.apache.spark.mllib.tree.configuration.Algo._
    +import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity}
    +import scala.util.control.Breaks._
    +
    +/**
    + * A class that implements a decision tree algorithm for classification and regression.
It
    + * supports both continuous and categorical features.
    + * @param strategy The configuration parameters for the tree algorithm which specify
the type
    + *                 of algorithm (classification, regression, etc.), feature type (continuous,
    + *                 categorical),
    + * depth of the tree, quantile calculation strategy, etc.
    +  */
    +class DecisionTree private(val strategy: Strategy) extends Serializable with Logging
{
    +
    +  /**
    +   * Method to train a decision tree model over an RDD
    +   * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training
data
    +   *              for DecisionTree
    +   * @return a DecisionTreeModel that can be used for prediction
    +   */
    +  def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
    +
    +    //Cache input RDD for speedup during multiple passes
    +    input.cache()
    +    logDebug("algo = " + strategy.algo)
    +
    +    //Finding the splits and the corresponding bins (interval between the splits) using
a sample
    +    // of the input data.
    +    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
    +    logDebug("numSplits = " + bins(0).length)
    +
    +    //Noting numBins for the input data
    +    strategy.numBins = bins(0).length
    +
    +    //The depth of the decision tree
    +    val maxDepth = strategy.maxDepth
    +    //The max number of nodes possible given the depth of the tree
    +    val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1
    +    //Initalizing an array to hold filters applied to points for each node
    +    val filters = new Array[List[Filter]](maxNumNodes)
    +    //The filter at the top node is an empty list
    +    filters(0) = List()
    +    //Initializing an array to hold parent impurity calculations for each node
    +    val parentImpurities = new Array[Double](maxNumNodes)
    +    //Dummy value for top node (updated during first split calculation)
    +    val nodes = new Array[Node](maxNumNodes)
    +
    +    //The main-idea here is to perform level-wise training of the decision tree nodes
thus
    +    // reducing the passes over the data from l to log2(l) where l is the total number
of nodes.
    +    // Each data sample is checked for validity w.r.t to each node at a given level --
i.e.,
    +    // the sample is only used for the split calculation at the node if the sampled would
have
    +    // still survived the filters of the parent nodes.
    +    breakable {
    +      for (level <- 0 until maxDepth) {
    +
    +        logDebug("#####################################")
    +        logDebug("level = " + level)
    +        logDebug("#####################################")
    +
    +        //Find best split for all nodes at a level
    +        val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
strategy,
    +          level, filters, splits, bins)
    +
    +        for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
    +          //Extract info for nodes at the current level
    +          extractNodeInfo(nodeSplitStats, level, index, nodes)
    +          //Extract info for nodes at the next lower level
    +          extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
    +            filters)
    +          logDebug("final best split = " + nodeSplitStats._1)
    +
    +        }
    +        require(scala.math.pow(2, level) == splitsStatsForLevel.length)
    +        //Check whether all the nodes at the current level at leaves
    +        val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
    +        logDebug("all leaf = " + allLeaf)
    +        if (allLeaf) break  //no more tree construction
    +
    +      }
    +    }
    +
    +    //Initialize the top or root node of the tree
    +    val topNode = nodes(0)
    +    //Build the full tree using the node info calculated in the level-wise best split
calculations
    +    topNode.build(nodes)
    +
    +    //Return a decision tree model
    +    return new DecisionTreeModel(topNode, strategy.algo)
    +  }
    +
    +  /**
    +   * Extract the decision tree node information for th given tree level and node index
    +   */
    +  private def extractNodeInfo(
    +      nodeSplitStats: (Split, InformationGainStats),
    +      level: Int,
    +      index: Int,
    +      nodes: Array[Node])
    +    : Unit = {
    +
    +    val split = nodeSplitStats._1
    +    val stats = nodeSplitStats._2
    +    val nodeIndex = scala.math.pow(2, level).toInt - 1 + index
    +    val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1)
    +    val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
    +    logDebug("Node = " + node)
    +    nodes(nodeIndex) = node
    +  }
    +
    +  /**
    +   *  Extract the decision tree node information for the children of the node
    +   */
    +  private def extractInfoForLowerLevels(
    +      level: Int,
    +      index: Int,
    +      maxDepth: Int,
    +      nodeSplitStats: (Split, InformationGainStats),
    +      parentImpurities: Array[Double],
    +      filters: Array[List[Filter]])
    +    : Unit = {
    +
    +    // 0 corresponds to the left child node and 1 corresponds to the right child node.
    +    for (i <- 0 to 1) {
    +     //Calculating the index of the node from the node level and the index at the current
level
    +      val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
    +      if (level < maxDepth - 1) {
    +        val impurity = if (i == 0) {
    +          nodeSplitStats._2.leftImpurity
    +        } else {
    +          nodeSplitStats._2.rightImpurity
    +        }
    +        logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity)
    +        //noting the parent impurities
    +        parentImpurities(nodeIndex) = impurity
    +        //noting the parents filters for the child nodes
    +        val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1)
    +        filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2)
    +        for (filter <- filters(nodeIndex)) {
    +          logDebug("Filter = " + filter)
    +        }
    +      }
    +    }
    +  }
    +}
    +
    +object DecisionTree extends Serializable with Logging {
    +
    +  /**
    +   * Method to train a decision tree model over an RDD
    +   * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training
data
    +   *              for DecisionTree
    +   * @param strategy The configuration parameters for the tree algorithm which specify
the type
    +   *                 of algoritm (classification, regression, etc.), feature type (continuous,
    +   *                 categorical), depth of the tree, quantile calculation strategy,
etc.
    +   * @return a DecisionTreeModel that can be used for prediction
    +  */
    +  def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
    +    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
    +  }
    +
    +  /**
    +   * Method to train a decision tree model over an RDD
    +   * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used
as
    +   *              training data
    +   * @param algo algo classification or regression
    +   * @param impurity impurity criterion used for information gain calculation
    +   * @param maxDepth maxDepth maximum depth of the tree
    +   * @return a DecisionTreeModel that can be used for prediction
    +   */
    +  def train(
    +      input: RDD[LabeledPoint],
    +      algo: Algo,
    +      impurity: Impurity,
    +      maxDepth: Int)
    +    : DecisionTreeModel = {
    +    val strategy = new Strategy(algo,impurity,maxDepth)
    +    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
    +  }
    +
    +
    +  /**
    +   * Method to train a decision tree model over an RDD
    +    * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used
as
    +   *              training data for DecisionTree
    +   * @param algo classification or regression
    +   * @param impurity criterion used for information gain calculation
    +   * @param maxDepth  maximum depth of the tree
    +   * @param maxBins maximum number of bins used for splitting features
    +   * @param quantileCalculationStrategy  algorithm for calculating quantiles
    +   * @param categoricalFeaturesInfo A map storing information about the categorical variables
and
    +   *                                the number of discrete values they take. For example,
    +   *                                an entry (n -> k) implies the feature n is categorical
with k
    +   *                                categories 0, 1, 2, ... , k-1. It's important to
note that
    +   *                                features are zero-indexed.
    +   * @return a DecisionTreeModel that can be used for prediction
    +   */
    +  def train(
    +      input: RDD[LabeledPoint],
    +      algo: Algo,
    +      impurity: Impurity,
    +      maxDepth: Int,
    +      maxBins: Int,
    +      quantileCalculationStrategy: QuantileStrategy,
    +      categoricalFeaturesInfo: Map[Int,Int])
    +    : DecisionTreeModel = {
    +    val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
    +      categoricalFeaturesInfo)
    +    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
    +  }
    +
    +  /**
    +   * Returns an array of optimal splits for all nodes at a given level
    +   * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training
data
    +   *              for DecisionTree
    +   * @param parentImpurities Impurities for all parent nodes for the current level
    +   * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance
containing
    +   *                parameters for construction the DecisionTree
    +   * @param level Level of the tree
    +   * @param filters Filters for all nodes at a given level
    +   * @param splits possible splits for all features
    +   * @param bins possible bins for all features
    +   * @return array of splits with best splits for all nodes at a given level.
    +   */
    +  def findBestSplits(
    +      input: RDD[LabeledPoint],
    +      parentImpurities: Array[Double],
    +      strategy: Strategy,
    +      level: Int,
    +      filters: Array[List[Filter]],
    +      splits: Array[Array[Split]],
    +      bins: Array[Array[Bin]])
    +    : Array[(Split, InformationGainStats)] = {
    +
    +    //Common calculations for multiple nested methods
    +    val numNodes = scala.math.pow(2, level).toInt
    +    logDebug("numNodes = " + numNodes)
    +    //Find the number of features by looking at the first sample
    +    val numFeatures = input.take(1)(0).features.length
    +    logDebug("numFeatures = " + numFeatures)
    +    val numBins = strategy.numBins
    +    logDebug("numBins = " + numBins)
    +
    +    /*Find the filters used before reaching the current code*/
    +    def findParentFilters(nodeIndex: Int): List[Filter] = {
    +      if (level == 0) {
    +        List[Filter]()
    +      } else {
    +        val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex
    +        filters(nodeFilterIndex)
    +      }
    +    }
    +
    +    /**
    +     * Find whether the sample is valid input for the current node. In other words,
    +     * does it pass through all the filters for the current node.
    +    */
    +    def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean
= {
    +
    +      //Leaf
    +      if ((level > 0) & (parentFilters.length == 0) ){
    +        return false
    +      }
    +
    +      for (filter <- parentFilters) {
    +        val features = labeledPoint.features
    +        val featureIndex = filter.split.feature
    +        val threshold = filter.split.threshold
    +        val comparison = filter.comparison
    +        val categories = filter.split.categories
    +        val isFeatureContinuous = filter.split.featureType == Continuous
    +        val feature =  features(featureIndex)
    +        if (isFeatureContinuous){
    +          comparison match {
    +            case(-1) => if (feature > threshold) return false
    +            case(1) => if (feature <= threshold) return false
    +          }
    +        } else {
    +          val containsFeature = categories.contains(feature)
    +          comparison match {
    +            case(-1) =>  if (!containsFeature) return false
    +            case(1) =>  if (containsFeature) return false
    +          }
    +
    +        }
    +      }
    +      true
    +    }
    +
    +    /**
    +     * Finds the right bin for the given feature
    +    */
    +    def findBin(
    +        featureIndex: Int,
    +        labeledPoint: LabeledPoint,
    +        isFeatureContinuous: Boolean)
    +      : Int = {
    +
    +      if (isFeatureContinuous){
    +        for (binIndex <- 0 until strategy.numBins) {
    +          val bin = bins(featureIndex)(binIndex)
    +          val lowThreshold = bin.lowSplit.threshold
    +          val highThreshold = bin.highSplit.threshold
    +          val features = labeledPoint.features
    +          if ((lowThreshold < features(featureIndex)) & (highThreshold >= features(featureIndex)))
{
    +            return binIndex
    +          }
    +        }
    +        throw new UnknownError("no bin was found for continuous variable.")
    +      } else {
    +
    +        for (binIndex <- 0 until strategy.numBins) {
    +          val bin = bins(featureIndex)(binIndex)
    +          val category = bin.category
    +          val features = labeledPoint.features
    +          if (category == features(featureIndex)) {
    +            return binIndex
    +          }
    +        }
    +        throw new UnknownError("no bin was found for categorical variable.")
    +
    +      }
    +
    +    }
    +
    +    /**
    +     * Finds bins for all nodes (and all features) at a given level k features,
    +     * l nodes (level = log2(l)).
    +     * Storage label, b11, b12, b13, .., b1k,
    +     * b21, b22, .. , b2k,
    +     * bl1, bl2, .. , blk
    +     * Denotes invalid sample for tree by noting bin for feature 1 as -1
    +    */
    --- End diff --
    
    need an extra space for indentation


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message