mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From adene...@apache.org
Subject svn commit: r1029738 - in /mahout/trunk: core/src/main/java/org/apache/mahout/df/builder/ core/src/test/java/org/apache/mahout/df/builder/ examples/src/main/java/org/apache/mahout/df/mapreduce/
Date Mon, 01 Nov 2010 16:43:36 GMT
Author: adeneche
Date: Mon Nov  1 16:43:36 2010
New Revision: 1029738

URL: http://svn.apache.org/viewvc?rev=1029738&view=rev
Log:
MAHOUT-526 Fixed the Infinite Recursion in Decision Forests

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java?rev=1029738&r1=1029737&r2=1029738&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java Mon
Nov  1 16:43:36 2010
@@ -20,6 +20,7 @@ package org.apache.mahout.df.builder;
 import java.util.Random;
 
 import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.Dataset;
 import org.apache.mahout.df.data.Instance;
 import org.apache.mahout.df.data.conditions.Condition;
 import org.apache.mahout.df.node.CategoricalNode;
@@ -81,7 +82,10 @@ public class DefaultTreeBuilder implemen
     }
     
     int[] attributes = randomAttributes(rng, selected, m);
-    
+    if (attributes == null) { // we tried all the attributes and could not split the data
anymore
+      return new Leaf(data.majorityLabel(rng));
+    }
+
     // find the best split
     Split best = null;
     for (int attr : attributes) {
@@ -92,7 +96,6 @@ public class DefaultTreeBuilder implemen
     }
     
     boolean alreadySelected = selected[best.getAttr()];
-    
     if (alreadySelected) {
       // attribute already selected
       log.warn("attribute {} already selected in a parent node", best.getAttr());
@@ -100,12 +103,30 @@ public class DefaultTreeBuilder implemen
     
     Node childNode;
     if (data.getDataset().isNumerical(best.getAttr())) {
+      boolean[] temp = null;
+
       Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
-      Node loChild = build(rng, loSubset);
-      
       Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
+
+      if (loSubset.isEmpty() || hiSubset.isEmpty()) {
+        // the selected attribute did not change the data, avoid using it in the child notes
+        selected[best.getAttr()] = true;
+      } else {
+        // the data changed, so we can unselect all previousely selected NUMERICAL attributes
+        temp = selected;
+        selected = cloneCategoricalAttributes(data.getDataset(), selected);
+      }
+
+      Node loChild = build(rng, loSubset);
       Node hiChild = build(rng, hiSubset);
-      
+
+      // restore the selection state of the attributes
+      if (temp != null) {
+        selected = temp;
+      } else {
+        selected[best.getAttr()] = alreadySelected;
+      }
+
       childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
     } else { // CATEGORICAL attribute
       selected[best.getAttr()] = true;
@@ -117,12 +138,10 @@ public class DefaultTreeBuilder implemen
         Data subset = data.subset(Condition.equals(best.getAttr(), values[index]));
         children[index] = build(rng, subset);
       }
+
+      selected[best.getAttr()] = alreadySelected;
       
       childNode = new CategoricalNode(best.getAttr(), values, children);
-      
-      if (!alreadySelected) {
-        selected[best.getAttr()] = false;
-      }
     }
     
     return childNode;
@@ -154,7 +173,25 @@ public class DefaultTreeBuilder implemen
     
     return true;
   }
-  
+
+
+  /**
+   * Make a copy of the selection state of the attributes, unselect all numerical attributes
+   * @param dataset
+   * @param selected selection state to clone
+   * @return cloned selection state
+   */
+  protected static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected)
{
+    boolean[] cloned = new boolean[selected.length];
+
+    for (int i = 0; i < selected.length; i++) {
+      if (dataset.isNumerical(i)) cloned[i] = false;
+      else cloned[i] = selected[i];
+    }
+
+    return cloned;
+  }
+
   /**
    * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
    * 
@@ -164,6 +201,7 @@ public class DefaultTreeBuilder implemen
    *          attributes' state (selected or not)
    * @param m
    *          number of attributes to choose
+   * @return list of selected attributes' indices, or null if all attributes have already
been selected
    */
   protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
     int nbNonSelected = 0; // number of non selected attributes
@@ -175,6 +213,7 @@ public class DefaultTreeBuilder implemen
     
     if (nbNonSelected == 0) {
       log.warn("All attributes are selected !");
+      return null;
     }
     
     int[] result;

Added: mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java?rev=1029738&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
Mon Nov  1 16:43:36 2010
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.builder;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.junit.Test;
+
+import java.util.Random;
+
+public final class InfiniteRecursionTest extends MahoutTestCase {
+
+  static private double[][] dData = {
+          {0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4,
0.0, 0.6225857142857143, 4},
+          {0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746,
0.0, 0.4964857142857143, 3},
+          {0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746,
0.0, 0.4964857142857143, 4},
+          {0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4,
0.0, 0.6225857142857143, 3}
+  };
+
+  /**
+   * make sure DefaultTreeBuilder.build() does not throw a StackOverflowException
+   */
+  @Test
+  public void testBuild() throws Exception {
+    Random rng = RandomUtils.getRandom();
+
+    DefaultTreeBuilder builder = new DefaultTreeBuilder();
+
+    String[] source = Utils.double2String(dData);
+    String descriptor = "N N N N N N N N L";
+    Dataset dataset = DataLoader.generateDataset(descriptor, source);
+    Data data = DataLoader.loadData(dataset, source);
+
+    builder.build(rng, data);
+  }
+}

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java?rev=1029738&r1=1029737&r2=1029738&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java Mon
Nov  1 16:43:36 2010
@@ -196,7 +196,10 @@ public class BuildForest extends Configu
     
     time = System.currentTimeMillis() - time;
     log.info("Build Time: {}", DFUtils.elapsedTime(time));
-    
+    log.info("Forest num Nodes: {}", forest.nbNodes());
+    log.info("Forest mean num Nodes: {}", forest.meanNbNodes());
+    log.info("Forest mean max Depth: {}", forest.meanMaxDepth());
+
     if (isOob) {
       Random rng;
       if (seed != null) {



Mime
View raw message