mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject svn commit: r1131481 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/sgd/ core/src/test/java/org/apache/mahout/classifier/sgd/ math/src/main/java/org/apache/mahout/math/function/
Date Sat, 04 Jun 2011 19:47:38 GMT
Author: srowen
Date: Sat Jun  4 19:47:37 2011
New Revision: 1131481

URL: http://svn.apache.org/viewvc?rev=1131481&view=rev
Log:
MAHOUT-703 implement Gradient machine classifier

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
Modified:
    mahout/trunk/math/src/main/java/org/apache/mahout/math/function/Functions.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java?rev=1131481&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
(added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
Sat Jun  4 19:47:37 2011
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Random;
+
+/**
+ * Online gradient machine learner that tries to minimize the label ranking hinge loss.
+ * Implements a gradient machine with one sigmpod hidden layer.
+ * It tries to minimize the ranking loss of some given set of labels,
+ * so this can be used for multi-class, multi-label
+ * or auto-encoding of sparse data (e.g. text).
+ */
+public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable
{
+
+  public static final int WRITABLE_VERSION = 1;
+
+  // the learning rate of the algorithm
+  private double learningRate = 0.1;
+
+  // the regularization term, a positive number that controls the size of the weight vector
+  private double regularization = 0.1;
+
+  // the sparsity term, a positive number that controls the sparsity of the hidden layer.
(0 - 1)
+  private double sparsity = 0.1;
+
+  // the sparsity learning rate.
+  private double sparsityLearningRate = 0.1;
+
+  // the number of features
+  private int numFeatures = 10;
+  // the number of hidden nodes
+  private int numHidden = 100;
+  // the number of output nodes
+  private int numOutput = 2;
+
+  // coefficients for the input to hidden layer.
+  // There are numHidden Vectors of dimension numFeatures.
+  private Vector[] hiddenWeights;
+
+  // coefficients for the hidden to output layer.
+  // There are numOuput Vectors of dimension numHidden.
+  private Vector[] outputWeights;
+
+  // hidden unit bias
+  private Vector hiddenBias;
+
+  // output unit bias
+  private Vector outputBias;
+
+  private final Random rnd;
+
+  public GradientMachine(int numFeatures, int numHidden, int numOutput) {
+    this.numFeatures = numFeatures;
+    this.numHidden = numHidden;
+    this.numOutput = numOutput;
+    hiddenWeights = new DenseVector[numHidden];
+    for (int i = 0; i < numHidden; i++) {
+      hiddenWeights[i] = new DenseVector(numFeatures);
+      hiddenWeights[i].assign(0);
+    }
+    hiddenBias = new DenseVector(numHidden);
+    hiddenBias.assign(0);
+    outputWeights = new DenseVector[numOutput];
+    for (int i = 0; i < numOutput; i++) {
+      outputWeights[i] = new DenseVector(numHidden);
+      outputWeights[i].assign(0);
+    }
+    outputBias = new DenseVector(numOutput);
+    outputBias.assign(0);
+    rnd = RandomUtils.getRandom();
+  }
+
+  /**
+   * Initialize weights.
+   *
+   * @param gen random number generator.
+   */
+  public void initWeights(Random gen) {
+    double hiddenFanIn = 1.0f / Math.sqrt(numFeatures);
+    for (int i = 0; i < numHidden; i++) {
+      for (int j = 0; j < numFeatures; j++) {
+        double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn;
+        hiddenWeights[i].setQuick(j, val);
+      }
+    }
+    double outputFanIn = 1.0f / Math.sqrt(numHidden);
+    for (int i = 0; i < numOutput; i++) {
+      for (int j = 0; j < numHidden; j++) {
+        double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn;
+        outputWeights[i].setQuick(j, val);
+      }
+    }
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param learningRate New value of initial learning rate.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine learningRate(double learningRate) {
+    this.learningRate = learningRate;
+    return this;
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param regularization A positive value that controls the weight vector size.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine regularization(double regularization) {
+    this.regularization = regularization;
+    return this;
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param sparsity A value between zero and one that controls the fraction of hidden units
+   *                 that are activated on average.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine sparsity(double sparsity) {
+    this.sparsity = sparsity;
+    return this;
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param sparsityLearningRate New value of initial learning rate for sparsity.
+   * @return This, so other configurations can be chained.
+   */
+  public GradientMachine sparsityLearningRate(double sparsityLearningRate) {
+    this.sparsityLearningRate = sparsityLearningRate;
+    return this;
+  }
+
+  public void copyFrom(GradientMachine other) {
+    numFeatures = other.numFeatures;
+    numHidden = other.numHidden;
+    numOutput = other.numOutput;
+    learningRate = other.learningRate;
+    regularization = other.regularization;
+    sparsity = other.sparsity;
+    sparsityLearningRate = other.sparsityLearningRate;
+    hiddenWeights = new DenseVector[numHidden];
+    for (int i = 0; i < numHidden; i++) {
+      hiddenWeights[i] = other.hiddenWeights[i].clone();
+    }
+    hiddenBias = other.hiddenBias.clone();
+    outputWeights = new DenseVector[numOutput];
+    for (int i = 0; i < numOutput; i++) {
+      outputWeights[i] = other.outputWeights[i].clone();
+    }
+    outputBias = other.outputBias.clone();
+  }
+
+  @Override
+  public int numCategories() {
+    return numOutput;
+  }
+
+  public int numFeatures() {
+    return numFeatures;
+  }
+
+  public int numHidden() {
+    return numHidden;
+  }
+
+  /**
+   * Feeds forward from input to hidden unit..
+   *
+   * @return Hidden unit activations.
+   */
+  public DenseVector inputToHidden(Vector input) {
+    DenseVector activations = new DenseVector(numHidden);
+    for (int i = 0; i < numHidden; i++) {
+      activations.setQuick(i, hiddenWeights[i].dot(input));
+    }
+    hiddenBias.addTo(activations);
+    activations.assign(Functions.min(40.0)).assign(Functions.max(-40));
+    activations.assign(Functions.SIGMOID);
+    return activations;
+  }
+
+  /**
+   * Feeds forward from hidden to output
+   *
+   * @return Output unit activations.
+   */
+  public DenseVector hiddenToOutput(Vector hiddenActivation) {
+    DenseVector activations = new DenseVector(numOutput);
+    for (int i = 0; i < numOutput; i++) {
+      activations.setQuick(i, outputWeights[i].dot(hiddenActivation));
+    }
+    outputBias.addTo(activations);
+    return activations;
+  }
+
+  /**
+   * Updates using ranking loss.
+   *
+   * @param hiddenActivation the hidden unit's activation
+   * @param goodLabels       the labels you want ranked above others.
+   * @param numTrials        how many times you want to search for the highest scoring bad
label.
+   * @param gen              Random number generator.
+   */
+  public void updateRanking(Vector hiddenActivation,
+                            Collection<Integer> goodLabels,
+                            int numTrials,
+                            Random gen) {
+    // All the labels are good, do nothing.
+    if (goodLabels.size() >= numOutput) {
+      return;
+    }
+    for (Integer good : goodLabels) {
+      double goodScore = outputWeights[good].dot(hiddenActivation);
+      int highestBad = -1;
+      double highestBadScore = Double.NEGATIVE_INFINITY;
+      for (int i = 0; i < numTrials; i++) {
+        int bad = gen.nextInt(numOutput);
+        while (goodLabels.contains(bad)) {
+          bad = gen.nextInt(numOutput);
+        }
+        double badScore = outputWeights[bad].dot(hiddenActivation);
+        if (badScore > highestBadScore) {
+          highestBadScore = badScore;
+          highestBad = bad;
+        }
+      }
+      int bad = highestBad;
+      double loss = 1.0 - goodScore + highestBadScore;
+      if (loss < 0.0) {
+        continue;
+      }
+      // Note from the loss above the gradient dloss/dy , y being the label is -1 for good
+      // and +1 for bad.
+      // dy / dw is just w since  y = x' * w + b.
+      // Hence by the chain rule, dloss / dw = dloss / dy * dy / dw = -w.
+      // For the regularization part, 0.5 * lambda * w' w, the gradient is lambda * w.
+      // dy / db = 1.
+      Vector gradGood = outputWeights[good].clone();
+      gradGood.assign(Functions.NEGATE);
+      Vector propHidden = gradGood.clone();
+      Vector gradBad = outputWeights[bad].clone();
+      gradBad.addTo(propHidden);
+      gradGood.assign(Functions.mult(-learningRate * (1.0 - regularization)));
+      gradGood.addTo(outputWeights[good]);
+      gradBad.assign(Functions.mult(-learningRate * (1.0 + regularization)));
+      gradBad.addTo(outputWeights[bad]);
+      outputBias.setQuick(good, outputBias.get(good) + learningRate);
+      outputBias.setQuick(bad, outputBias.get(bad) - learningRate);
+      // Gradient of sigmoid is s * (1 -s).
+      Vector gradSig = hiddenActivation.clone();
+      gradSig.assign(Functions.SIGMOIDGRADIENT);
+      // Multiply by the change caused by the ranking loss.
+      for (int i = 0; i < numHidden; i++) {
+        gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i));
+      }
+      for (int i = 0; i < numHidden; i++) {
+        for (int j = 0; j < numFeatures; j++) {
+          double v = hiddenWeights[i].get(j);
+          v -= learningRate * (gradSig.get(i) + regularization * v);
+          hiddenWeights[i].setQuick(j, v);
+        }
+      }
+    }
+  }
+
+  @Override
+  public Vector classify(Vector instance) {
+    Vector result = classifyNoLink(instance);
+    // Find the max value's index.
+    int max = result.maxValueIndex();
+    result.assign(0);
+    result.setQuick(max, 1.0);
+    return result.viewPart(1, result.size() - 1);
+  }
+
+  @Override
+  public Vector classifyNoLink(Vector instance) {
+    DenseVector hidden = inputToHidden(instance);
+    return hiddenToOutput(hidden);
+  }
+
+  @Override
+  public double classifyScalar(Vector instance) {
+    Vector output = classifyNoLink(instance);
+    if (output.get(0) > output.get(1)) {
+      return 0;
+    }
+    return 1;
+  }
+
+  public GradientMachine copy() {
+    close();
+    GradientMachine r = new GradientMachine(numFeatures(), numHidden(), numCategories());
+    r.copyFrom(this);
+    return r;
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeInt(WRITABLE_VERSION);
+    out.writeDouble(learningRate);
+    out.writeDouble(regularization);
+    out.writeDouble(sparsity);
+    out.writeDouble(sparsityLearningRate);
+    out.writeInt(numFeatures);
+    out.writeInt(numHidden);
+    out.writeInt(numOutput);
+    VectorWritable.writeVector(out, hiddenBias);
+    for (int i = 0; i < numHidden; i++) {
+      VectorWritable.writeVector(out, hiddenWeights[i]);
+    }
+    VectorWritable.writeVector(out, outputBias);
+    for (int i = 0; i < numOutput; i++) {
+      VectorWritable.writeVector(out, outputWeights[i]);
+    }
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    int version = in.readInt();
+    if (version == WRITABLE_VERSION) {
+      learningRate = in.readDouble();
+      regularization = in.readDouble();
+      sparsity = in.readDouble();
+      sparsityLearningRate = in.readDouble();
+      numFeatures = in.readInt();
+      numHidden = in.readInt();
+      numOutput = in.readInt();
+      hiddenWeights = new DenseVector[numHidden];
+      hiddenBias = VectorWritable.readVector(in);
+      for (int i = 0; i < numHidden; i++) {
+        hiddenWeights[i] = VectorWritable.readVector(in);
+      }
+      outputWeights = new DenseVector[numOutput];
+      outputBias = VectorWritable.readVector(in);
+      for (int i = 0; i < numOutput; i++) {
+        outputWeights[i] = VectorWritable.readVector(in);
+      }
+    } else {
+      throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got
" + version);
+    }
+  }
+
+  @Override
+  public void close() {
+    // This is an online classifier, nothing to do.
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+    Vector hiddenActivation = inputToHidden(instance);
+    Vector outputActivation = hiddenToOutput(hiddenActivation);
+    Collection<Integer> goodLabels = new HashSet<Integer>();
+    goodLabels.add(actual);
+    updateRanking(hiddenActivation, goodLabels, 2, rnd);
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(0, null, actual, instance);
+  }
+
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java?rev=1131481&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
Sat Jun  4 19:47:37 2011
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Random;
+
+public final class GradientMachineTest extends OnlineBaseTest {
+
+  @Test
+  public void testGradientmachine() throws IOException {
+    Vector target = readStandardData();
+    GradientMachine grad = new GradientMachine(8, 4, 2).learningRate(0.1).regularization(0.01);
+    RandomUtils.useTestSeed();
+    Random gen = RandomUtils.getRandom();
+    grad.initWeights(gen);
+    train(getInput(), target, grad);
+    test(getInput(), target, grad, 0.05, 1);
+  }
+
+}

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/function/Functions.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/function/Functions.java?rev=1131481&r1=1131480&r2=1131481&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/function/Functions.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/function/Functions.java Sat Jun
 4 19:47:37 2011
@@ -273,6 +273,22 @@ public final class Functions {
     }
   };
 
+  /** Function that returns <tt> 1 / (1 + exp(-a) </tt> */
+  public static final DoubleFunction SIGMOID = new DoubleFunction() {
+    @Override
+    public double apply(double a) {
+      return 1.0 / (1.0 + Math.exp(-a));
+    }
+  };
+
+  /** Function that returns <tt> a * (1-a) </tt> */
+  public static final DoubleFunction SIGMOIDGRADIENT = new DoubleFunction() {
+    @Override
+    public double apply(double a) {
+      return a * (1.0 - a);
+    }
+  };
+
   /** Function that returns <tt>Math.tan(a)</tt>. */
   public static final DoubleFunction TAN = new DoubleFunction() {
 



Mime
View raw message