mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r997193 - in /mahout/trunk/core/src/main/java/org/apache/mahout: classifier/ classifier/sgd/ vectors/
Date Wed, 15 Sep 2010 06:18:05 GMT
Author: tdunning
Date: Wed Sep 15 06:18:04 2010
New Revision: 997193

URL: http://svn.apache.org/viewvc?rev=997193&view=rev
Log:
Add model reverse engineering for classifiers that extend AbstractVectorClassifier.

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
    mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
Wed Sep 15 06:18:04 2010
@@ -31,6 +31,17 @@ public abstract class AbstractVectorClas
   public abstract Vector classify(Vector instance);
 
   /**
+   * Classify a vector, but don't apply the inverse link function.  For logistic regression
+   * and other generalized linear models, this is just the linear part of the classification.
+   * @param features  A feature vector to be classified.
+   * @return  A vector of scores.  If transformed by the link function, these will become
probabilities.
+   */
+  public Vector classifyNoLink(Vector features) {
+    throw new UnsupportedOperationException("Classifier " + this.getClass().getName() +
+      " doesn't support classification without a link");
+  }
+
+  /**
    * Classifies a vector in the special case of a binary classifier where
    * <code>classify(Vector)</code> would return a vector with only one element.
 As such,
    * using this method can void the allocation of a vector.

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Wed Sep 15 06:18:04 2010
@@ -86,6 +86,12 @@ public abstract class AbstractOnlineLogi
     }
   }
 
+  public Vector classifyNoLink(Vector instance) {
+    // apply pending regularization to whichever coefficients matter
+    regularize(instance);
+    return beta.times(instance);
+  }
+
   /**
    * Returns n-1 probabilities, one for each category but the 0-th.  The probability of the
0-th
    * category is 1 - sum(this result).
@@ -94,11 +100,7 @@ public abstract class AbstractOnlineLogi
    * @return A vector of probabilities, one for each of the first n-1 categories.
    */
   public Vector classify(Vector instance) {
-    // apply pending regularization to whichever coefficients matter
-    regularize(instance);
-
-    Vector v = beta.times(instance);
-    return logisticLink(v);
+    return logisticLink(classifyNoLink(instance));
   }
 
   /**

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Wed Sep 15 06:18:04 2010
@@ -5,6 +5,7 @@ import org.apache.mahout.classifier.Abst
 import org.apache.mahout.classifier.OnlineLearner;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.BinaryFunction;
 import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.stats.OnlineAuc;
 
@@ -131,9 +132,18 @@ public class CrossFoldLearner extends Ab
   @Override
   public Vector classify(Vector instance) {
     Vector r = new DenseVector(numCategories() - 1);
-    double scale = 1.0 / models.size();
+    BinaryFunction scale = Functions.plusMult(1.0 / models.size());
     for (OnlineLogisticRegression model : models) {
-      r.assign(model.classify(instance), Functions.plusMult(scale));
+      r.assign(model.classify(instance), scale);
+    }
+    return r;
+  }
+
+  public Vector classifyNoLink(Vector instance) {
+    Vector r = new DenseVector(numCategories() - 1);
+    BinaryFunction scale = Functions.plusMult(1.0 / models.size());
+    for (OnlineLogisticRegression model : models) {
+      r.assign(model.classifyNoLink(instance), scale);
     }
     return r;
   }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=997193&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java Wed
Sep 15 06:18:04 2010
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Maps;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectors.Dictionary;
+
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Uses sample data to reverse engineer a feature-hashed model.
+ *
+ * The result gives approximate weights for features and interactions
+ * in the original space.
+ */
+public class ModelDissector {
+  int records = 0;
+  private Dictionary dict;
+  private Matrix a;
+  private Matrix b;
+
+  public ModelDissector(int n) {
+    a = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, Integer.MAX_VALUE}, true);
+    b = new SparseRowMatrix(new int[]{Integer.MAX_VALUE, n});
+
+    dict.intern("Intercept Value");
+  }
+
+  public void addExample(Set<String> features, Vector score) {
+    for (Vector.Element element : score) {
+      b.set(records, element.index(), element.get());
+    }
+
+    for (String feature : features) {
+      int j = dict.intern(feature);
+      a.set(records, j, 1);
+    }
+    records++;
+  }
+
+  public void addExample(Set<String> features, double score) {
+    b.set(records, 0, score);
+
+    a.set(records, 0, 1);
+    for (String feature : features) {
+      int j = dict.intern(feature);
+      a.set(records, j, 1);
+    }
+    records++;
+  }
+
+  public Matrix solve() {
+    Matrix az = a.viewPart(new int[]{0, 0}, new int[]{records, dict.size()});
+    Matrix bz = b.viewPart(new int[]{0, 0}, new int[]{records, b.columnSize()});
+    QRDecomposition qr = new QRDecomposition(az.transpose().times(az));
+    Matrix x = qr.solve(bz);
+    Map<String, Integer> labels = Maps.newHashMap();
+    int i = 0;
+    for (String s : dict.values()) {
+      labels.put(s, i++);
+    }
+    x.setRowLabelBindings(labels);
+    return x;
+  }
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java?rev=997193&r1=997192&r2=997193&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/vectors/Dictionary.java Wed Sep 15 06:18:04
2010
@@ -41,6 +41,10 @@ public class Dictionary {
     return new ArrayList<String>(dict.keySet());
   }
 
+  public int size() {
+    return dict.size();
+  }
+
   public static Dictionary fromList(List<String> values) {
     Dictionary dict = new Dictionary();
     for (String value : values) {



Mime
View raw message