spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [SPARK-2152][MLlib] fix bin offset in DecisionTree node aggregations (also resolves SPARK-2160)
Date Wed, 09 Jul 2014 02:17:53 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.0 885489112 -> d569838bc


[SPARK-2152][MLlib] fix bin offset in DecisionTree node aggregations (also resolves SPARK-2160)

Hi, this pull fixes (what I believe to be) a bug in DecisionTree.scala.

In the extractLeftRightNodeAggregates function, the first set of rightNodeAgg values for Regression
are set in line 792 as follows:

rightNodeAgg(featureIndex)(2 * (numBins - 2))
  = binData(shift + (2 * numBins - 1)))

Then there is a loop that sets the rest of the values, as in line 809:

rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
  binData(shift + (2 *(numBins - 2 - splitIndex))) +
  rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))

But since splitIndex starts at 1, this ends up skipping a set of binData values.

The changes here address this issue, for both the Regression and Classification cases.

Author: johnnywalleye <jsondag@gmail.com>

Closes #1316 from johnnywalleye/master and squashes the following commits:

73809da [johnnywalleye] fix bin offset in DecisionTree node aggregations

(cherry picked from commit 1114207cc8e4ef94cb97bbd5a2ef3ae4d51f73fa)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d569838b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d569838b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d569838b

Branch: refs/heads/branch-1.0
Commit: d569838bc067f2b64f6c10e54ba8e5973f8fc93a
Parents: 8854891
Author: johnnywalleye <jsondag@gmail.com>
Authored: Tue Jul 8 19:17:26 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Jul 8 19:17:43 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/mllib/tree/DecisionTree.scala  | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d569838b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 3b13e52..74d5d7b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -807,10 +807,10 @@ object DecisionTree extends Serializable with Logging {
               // calculating right node aggregate for a split as a sum of right node aggregate
of a
               // higher split and the right bin aggregate of a bin where the split is a low
split
               rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
-                binData(shift + (2 *(numBins - 2 - splitIndex))) +
+                binData(shift + (2 *(numBins - 1 - splitIndex))) +
                 rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
               rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
-                binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
+                binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
                   rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
 
               splitIndex += 1
@@ -855,13 +855,13 @@ object DecisionTree extends Serializable with Logging {
               // calculating right node aggregate for a split as a sum of right node aggregate
of a
               // higher split and the right bin aggregate of a bin where the split is a low
split
               rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
-                binData(shift + (3 * (numBins - 2 - splitIndex))) +
+                binData(shift + (3 * (numBins - 1 - splitIndex))) +
                   rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
               rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
-                binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
+                binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
                   rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
               rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
-                binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
+                binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
                   rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
 
               splitIndex += 1


Mime
View raw message