mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r997194 - /mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Date Wed, 15 Sep 2010 06:19:41 GMT
Author: tdunning
Date: Wed Sep 15 06:19:40 2010
New Revision: 997194

URL: http://svn.apache.org/viewvc?rev=997194&view=rev
Log:
Added model dissector itself

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=997194&r1=997193&r2=997194&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java Wed
Sep 15 06:19:40 2010
@@ -17,15 +17,28 @@
 
 package org.apache.mahout.classifier.sgd;
 
+import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.QRDecomposition;
 import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
 import org.apache.mahout.math.SparseRowMatrix;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.UnaryFunction;
+import org.apache.mahout.math.matrix.GaussSeidel;
 import org.apache.mahout.vectors.Dictionary;
 
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
 import java.util.Map;
+import java.util.PriorityQueue;
 import java.util.Set;
 
 /**
@@ -35,52 +48,79 @@ import java.util.Set;
  * in the original space.
  */
 public class ModelDissector {
-  int records = 0;
-  private Dictionary dict;
-  private Matrix a;
-  private Matrix b;
+  private Map<String,Vector> weightMap;
 
   public ModelDissector(int n) {
-    a = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, Integer.MAX_VALUE}, true);
-    b = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, n});
-
-    dict.intern("Intercept Value");
+    weightMap = Maps.newHashMap();
   }
 
-  public void addExample(Set<String> features, Vector score) {
-    for (Vector.Element element : score) {
-      b.set(records, element.index(), element.get());
+  public void update(Vector features, Map<String, Set<Integer>> traceDictionary,
AbstractVectorClassifier learner) {
+    features.assign(0);
+    final int numCategories = learner.numCategories();
+    for (String feature : traceDictionary.keySet()) {
+      weightMap = weightMap;
+      if (!weightMap.containsKey(feature)) {
+        for (Integer where : traceDictionary.get(feature)) {
+          features.set(where, 1);
+        }
+
+        Vector v = learner.classifyNoLink(features);
+        weightMap.put(feature, v);
+
+        for (Integer where : traceDictionary.get(feature)) {
+          features.set(where, 0);
+        }
+      }
     }
 
-    for (String feature : features) {
-      int j = dict.intern(feature);
-      a.set(records, j, 1);
+  }
+
+  public List<Weight> summary(int n) {
+    PriorityQueue<Weight> pq = new PriorityQueue<Weight>();
+    for (String s : weightMap.keySet()) {
+      pq.add(new Weight(s, weightMap.get(s)));
+      while (pq.size() > n) {
+        pq.poll();
+      }
     }
-    records++;
+    List<Weight> r = Lists.newArrayList(pq);
+    Collections.sort(r, Ordering.natural().reverse());
+    return r;
   }
 
-  public void addExample(Set<String> features, double score) {
-    b.set(records, 0, score);
+  public static class Weight implements Comparable<Weight> {
+    private String feature;
+    private double value;
+    private int maxIndex;
+    private Vector weights;
+
+    public Weight(String feature, Vector weights) {
+      this.weights = weights;
+      this.feature = feature;
+      value = weights.norm(1);
+      maxIndex = weights.maxValueIndex();
+    }
 
-    a.set(records, 0, 1);
-    for (String feature : features) {
-      int j = dict.intern(feature);
-      a.set(records, j, 1);
+    @Override
+    public int compareTo(Weight other) {
+      int r = Double.compare(this.value, other.value);
+      if (r != 0) {
+        return r;
+      } else {
+        return feature.compareTo(other.feature);
+      }
+    }
+
+    public String getFeature() {
+      return feature;
+    }
+
+    public double getWeight() {
+      return value;
     }
-    records++;
-  }
 
-  public Matrix solve() {
-    Matrix az = a.viewPart(new int[]{0, 0}, new int[]{records, dict.size()});
-    Matrix bz = b.viewPart(new int[]{0, 0}, new int[]{records, b.columnSize()});
-    QRDecomposition qr = new QRDecomposition(az.transpose().times(az));
-    Matrix x = qr.solve(bz);
-    Map<String, Integer> labels = Maps.newHashMap();
-    int i = 0;
-    for (String s : dict.values()) {
-      labels.put(s, i++);
+    public int getMaxImpact() {
+      return maxIndex;
     }
-    x.setRowLabelBindings(labels);
-    return x;
   }
 }



Mime
View raw message