hama-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From edwardy...@apache.org
Subject svn commit: r1491624 - in /hama/trunk: ./ ml/src/main/java/org/apache/hama/ml/perception/ ml/src/test/java/org/apache/hama/ml/perception/
Date Mon, 10 Jun 2013 22:11:07 GMT
Author: edwardyoon
Date: Mon Jun 10 22:11:06 2013
New Revision: 1491624

URL: http://svn.apache.org/r1491624
Log:
HAMA-760: Add new features to existing Multi Layer Perceptron (Yexi Jiang via edwardyoon)

Added:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java
Modified:
    hama/trunk/CHANGES.txt
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java
    hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java
    hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java

Modified: hama/trunk/CHANGES.txt
URL: http://svn.apache.org/viewvc/hama/trunk/CHANGES.txt?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/CHANGES.txt (original)
+++ hama/trunk/CHANGES.txt Mon Jun 10 22:11:06 2013
@@ -15,6 +15,7 @@ Release 0.7 (unreleased changes)
 
   IMPROVEMENTS
 
+   HAMA-760: Add new features to existing Multi Layer Perceptron (Yexi Jiang via edwardyoon)
    HAMA-758: Send message to non-exist vertex makes the job fail (MaoYuan Xian via edwardyoon)
    HAMA-757: The partitioning job output should be un-splitable (MaoYuan Xian via edwardyoon)
    HAMA-754: PartitioningRunner should write raw records to partition files (edwardyoon)

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunction.java Mon Jun 10
22:11:06 2013
@@ -19,7 +19,6 @@ package org.apache.hama.ml.perception;
 
 /**
  * The common interface for cost functions.
- * 
  */
 public abstract class CostFunction {
 

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CostFunctionFactory.java Mon
Jun 10 22:11:06 2013
@@ -32,9 +32,10 @@ public class CostFunctionFactory {
   public static CostFunction getCostFunction(String name) {
     if (name.equalsIgnoreCase("SquaredError")) {
       return new SquaredError();
-    } else if (name.equalsIgnoreCase("LogisticError")) {
-      return new LogisticCostFunction();
+    } else if (name.equalsIgnoreCase("CrossEntropy")) {
+      return new CrossEntropy();
     }
-    return new SquaredError();
+    throw new IllegalStateException(String.format(
+        "No cost function with name '%s' found.", name));
   }
 }

Added: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java?rev=1491624&view=auto
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java (added)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/CrossEntropy.java Mon Jun 10
22:11:06 2013
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hama.ml.perception;
+
+/**
+ * The cross entropy cost function.
+ * 
+ * <pre>
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
+ * where t denotes the target value, y denotes the estimated value.
+ * </pre>
+ */
+public class CrossEntropy extends CostFunction {
+
+  @Override
+  public double calculate(double target, double actual) {
+    return -target * Math.log(actual) - (1 - target) * Math.log(1 - actual);
+  }
+
+  @Override
+  public double calculateDerivative(double target, double actual) {
+    double adjustedTarget = target;
+    double adjustedActual = actual;
+    if (adjustedActual == 1) {
+      adjustedActual = 0.999;
+    } else if (actual == 0) {
+      adjustedActual = 0.001;
+    }
+    if (adjustedTarget == 1) {
+      adjustedTarget = 0.999;
+    } else if (adjustedTarget == 0) {
+      adjustedTarget = 0.001;
+    }
+    return -adjustedTarget / adjustedActual + (1 - adjustedTarget)
+        / (1 - adjustedActual);
+  }
+
+}

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/MultiLayerPerceptron.java Mon
Jun 10 22:11:06 2013
@@ -36,7 +36,7 @@ public abstract class MultiLayerPerceptr
   /* Model meta-data */
   protected String MLPType;
   protected double learningRate;
-  protected boolean regularization;
+  protected double regularization;
   protected double momentum;
   protected int numberOfLayers;
   protected String squashingFunctionName;
@@ -50,20 +50,33 @@ public abstract class MultiLayerPerceptr
    * Initialize the MLP.
    * 
    * @param learningRate Larger learningRate makes MLP learn more aggressive.
-   * @param regularization Turn on regularization make MLP less likely to
-   *          overfit.
+   *          Learning rate cannot be negative.
+   * @param regularization Regularization makes MLP less likely to overfit. The
+   *          value of regularization cannot be negative or too large,
+   *          otherwise it will affect the precision.
    * @param momentum The momentum makes the historical adjust have affect to
-   *          current adjust.
+   *          current adjust. The weight of momentum cannot be negative.
    * @param squashingFunctionName The name of squashing function.
    * @param costFunctionName The name of the cost function.
    * @param layerSizeArray The number of neurons for each layer. Note that the
    *          actual size of each layer is one more than the input size.
    */
-  public MultiLayerPerceptron(double learningRate, boolean regularization,
+  public MultiLayerPerceptron(double learningRate, double regularization,
       double momentum, String squashingFunctionName, String costFunctionName,
       int[] layerSizeArray) {
+    this.MLPType = getTypeName();
+    if (learningRate <= 0) {
+      throw new IllegalStateException("learning rate cannot be negative.");
+    }
     this.learningRate = learningRate;
+    if (regularization < 0 || regularization >= 0.5) {
+      throw new IllegalStateException(
+          "regularization weight must be in range (0, 0.5).");
+    }
     this.regularization = regularization; // no regularization
+    if (momentum < 0) {
+      throw new IllegalStateException("momentum weight cannot be negative.");
+    }
     this.momentum = momentum; // no momentum
     this.squashingFunctionName = squashingFunctionName;
     this.costFunctionName = costFunctionName;
@@ -101,8 +114,12 @@ public abstract class MultiLayerPerceptr
    * @param featureVector The feature of an instance to feed the perceptron.
    * @return The results.
    */
-  public abstract DoubleVector output(DoubleVector featureVector)
-      throws Exception;
+  public abstract DoubleVector output(DoubleVector featureVector);
+
+  /**
+   * Use the class name as the type name.
+   */
+  protected abstract String getTypeName();
 
   /**
    * Read the model meta-data from the specified location.
@@ -131,7 +148,7 @@ public abstract class MultiLayerPerceptr
     return learningRate;
   }
 
-  public boolean isRegularization() {
+  public double isRegularization() {
     return regularization;
   }
 

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/Sigmoid.java Mon Jun 10 22:11:06
2013
@@ -17,7 +17,6 @@
  */
 package org.apache.hama.ml.perception;
 
-
 /**
  * The Sigmoid function
  * 

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPMessage.java Mon Jun
10 22:11:06 2013
@@ -29,18 +29,47 @@ import org.apache.hama.ml.writable.Matri
  * {@link SmallMultiLayerPerceptron}. It send the whole parameter matrix from
  * one task to another.
  */
-public class SmallMLPMessage extends MLPMessage {
+class SmallMLPMessage extends MLPMessage {
 
   private int owner; // the ID of the task who creates the message
+  private int numOfUpdatedMatrices;
   private DenseDoubleMatrix[] weightUpdatedMatrices;
-  private int numOfMatrices;
+  private int numOfPrevUpdatedMatrices;
+  private DenseDoubleMatrix[] prevWeightUpdatedMatrices;
 
-  public SmallMLPMessage(int owner, boolean terminated, DenseDoubleMatrix[] mat) {
+  /**
+   * When slave send message to master, use this constructor.
+   * 
+   * @param owner The owner that create the message
+   * @param terminated Whether the training is terminated for the owner task
+   * @param weightUpdatedMatrics The weight updates
+   */
+  public SmallMLPMessage(int owner, boolean terminated,
+      DenseDoubleMatrix[] weightUpdatedMatrics) {
     super(terminated);
     this.owner = owner;
-    this.weightUpdatedMatrices = mat;
-    this.numOfMatrices = this.weightUpdatedMatrices == null ? 0
+    this.weightUpdatedMatrices = weightUpdatedMatrics;
+    this.numOfUpdatedMatrices = this.weightUpdatedMatrices == null ? 0
         : this.weightUpdatedMatrices.length;
+    this.numOfPrevUpdatedMatrices = 0;
+    this.prevWeightUpdatedMatrices = null;
+  }
+
+  /**
+   * When master send message to slave, use this constructor.
+   * 
+   * @param owner The owner that create the message
+   * @param terminated Whether the training is terminated for the owner task
+   * @param weightUpdatedMatrics The weight updates
+   * @param prevWeightUpdatedMatrices
+   */
+  public SmallMLPMessage(int owner, boolean terminated,
+      DenseDoubleMatrix[] weightUpdatedMatrices,
+      DenseDoubleMatrix[] prevWeightUpdatedMatrices) {
+    this(owner, terminated, weightUpdatedMatrices);
+    this.prevWeightUpdatedMatrices = prevWeightUpdatedMatrices;
+    this.numOfPrevUpdatedMatrices = this.prevWeightUpdatedMatrices == null ? 0
+        : this.prevWeightUpdatedMatrices.length;
   }
 
   /**
@@ -57,30 +86,44 @@ public class SmallMLPMessage extends MLP
    * 
    * @return
    */
-  public DenseDoubleMatrix[] getWeightsUpdatedMatrices() {
+  public DenseDoubleMatrix[] getWeightUpdatedMatrices() {
     return this.weightUpdatedMatrices;
   }
 
+  public DenseDoubleMatrix[] getPrevWeightsUpdatedMatrices() {
+    return this.prevWeightUpdatedMatrices;
+  }
+
   @Override
   public void readFields(DataInput input) throws IOException {
     this.owner = input.readInt();
     this.terminated = input.readBoolean();
-    this.numOfMatrices = input.readInt();
-    this.weightUpdatedMatrices = new DenseDoubleMatrix[this.numOfMatrices];
-    for (int i = 0; i < this.numOfMatrices; ++i) {
+    this.numOfUpdatedMatrices = input.readInt();
+    this.weightUpdatedMatrices = new DenseDoubleMatrix[this.numOfUpdatedMatrices];
+    for (int i = 0; i < this.numOfUpdatedMatrices; ++i) {
       this.weightUpdatedMatrices[i] = (DenseDoubleMatrix) MatrixWritable
           .read(input);
     }
+    this.numOfPrevUpdatedMatrices = input.readInt();
+    this.prevWeightUpdatedMatrices = new DenseDoubleMatrix[this.numOfPrevUpdatedMatrices];
+    for (int i = 0; i < this.numOfPrevUpdatedMatrices; ++i) {
+      this.prevWeightUpdatedMatrices[i] = (DenseDoubleMatrix) MatrixWritable
+          .read(input);
+    }
   }
 
   @Override
   public void write(DataOutput output) throws IOException {
     output.writeInt(this.owner);
     output.writeBoolean(this.terminated);
-    output.writeInt(this.numOfMatrices);
-    for (int i = 0; i < this.numOfMatrices; ++i) {
+    output.writeInt(this.numOfUpdatedMatrices);
+    for (int i = 0; i < this.numOfUpdatedMatrices; ++i) {
       MatrixWritable.write(this.weightUpdatedMatrices[i], output);
     }
+    output.writeInt(this.numOfPrevUpdatedMatrices);
+    for (int i = 0; i < this.numOfPrevUpdatedMatrices; ++i) {
+      MatrixWritable.write(this.prevWeightUpdatedMatrices[i], output);
+    }
   }
 
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMLPTrainer.java Mon Jun
10 22:11:06 2013
@@ -33,7 +33,7 @@ import org.apache.hama.ml.writable.Vecto
 /**
  * The perceptron trainer for small scale MLP.
  */
-public class SmallMLPTrainer extends PerceptronTrainer {
+class SmallMLPTrainer extends PerceptronTrainer {
 
   private static final Log LOG = LogFactory.getLog(SmallMLPTrainer.class);
   /* used by master only, check whether all slaves finishes reading */
@@ -66,7 +66,7 @@ public class SmallMLPTrainer extends Per
     // build model from scratch
     if (modelPath == null || modelPath.trim().length() == 0) {
       double learningRate = Double.parseDouble(conf.get("learningRate"));
-      boolean regularization = Boolean.parseBoolean(conf.get("regularization"));
+      double regularization = Double.parseDouble(conf.get("regularization"));
       double momentum = Double.parseDouble(conf.get("momentum"));
       String squashingFunctionName = conf.get("squashingFunctionName");
       String costFunctionName = conf.get("costFunctionName");
@@ -184,7 +184,7 @@ public class SmallMLPTrainer extends Per
         this.statusSet.set(message.getOwner());
       }
 
-      DenseDoubleMatrix[] weightUpdates = message.getWeightsUpdatedMatrices();
+      DenseDoubleMatrix[] weightUpdates = message.getWeightUpdatedMatrices();
       for (int m = 0; m < mergedUpdates.length; ++m) {
         mergedUpdates[m] = (DenseDoubleMatrix) mergedUpdates[m]
             .add(weightUpdates[m]);
@@ -206,12 +206,14 @@ public class SmallMLPTrainer extends Per
 
       // update the weight matrices
       this.inMemoryPerceptron.updateWeightMatrices(mergedUpdates);
+      this.inMemoryPerceptron.setPrevWeightUpdateMatrices(mergedUpdates);
     }
 
     // broadcast updated weight matrices
     for (String peerName : peer.getAllPeerNames()) {
       SmallMLPMessage msg = new SmallMLPMessage(peer.getPeerIndex(),
-          this.terminateTraining, this.inMemoryPerceptron.getWeightMatrices());
+          this.terminateTraining, this.inMemoryPerceptron.getWeightMatrices(),
+          this.inMemoryPerceptron.getPrevWeightUpdateMatrices());
       peer.send(peerName, msg);
     }
 
@@ -233,7 +235,9 @@ public class SmallMLPTrainer extends Per
       this.terminateTraining = message.isTerminated();
       // each slave renew its weight matrices
       this.inMemoryPerceptron.setWeightMatrices(message
-          .getWeightsUpdatedMatrices());
+          .getWeightUpdatedMatrices());
+      this.inMemoryPerceptron.setPrevWeightUpdateMatrices(message
+          .getPrevWeightsUpdatedMatrices());
       if (this.terminateTraining) {
         return true;
       }
@@ -272,8 +276,8 @@ public class SmallMLPTrainer extends Per
       weightUpdates[m] = (DenseDoubleMatrix) weightUpdates[m].divide(count);
     }
 
-    LOG.info(String.format("Task %d has read %d records.",
-        peer.getPeerIndex(), this.numTrainingInstanceRead));
+    LOG.info(String.format("Task %d has read %d records.", peer.getPeerIndex(),
+        this.numTrainingInstanceRead));
 
     // send the weight updates to master task
     SmallMLPMessage message = new SmallMLPMessage(peer.getPeerIndex(),

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SmallMultiLayerPerceptron.java
Mon Jun 10 22:11:06 2013
@@ -65,16 +65,19 @@ public final class SmallMultiLayerPercep
   /* The in-memory weight matrix */
   private DenseDoubleMatrix[] weightMatrice;
 
+  /* Previous weight updates, used for momentum */
+  private DenseDoubleMatrix[] prevWeightUpdateMatrices;
+
   /**
    * {@inheritDoc}
    */
-  public SmallMultiLayerPerceptron(double learningRate, boolean regularization,
+  public SmallMultiLayerPerceptron(double learningRate, double regularization,
       double momentum, String squashingFunctionName, String costFunctionName,
       int[] layerSizeArray) {
     super(learningRate, regularization, momentum, squashingFunctionName,
         costFunctionName, layerSizeArray);
-    this.MLPType = "SmallMLP";
     initializeWeightMatrix();
+    this.initializePrevWeightUpdateMatrix();
   }
 
   /**
@@ -85,6 +88,7 @@ public final class SmallMultiLayerPercep
     if (modelPath != null) {
       try {
         this.readFromModel();
+        this.initializePrevWeightUpdateMatrix();
       } catch (IOException e) {
         e.printStackTrace();
       }
@@ -113,20 +117,30 @@ public final class SmallMultiLayerPercep
     }
   }
 
+  /**
+   * Initial the momentum weight matrices.
+   */
+  private void initializePrevWeightUpdateMatrix() {
+    this.prevWeightUpdateMatrices = new DenseDoubleMatrix[this.numberOfLayers - 1];
+    for (int i = 0; i < this.prevWeightUpdateMatrices.length; ++i) {
+      int row = this.layerSizeArray[i] + 1;
+      int col = this.layerSizeArray[i + 1];
+      this.prevWeightUpdateMatrices[i] = new DenseDoubleMatrix(row, col);
+    }
+  }
+
   @Override
   /**
    * {@inheritDoc}
    * The model meta-data is stored in memory.
    */
-  public DoubleVector output(DoubleVector featureVector) throws Exception {
+  public DoubleVector output(DoubleVector featureVector) {
     List<double[]> outputCache = this.outputInternal(featureVector);
     // the output of the last layer is the output of the MLP
     return new DenseDoubleVector(outputCache.get(outputCache.size() - 1));
   }
 
-  private List<double[]> outputInternal(DoubleVector featureVector)
-      throws Exception {
-
+  private List<double[]> outputInternal(DoubleVector featureVector) {
     // store the output of the hidden layers and output layer, each array store
     // one layer
     List<double[]> outputCache = new ArrayList<double[]>();
@@ -134,7 +148,7 @@ public final class SmallMultiLayerPercep
     // start from the first hidden layer
     double[] intermediateResults = new double[this.layerSizeArray[0] + 1];
     if (intermediateResults.length - 1 != featureVector.getDimension()) {
-      throw new Exception(
+      throw new IllegalStateException(
           "Input feature dimension incorrect! The dimension of input layer is "
               + (this.layerSizeArray[0] - 1)
               + ", but the dimension of input feature is "
@@ -227,17 +241,31 @@ public final class SmallMultiLayerPercep
     double[] outputLayerOutput = outputCache.get(outputCache.size() - 1);
     double[] lastHiddenLayerOutput = outputCache.get(outputCache.size() - 2);
 
+    DenseDoubleMatrix prevWeightUpdateMatrix = this.prevWeightUpdateMatrices[this.prevWeightUpdateMatrices.length
- 1];
     for (int j = 0; j < delta.length; ++j) {
-      delta[j] = this.squashingFunction
-          .calculateDerivative(outputLayerOutput[j])
-          * this.costFunction.calculateDerivative(trainingLabels[j],
-              outputLayerOutput[j]);
+      delta[j] = this.costFunction.calculateDerivative(trainingLabels[j],
+          outputLayerOutput[j]);
+      // add regularization term
+      if (this.regularization != 0.0) {
+        double derivativeRegularization = 0.0;
+        DenseDoubleMatrix weightMatrix = this.weightMatrice[this.weightMatrice.length - 1];
+        for (int k = 0; k < this.layerSizeArray[this.layerSizeArray.length - 1]; ++k)
{
+          derivativeRegularization += weightMatrix.get(k, j);
+        }
+        derivativeRegularization /= this.layerSizeArray[this.layerSizeArray.length - 1];
+        delta[j] += this.regularization * derivativeRegularization;
+      }
+
+      delta[j] *= this.squashingFunction
+          .calculateDerivative(outputLayerOutput[j]);
 
       // calculate the weight update matrix between the last hidden layer and
       // the output layer
       for (int i = 0; i < this.layerSizeArray[this.layerSizeArray.length - 2] + 1; ++i)
{
-        double updatedValue = this.learningRate * delta[j]
+        double updatedValue = -this.learningRate * delta[j]
             * lastHiddenLayerOutput[i];
+        // add momentum
+        updatedValue += this.momentum * prevWeightUpdateMatrix.get(i, j);
         weightUpdateMatrices[weightUpdateMatrices.length - 1].set(i, j,
             updatedValue);
       }
@@ -270,6 +298,7 @@ public final class SmallMultiLayerPercep
     double[] curLayerOutput = outputCache.get(curLayerIdx);
     double[] prevLayerOutput = outputCache.get(prevLayerIdx);
 
+    DenseDoubleMatrix prevWeightUpdateMatrix = this.prevWeightUpdateMatrices[curLayerIdx
- 1];
     // for each neuron j in nextLayer, calculate the delta
     for (int j = 0; j < delta.length; ++j) {
       // aggregate delta from next layer
@@ -283,7 +312,10 @@ public final class SmallMultiLayerPercep
       // calculate the weight update matrix between the previous layer and the
       // current layer
       for (int i = 0; i < weightUpdateMatrices[prevLayerIdx].getRowCount(); ++i) {
-        double updatedValue = this.learningRate * delta[j] * prevLayerOutput[i];
+        double updatedValue = -this.learningRate * delta[j]
+            * prevLayerOutput[i];
+        // add momemtum
+        updatedValue += this.momentum * prevWeightUpdateMatrix.get(i, j);
         weightUpdateMatrices[prevLayerIdx].set(i, j, updatedValue);
       }
     }
@@ -349,7 +381,7 @@ public final class SmallMultiLayerPercep
   public void readFields(DataInput input) throws IOException {
     this.MLPType = WritableUtils.readString(input);
     this.learningRate = input.readDouble();
-    this.regularization = input.readBoolean();
+    this.regularization = input.readDouble();
     this.momentum = input.readDouble();
     this.numberOfLayers = input.readInt();
     this.squashingFunctionName = WritableUtils.readString(input);
@@ -373,7 +405,7 @@ public final class SmallMultiLayerPercep
   public void write(DataOutput output) throws IOException {
     WritableUtils.writeString(output, MLPType);
     output.writeDouble(learningRate);
-    output.writeBoolean(regularization);
+    output.writeDouble(regularization);
     output.writeDouble(momentum);
     output.writeInt(numberOfLayers);
     WritableUtils.writeString(output, squashingFunctionName);
@@ -402,6 +434,11 @@ public final class SmallMultiLayerPercep
       FileSystem fs = FileSystem.get(uri, conf);
       FSDataInputStream is = new FSDataInputStream(fs.open(new Path(modelPath)));
       this.readFields(is);
+      if (!this.MLPType.equals(this.getClass().getName())) {
+        throw new IllegalStateException(String.format(
+            "Model type incorrect, cannot load model '%s' for '%s'.",
+            this.MLPType, this.getClass().getName()));
+      }
     } catch (URISyntaxException e) {
       e.printStackTrace();
     }
@@ -425,10 +462,19 @@ public final class SmallMultiLayerPercep
     return this.weightMatrice;
   }
 
+  DenseDoubleMatrix[] getPrevWeightUpdateMatrices() {
+    return this.prevWeightUpdateMatrices;
+  }
+
   void setWeightMatrices(DenseDoubleMatrix[] newMatrices) {
     this.weightMatrice = newMatrices;
   }
 
+  void setPrevWeightUpdateMatrices(
+      DenseDoubleMatrix[] newPrevWeightUpdateMatrices) {
+    this.prevWeightUpdateMatrices = newPrevWeightUpdateMatrices;
+  }
+
   /**
    * Update the weight matrices with given updates.
    * 
@@ -462,4 +508,9 @@ public final class SmallMultiLayerPercep
     return sb.toString();
   }
 
+  @Override
+  protected String getTypeName() {
+    return this.getClass().getName();
+  }
+
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquaredError.java Mon Jun 10
22:11:06 2013
@@ -40,7 +40,8 @@ public class SquaredError extends CostFu
    * {@inheritDoc}
    */
   public double calculateDerivative(double target, double actual) {
-    return target - actual;
+    // return target - actual;
+    return actual - target;
   }
 
 }

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/perception/SquashingFunctionFactory.java
Mon Jun 10 22:11:06 2013
@@ -36,7 +36,8 @@ public class SquashingFunctionFactory {
     } else if (name.equalsIgnoreCase("Tanh")) {
       return new Tanh();
     }
-    return new Sigmoid();
+    throw new IllegalStateException(String.format(
+        "No squashing function with name '%s' found.", name));
   }
 
 }

Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java (original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMLPMessage.java Mon
Jun 10 22:11:06 2013
@@ -1,4 +1,3 @@
-
 /**
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
@@ -33,7 +32,6 @@ import org.apache.hadoop.fs.Path;
 import org.apache.hama.ml.math.DenseDoubleMatrix;
 import org.junit.Test;
 
-
 /**
  * Test the functionalities of SmallMLPMessage
  * 
@@ -41,12 +39,10 @@ import org.junit.Test;
 public class TestSmallMLPMessage {
 
   @Test
-  public void testReadWrite() {
+  public void testReadWriteWithoutPrevUpdate() {
     int owner = 101;
     double[][] mat = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
-
     double[][] mat2 = { { 10, 20 }, { 30, 40 }, { 50, 60 } };
-
     double[][][] mats = { mat, mat2 };
 
     DenseDoubleMatrix[] matrices = new DenseDoubleMatrix[] {
@@ -68,11 +64,10 @@ public class TestSmallMLPMessage {
       outMessage.readFields(in);
 
       assertEquals(owner, outMessage.getOwner());
-      DenseDoubleMatrix[] outMatrices = outMessage.getWeightsUpdatedMatrices();
+      DenseDoubleMatrix[] outMatrices = outMessage.getWeightUpdatedMatrices();
       // check each matrix
       for (int i = 0; i < outMatrices.length; ++i) {
-        double[][] outMat = outMessage.getWeightsUpdatedMatrices()[i]
-            .getValues();
+        double[][] outMat = outMatrices[i].getValues();
         for (int j = 0; j < outMat.length; ++j) {
           assertArrayEquals(mats[i][j], outMat[j], 0.0001);
         }
@@ -84,6 +79,69 @@ public class TestSmallMLPMessage {
     } catch (URISyntaxException e) {
       e.printStackTrace();
     }
+  }
+
+  @Test
+  public void testReadWriteWithPrevUpdate() {
+    int owner = 101;
+    double[][] mat = { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
+    double[][] mat2 = { { 10, 20 }, { 30, 40 }, { 50, 60 } };
+    double[][][] mats = { mat, mat2 };
+
+    double[][] prevMat = { { 0.1, 0.2, 0.3 }, { 0.4, 0.5, 0.6 },
+        { 0.7, 0.8, 0.9 } };
+    double[][] prevMat2 = { { 1, 2 }, { 3, 4 }, { 5, 6 } };
+    double[][][] prevMats = { prevMat, prevMat2 };
+
+    DenseDoubleMatrix[] matrices = new DenseDoubleMatrix[] {
+        new DenseDoubleMatrix(mat), new DenseDoubleMatrix(mat2) };
+
+    DenseDoubleMatrix[] prevMatrices = new DenseDoubleMatrix[] {
+        new DenseDoubleMatrix(prevMat), new DenseDoubleMatrix(prevMat2) };
+
+    boolean terminated = false;
+    SmallMLPMessage message = new SmallMLPMessage(owner, terminated, matrices,
+        prevMatrices);
+
+    Configuration conf = new Configuration();
+    String strPath = "/tmp/testSmallMLPMessageWithPrevMatrices";
+    Path path = new Path(strPath);
+    try {
+      FileSystem fs = FileSystem.get(new URI(strPath), conf);
+      FSDataOutputStream out = fs.create(path, true);
+      message.write(out);
+      out.close();
+
+      FSDataInputStream in = fs.open(path);
+      SmallMLPMessage outMessage = new SmallMLPMessage(0, false, null);
+      outMessage.readFields(in);
+
+      assertEquals(owner, outMessage.getOwner());
+      assertEquals(terminated, outMessage.isTerminated());
+      DenseDoubleMatrix[] outMatrices = outMessage.getWeightUpdatedMatrices();
+      // check each matrix
+      for (int i = 0; i < outMatrices.length; ++i) {
+        double[][] outMat = outMatrices[i].getValues();
+        for (int j = 0; j < outMat.length; ++j) {
+          assertArrayEquals(mats[i][j], outMat[j], 0.0001);
+        }
+      }
+
+      DenseDoubleMatrix[] outPrevMatrices = outMessage
+          .getPrevWeightsUpdatedMatrices();
+      // check each matrix
+      for (int i = 0; i < outPrevMatrices.length; ++i) {
+        double[][] outMat = outPrevMatrices[i].getValues();
+        for (int j = 0; j < outMat.length; ++j) {
+          assertArrayEquals(prevMats[i][j], outMat[j], 0.0001);
+        }
+      }
 
+      fs.delete(path, true);
+    } catch (IOException e) {
+      e.printStackTrace();
+    } catch (URISyntaxException e) {
+      e.printStackTrace();
+    }
   }
 }

Modified: hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java?rev=1491624&r1=1491623&r2=1491624&view=diff
==============================================================================
--- hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
(original)
+++ hama/trunk/ml/src/test/java/org/apache/hama/ml/perception/TestSmallMultiLayerPerceptron.java
Mon Jun 10 22:11:06 2013
@@ -50,8 +50,8 @@ public class TestSmallMultiLayerPerceptr
   @Test
   public void testWriteReadMLP() {
     String modelPath = "/tmp/sampleModel-testWriteReadMLP.data";
-    double learningRate = 0.5;
-    boolean regularization = false; // no regularization
+    double learningRate = 0.3;
+    double regularization = 0.0; // no regularization
     double momentum = 0; // no momentum
     String squashingFunctionName = "Sigmoid";
     String costFunctionName = "SquaredError";
@@ -70,9 +70,9 @@ public class TestSmallMultiLayerPerceptr
       Configuration conf = new Configuration();
       FileSystem fs = FileSystem.get(conf);
       mlp = new SmallMultiLayerPerceptron(modelPath);
-      assertEquals("SmallMLP", mlp.getMLPType());
+      assertEquals(mlp.getClass().getName(), mlp.getMLPType());
       assertEquals(learningRate, mlp.getLearningRate(), 0.001);
-      assertEquals(regularization, mlp.isRegularization());
+      assertEquals(regularization, mlp.isRegularization(), 0.001);
       assertEquals(layerSizeArray.length, mlp.getNumberOfLayers());
       assertEquals(momentum, mlp.getMomentum(), 0.001);
       assertEquals(squashingFunctionName, mlp.getSquashingFunctionName());
@@ -97,10 +97,10 @@ public class TestSmallMultiLayerPerceptr
       FileSystem fs = FileSystem.get(conf);
       FSDataOutputStream output = fs.create(new Path(modelPath), true);
 
-      String MLPType = "SmallMLP";
+      String MLPType = SmallMultiLayerPerceptron.class.getName();
       double learningRate = 0.5;
-      boolean regularization = false;
-      double momentum = 0;
+      double regularization = 0.0;
+      double momentum = 0.1;
       String squashingFunctionName = "Sigmoid";
       String costFunctionName = "SquaredError";
       int[] layerSizeArray = new int[] { 3, 2, 3, 3 };
@@ -108,7 +108,7 @@ public class TestSmallMultiLayerPerceptr
 
       WritableUtils.writeString(output, MLPType);
       output.writeDouble(learningRate);
-      output.writeBoolean(regularization);
+      output.writeDouble(regularization);
       output.writeDouble(momentum);
       output.writeInt(numberOfLayers);
       WritableUtils.writeString(output, squashingFunctionName);
@@ -162,10 +162,10 @@ public class TestSmallMultiLayerPerceptr
   }
 
   /**
-   * Test the MLP on XOR problem.
+   * Test training with squared error on the XOR problem.
    */
   @Test
-  public void testSingleInstanceTraining() {
+  public void testTrainWithSquaredError() {
     // generate training data
     DoubleVector[] trainingData = new DenseDoubleVector[] {
         new DenseDoubleVector(new double[] { 0, 0, 0 }),
@@ -174,8 +174,8 @@ public class TestSmallMultiLayerPerceptr
         new DenseDoubleVector(new double[] { 1, 1, 0 }) };
 
     // set parameters
-    double learningRate = 0.6;
-    boolean regularization = false; // no regularization
+    double learningRate = 0.5;
+    double regularization = 0.02; // no regularization
     double momentum = 0; // no momentum
     String squashingFunctionName = "Sigmoid";
     String costFunctionName = "SquaredError";
@@ -207,6 +207,142 @@ public class TestSmallMultiLayerPerceptr
   }
 
   /**
+   * Test training with cross entropy on the XOR problem.
+   */
+  @Test
+  public void testTrainWithCrossEntropy() {
+    // generate training data
+    DoubleVector[] trainingData = new DenseDoubleVector[] {
+        new DenseDoubleVector(new double[] { 0, 0, 0 }),
+        new DenseDoubleVector(new double[] { 0, 1, 1 }),
+        new DenseDoubleVector(new double[] { 1, 0, 1 }),
+        new DenseDoubleVector(new double[] { 1, 1, 0 }) };
+
+    // set parameters
+    double learningRate = 0.5;
+    double regularization = 0.0; // no regularization
+    double momentum = 0; // no momentum
+    String squashingFunctionName = "Sigmoid";
+    String costFunctionName = "CrossEntropy";
+    int[] layerSizeArray = new int[] { 2, 7, 1 };
+    SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
+        regularization, momentum, squashingFunctionName, costFunctionName,
+        layerSizeArray);
+
+    try {
+      // train by multiple instances
+      Random rnd = new Random();
+      for (int i = 0; i < 20000; ++i) {
+        DenseDoubleMatrix[] weightUpdates = mlp
+            .trainByInstance(trainingData[rnd.nextInt(4)]);
+        mlp.updateWeightMatrices(weightUpdates);
+      }
+
+      // System.out.printf("Weight matrices: %s\n",
+      // mlp.weightsToString(mlp.getWeightMatrices()));
+      for (int i = 0; i < trainingData.length; ++i) {
+        DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i]
+            .slice(2);
+        assertEquals(trainingData[i].toArray()[2], mlp.output(testVec)
+            .toArray()[0], 0.2);
+      }
+    } catch (Exception e) {
+      e.printStackTrace();
+    }
+  }
+
+  /**
+   * Test training with regularizatiion.
+   */
+  @Test
+  public void testWithRegularization() {
+    // generate training data
+    DoubleVector[] trainingData = new DenseDoubleVector[] {
+        new DenseDoubleVector(new double[] { 0, 0, 0 }),
+        new DenseDoubleVector(new double[] { 0, 1, 1 }),
+        new DenseDoubleVector(new double[] { 1, 0, 1 }),
+        new DenseDoubleVector(new double[] { 1, 1, 0 }) };
+
+    // set parameters
+    double learningRate = 0.5;
+    double regularization = 0.02; // regularization should be a tiny number
+    double momentum = 0; // no momentum
+    String squashingFunctionName = "Sigmoid";
+    String costFunctionName = "CrossEntropy";
+    int[] layerSizeArray = new int[] { 2, 7, 1 };
+    SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
+        regularization, momentum, squashingFunctionName, costFunctionName,
+        layerSizeArray);
+
+    try {
+      // train by multiple instances
+      Random rnd = new Random();
+      for (int i = 0; i < 10000; ++i) {
+        DenseDoubleMatrix[] weightUpdates = mlp
+            .trainByInstance(trainingData[rnd.nextInt(4)]);
+        mlp.updateWeightMatrices(weightUpdates);
+      }
+
+      // System.out.printf("Weight matrices: %s\n",
+      // mlp.weightsToString(mlp.getWeightMatrices()));
+      for (int i = 0; i < trainingData.length; ++i) {
+        DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i]
+            .slice(2);
+        assertEquals(trainingData[i].toArray()[2], mlp.output(testVec)
+            .toArray()[0], 0.2);
+      }
+    } catch (Exception e) {
+      e.printStackTrace();
+    }
+  }
+  
+  /**
+   * Test training with momentum.
+   * The MLP can converge faster.
+   */
+  @Test
+  public void testWithMomentum() {
+    // generate training data
+    DoubleVector[] trainingData = new DenseDoubleVector[] {
+        new DenseDoubleVector(new double[] { 0, 0, 0 }),
+        new DenseDoubleVector(new double[] { 0, 1, 1 }),
+        new DenseDoubleVector(new double[] { 1, 0, 1 }),
+        new DenseDoubleVector(new double[] { 1, 1, 0 }) };
+
+    // set parameters
+    double learningRate = 0.5;
+    double regularization = 0.02; // regularization should be a tiny number
+    double momentum = 0.5; // no momentum
+    String squashingFunctionName = "Sigmoid";
+    String costFunctionName = "CrossEntropy";
+    int[] layerSizeArray = new int[] { 2, 7, 1 };
+    SmallMultiLayerPerceptron mlp = new SmallMultiLayerPerceptron(learningRate,
+        regularization, momentum, squashingFunctionName, costFunctionName,
+        layerSizeArray);
+
+    try {
+      // train by multiple instances
+      Random rnd = new Random();
+      for (int i = 0; i < 3000; ++i) {
+        DenseDoubleMatrix[] weightUpdates = mlp
+            .trainByInstance(trainingData[rnd.nextInt(4)]);
+        mlp.updateWeightMatrices(weightUpdates);
+      }
+
+      // System.out.printf("Weight matrices: %s\n",
+      // mlp.weightsToString(mlp.getWeightMatrices()));
+      for (int i = 0; i < trainingData.length; ++i) {
+        DenseDoubleVector testVec = (DenseDoubleVector) trainingData[i]
+            .slice(2);
+        assertEquals(trainingData[i].toArray()[2], mlp.output(testVec)
+            .toArray()[0], 0.2);
+      }
+    } catch (Exception e) {
+      e.printStackTrace();
+    }
+  }
+
+  /**
    * Test the XOR problem.
    */
   @Test
@@ -246,8 +382,8 @@ public class TestSmallMultiLayerPerceptr
     // begin training
     String modelPath = "/tmp/xorModel-training-by-xor.data";
     double learningRate = 0.6;
-    boolean regularization = false; // no regularization
-    double momentum = 0; // no momentum
+    double regularization = 0.02; // no regularization
+    double momentum = 0.3; // no momentum
     String squashingFunctionName = "Tanh";
     String costFunctionName = "SquaredError";
     int[] layerSizeArray = new int[] { 2, 5, 1 };
@@ -256,7 +392,7 @@ public class TestSmallMultiLayerPerceptr
         layerSizeArray);
 
     Map<String, String> trainingParams = new HashMap<String, String>();
-    trainingParams.put("training.iteration", "10000");
+    trainingParams.put("training.iteration", "1000");
     trainingParams.put("training.mode", "minibatch.gradient.descent");
     trainingParams.put("training.batch.size", "100");
     trainingParams.put("tasks", "3");



Mime
View raw message