horn-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From edwardy...@apache.org
Subject incubator-horn git commit: HORN-23: Add softmax function
Date Sat, 21 May 2016 08:54:53 GMT
Repository: incubator-horn
Updated Branches:
  refs/heads/master 9f35e9fb2 -> ca560628c


HORN-23: Add softmax function


Project: http://git-wip-us.apache.org/repos/asf/incubator-horn/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-horn/commit/ca560628
Tree: http://git-wip-us.apache.org/repos/asf/incubator-horn/tree/ca560628
Diff: http://git-wip-us.apache.org/repos/asf/incubator-horn/diff/ca560628

Branch: refs/heads/master
Commit: ca560628c8867028900daea4e7c6134dd7f0484d
Parents: 9f35e9f
Author: Edward J. Yoon <edwardyoon@apache.org>
Authored: Sat May 21 16:11:03 2016 +0900
Committer: Edward J. Yoon <edwardyoon@apache.org>
Committed: Sat May 21 17:37:02 2016 +0900

----------------------------------------------------------------------
 README.md                                       |  15 +--
 bin/horn                                        |   2 +-
 conf/horn-env.sh                                |   2 +-
 .../horn/core/AbstractLayeredNeuralNetwork.java |  14 ++-
 .../apache/horn/core/AbstractNeuralNetwork.java |   1 -
 .../apache/horn/core/IntermediateOutput.java    |  23 ++++
 .../org/apache/horn/core/LayerInterface.java    |  28 +++++
 .../apache/horn/core/LayeredNeuralNetwork.java  | 111 +++++++++++++------
 .../horn/core/LayeredNeuralNetworkTrainer.java  |  54 +++++----
 src/main/java/org/apache/horn/core/Neuron.java  |  22 ++++
 .../org/apache/horn/core/NeuronInterface.java   |   2 +-
 src/main/java/org/apache/horn/core/Synapse.java |   2 +-
 .../horn/examples/MultiLayerPerceptron.java     |  12 +-
 .../horn/funcs/CategoricalCrossEntropy.java     |  40 +++++++
 .../org/apache/horn/funcs/CrossEntropy.java     |   9 +-
 .../org/apache/horn/funcs/FunctionFactory.java  |   9 +-
 src/main/java/org/apache/horn/funcs/ReLU.java   |   7 +-
 .../java/org/apache/horn/funcs/Sigmoid.java     |   5 +
 .../java/org/apache/horn/funcs/SoftMax.java     |  58 ++++++++++
 .../org/apache/horn/utils/MNISTEvaluator.java   |  48 ++++----
 .../horn/examples/MultiLayerPerceptronTest.java |  40 +++++--
 21 files changed, 388 insertions(+), 116 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/README.md
----------------------------------------------------------------------
diff --git a/README.md b/README.md
index cb217f6..b797783 100644
--- a/README.md
+++ b/README.md
@@ -49,9 +49,8 @@ Also, Apache Horn provides a simplified and intuitive configuration interface. T
   ..
 
   job.inputLayer(784, Sigmoid.class, StandardNeuron.class);
-  job.addLayer(500, Sigmoid.class, StandardNeuron.class);
-  job.addLayer(500, Sigmoid.class, StandardNeuron.class);
-  job.outputLayer(10, Sigmoid.class, StandardNeuron.class);
+  job.addLayer(100, Sigmoid.class, StandardNeuron.class);
+  job.outputLayer(10, SoftMax.class, StandardNeuron.class);
   job.setCostFunction(CrossEntropy.class);
 ```
 
@@ -59,16 +58,18 @@ Also, Apache Horn provides a simplified and intuitive configuration interface. T
 
 Download a MNIST training and label datasets, and convert into a HDFS sequence file with following command:
 ```
- % bin/horn jar horn-0.x.0.jar MNISTConverter train-images.idx3-ubyte train-labels.idx1-ubyte /tmp/mnist.seq 
+ % bin/horn jar horn-0.x.0.jar MNISTConverter \
+   train-images.idx3-ubyte train-labels.idx1-ubyte /tmp/mnist.seq 
 ```
 
-Then, train it with following command (in this example, we used η 0.002, λ 0.1, 100 hidden units, and minibatch 10):
+Then, train it with following command (in this example, we used η 0.01, α 0.9, λ 0.0005, 100 hidden units, and minibatch 10):
 ```
  % bin/horn jar horn-0.x.0.jar MultiLayerPerceptron /tmp/model /tmp/mnist.seq \
-   0.002 0.0 0.1 784 100 10 10 12000
- 
+   0.01 0.9 0.00075 784 100 10 10 12000
 ```
 
+With this default example, you'll reach over the 95% accuracy. The local-mode of multithread-based parallel synchronous SGD will took around 1 hour to train. 
+
 ## High Scalability
 
 The Apache Horn is an Sync and Async hybrid distributed training framework. Within single BSP job, each task group works asynchronously using region barrier synchronization instead of global barrier synchronization, and trains large-scale neural network model using assigned data sets in synchronous way.

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/bin/horn
----------------------------------------------------------------------
diff --git a/bin/horn b/bin/horn
index e697695..8cbd106 100755
--- a/bin/horn
+++ b/bin/horn
@@ -58,7 +58,7 @@ if [ "$JAVA_HOME" = "" ]; then
 fi
 
 JAVA=$JAVA_HOME/bin/java
-JAVA_HEAP_MAX=-Xmx512m
+JAVA_HEAP_MAX=-Xmx2048m
 
 # check envvars which might override default args
 if [ "$HORN_HEAPSIZE" != "" ]; then

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/conf/horn-env.sh
----------------------------------------------------------------------
diff --git a/conf/horn-env.sh b/conf/horn-env.sh
index a033fe0..ca7ed32 100644
--- a/conf/horn-env.sh
+++ b/conf/horn-env.sh
@@ -22,5 +22,5 @@
 # Set environment variables here.
 
 # The java implementation to use.  Required.
-export JAVA_HOME=/usr/lib/jvm/java-8-oracle/
+export JAVA_HOME=/Library/Java/JavaVirtualMachines/jdk1.8.0_25.jdk/Contents/Home
 

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
index e415a25..b82ad41 100644
--- a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
@@ -30,7 +30,10 @@ import org.apache.hama.commons.math.DoubleMatrix;
 import org.apache.hama.commons.math.DoubleVector;
 import org.apache.horn.core.Constants.LearningStyle;
 import org.apache.horn.core.Constants.TrainingMethod;
+import org.apache.horn.funcs.CategoricalCrossEntropy;
+import org.apache.horn.funcs.CrossEntropy;
 import org.apache.horn.funcs.FunctionFactory;
+import org.mortbay.log.Log;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
@@ -65,7 +68,7 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
   protected List<Integer> layerSizeList;
 
   protected TrainingMethod trainingMethod;
-  
+
   protected LearningStyle learningStyle;
 
   public AbstractLayeredNeuralNetwork() {
@@ -77,6 +80,11 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
 
   public AbstractLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
     super(conf, modelPath);
+    if (this.layerSizeList.get(this.layerSizeList.size() - 1) > 1
+        && this.costFunction.getFunctionName().equalsIgnoreCase(
+            CrossEntropy.class.getSimpleName())) {
+      this.setCostFunction(new CategoricalCrossEntropy());
+    }
   }
 
   /**
@@ -118,11 +126,11 @@ abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
   public TrainingMethod getTrainingMethod() {
     return this.trainingMethod;
   }
-  
+
   public void setLearningStyle(LearningStyle style) {
     this.learningStyle = style;
   }
-  
+
   public LearningStyle getLearningStyle() {
     return this.learningStyle;
   }

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
index 5624e49..77d6af0 100644
--- a/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
@@ -24,7 +24,6 @@ import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.net.URI;
 import java.net.URISyntaxException;
-import java.util.Map;
 
 import org.apache.commons.lang.SerializationUtils;
 import org.apache.hadoop.conf.Configuration;

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/IntermediateOutput.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/IntermediateOutput.java b/src/main/java/org/apache/horn/core/IntermediateOutput.java
new file mode 100644
index 0000000..272fed0
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/IntermediateOutput.java
@@ -0,0 +1,23 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+
+public abstract class IntermediateOutput implements LayerInterface {
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/LayerInterface.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayerInterface.java b/src/main/java/org/apache/horn/core/LayerInterface.java
new file mode 100644
index 0000000..c010cc9
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/LayerInterface.java
@@ -0,0 +1,28 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+
+import org.apache.hama.commons.math.DoubleVector;
+
+public interface LayerInterface {
+
+  public DoubleVector interlayer(DoubleVector intermediateOutput) throws IOException;
+  
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
index fe5d3a3..d4f2f3e 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
@@ -46,6 +46,7 @@ import org.apache.hama.util.ReflectionUtils;
 import org.apache.horn.core.Constants.LearningStyle;
 import org.apache.horn.core.Constants.TrainingMethod;
 import org.apache.horn.funcs.FunctionFactory;
+import org.apache.horn.funcs.SoftMax;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.Lists;
@@ -76,7 +77,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
   protected List<DoubleFunction> squashingFunctionList;
 
   protected List<Class<? extends Neuron>> neuronClassList;
-  
+
   protected int finalLayerIdx;
 
   public LayeredNeuralNetwork() {
@@ -97,9 +98,20 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
    */
   public int addLayer(int size, boolean isFinalLayer,
       DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass) {
+    return addLayer(size, isFinalLayer, squashingFunction, neuronClass, null);
+  }
+
+  public int addLayer(int size, boolean isFinalLayer,
+      DoubleFunction squashingFunction, Class<? extends Neuron> neuronClass,
+      Class<? extends IntermediateOutput> interlayer) {
     Preconditions.checkArgument(size > 0,
         "Size of layer must be larger than 0.");
     if (!isFinalLayer) {
+      if (this.layerSizeList.size() == 0) {
+        LOG.info("add input layer: " + size + " neurons");
+      } else {
+        LOG.info("add hidden layer: " + size + " neurons");
+      }
       size += 1;
     }
 
@@ -107,6 +119,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
     int layerIdx = this.layerSizeList.size() - 1;
     if (isFinalLayer) {
       this.finalLayerIdx = layerIdx;
+      LOG.info("add output layer: " + size + " neurons");
     }
 
     // add weights between current layer and previous layer, and input layer has
@@ -133,6 +146,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
       this.weightMatrixList.add(weightMatrix);
       this.prevWeightUpdatesList.add(new DenseDoubleMatrix(row, col));
       this.squashingFunctionList.add(squashingFunction);
+
       this.neuronClassList.add(neuronClass);
     }
     return layerIdx;
@@ -152,6 +166,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
 
   /**
    * Set the previous weight matrices.
+   * 
    * @param prevUpdates
    */
   void setPrevWeightMatrices(DoubleMatrix[] prevUpdates) {
@@ -263,12 +278,12 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
     for (Class<? extends Neuron> clazz : this.neuronClassList) {
       output.writeUTF(clazz.getName());
     }
-    
+
     // write squashing functions
     output.writeInt(this.squashingFunctionList.size());
     for (DoubleFunction aSquashingFunctionList : this.squashingFunctionList) {
-      WritableUtils.writeString(output, aSquashingFunctionList
-              .getFunctionName());
+      WritableUtils.writeString(output,
+          aSquashingFunctionList.getFunctionName());
     }
 
     // write weight matrices
@@ -327,21 +342,18 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
     outputCache.add(intermediateOutput);
 
     for (int i = 0; i < this.layerSizeList.size() - 1; ++i) {
-      intermediateOutput = forward(i, intermediateOutput);
-      outputCache.add(intermediateOutput);
+      forward(i, outputCache);
     }
     return outputCache;
   }
-  
+
   /**
    * @param neuronClass
    * @return a new neuron instance
    */
-  @SuppressWarnings({ "unchecked", "rawtypes" })
-  public static Neuron<Synapse<DoubleWritable, DoubleWritable>> newNeuronInstance(
-      Class<? extends Neuron> neuronClass) {
-    return (Neuron<Synapse<DoubleWritable, DoubleWritable>>) ReflectionUtils
-        .newInstance(neuronClass);
+  @SuppressWarnings({ "rawtypes" })
+  public static Neuron newNeuronInstance(Class<? extends Neuron> neuronClass) {
+    return (Neuron) ReflectionUtils.newInstance(neuronClass);
   }
 
   /**
@@ -351,25 +363,33 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
    * @param intermediateOutput The intermediateOutput of previous layer.
    * @return a new vector with the result of the operation.
    */
-  protected DoubleVector forward(int fromLayer, DoubleVector intermediateOutput) {
+  protected void forward(int fromLayer, List<DoubleVector> outputCache) {
+    DoubleVector previousOutput = outputCache.get(fromLayer * 2); // skip
+                                                                  // intermediate
+                                                                  // output
+
     DoubleMatrix weightMatrix = this.weightMatrixList.get(fromLayer);
 
     // LOG.info("intermediate: " + intermediateOutput.toString());
     // DoubleVector vec = weightMatrix.multiplyVectorUnsafe(intermediateOutput);
     // vec = vec.applyToElements(this.squashingFunctionList.get(fromLayer));
-   
+
+    DoubleFunction squashingFunction = getSquashingFunction(fromLayer);
+
     DoubleVector vec = new DenseDoubleVector(weightMatrix.getRowCount());
+
     for (int row = 0; row < weightMatrix.getRowCount(); row++) {
       List<Synapse<DoubleWritable, DoubleWritable>> msgs = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
       for (int col = 0; col < weightMatrix.getColumnCount(); col++) {
         msgs.add(new Synapse<DoubleWritable, DoubleWritable>(
-            new DoubleWritable(intermediateOutput.get(col)),
-            new DoubleWritable(weightMatrix.get(row, col))));
+            new DoubleWritable(previousOutput.get(col)), new DoubleWritable(
+                weightMatrix.get(row, col))));
       }
       Iterable<Synapse<DoubleWritable, DoubleWritable>> iterable = msgs;
-      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
-          .get(fromLayer));
-      n.setSquashingFunction(this.squashingFunctionList.get(fromLayer));
+      Neuron n = newNeuronInstance(this.neuronClassList.get(fromLayer));
+      n.setSquashingFunction(squashingFunction);
+      n.setLayerIndex(fromLayer);
+
       try {
         n.forward(iterable);
       } catch (IOException e) {
@@ -378,14 +398,30 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
       }
       vec.set(row, n.getOutput());
     }
-    
+
+    if (squashingFunction.getFunctionName().equalsIgnoreCase(
+        SoftMax.class.getSimpleName())) {
+      IntermediateOutput interlayer = (IntermediateOutput) ReflectionUtils
+          .newInstance(SoftMax.SoftMaxOutputComputer.class);
+      try {
+        outputCache.add(vec);
+        vec = interlayer.interlayer(vec);
+      } catch (IOException e) {
+        // TODO Auto-generated catch block
+        e.printStackTrace();
+      }
+    } else {
+      outputCache.add(null);
+    }
+
     // add bias
     DoubleVector vecWithBias = new DenseDoubleVector(vec.getDimension() + 1);
     vecWithBias.set(0, 1);
     for (int i = 0; i < vec.getDimension(); ++i) {
       vecWithBias.set(i + 1, vec.get(i));
     }
-    return vecWithBias;
+
+    outputCache.add(vecWithBias);
   }
 
   /**
@@ -472,6 +508,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
       List<DoubleVector> internalResults) {
 
     DoubleVector output = internalResults.get(internalResults.size() - 1);
+
     // initialize weight update matrices
     DenseDoubleMatrix[] weightUpdateMatrices = new DenseDoubleMatrix[this.weightMatrixList
         .size()];
@@ -487,22 +524,27 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
 
     DoubleMatrix lastWeightMatrix = this.weightMatrixList
         .get(this.weightMatrixList.size() - 1);
+
     for (int i = 0; i < deltaVec.getDimension(); ++i) {
       double costFuncDerivative = this.costFunction.applyDerivative(
           labels.get(i), output.get(i + 1));
       // add regularization
       costFuncDerivative += this.regularizationWeight
           * lastWeightMatrix.getRowVector(i).sum();
-      deltaVec.set(
-          i,
-          costFuncDerivative
-              * squashingFunction.applyDerivative(output.get(i + 1)));
+
+      if (!squashingFunction.getFunctionName().equalsIgnoreCase(
+          SoftMax.class.getSimpleName())) {
+        costFuncDerivative *= squashingFunction.applyDerivative(output
+            .get(i + 1));
+      }
+
+      deltaVec.set(i, costFuncDerivative);
     }
 
     // start from previous layer of output layer
     for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
-      output = internalResults.get(layer);
-      deltaVec = backpropagate(layer, deltaVec, internalResults,
+      output = internalResults.get(layer * 2); // skip intermediate output
+      deltaVec = backpropagate(layer, deltaVec, output,
           weightUpdateMatrices[layer]);
     }
 
@@ -521,13 +563,12 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
    * @return the squashing function of the specified position.
    */
   private DoubleVector backpropagate(int curLayerIdx,
-      DoubleVector nextLayerDelta, List<DoubleVector> outputCache,
+      DoubleVector nextLayerDelta, DoubleVector curLayerOutput,
       DenseDoubleMatrix weightUpdateMatrix) {
 
     // get layer related information
     DoubleFunction squashingFunction = this.squashingFunctionList
         .get(curLayerIdx);
-    DoubleVector curLayerOutput = outputCache.get(curLayerIdx);
     DoubleMatrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
     DoubleMatrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);
 
@@ -536,16 +577,16 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
       nextLayerDelta = nextLayerDelta.slice(1,
           nextLayerDelta.getDimension() - 1);
     }
-    
+
     DoubleVector deltaVector = new DenseDoubleVector(
         weightMatrix.getColumnCount());
-    
+
     for (int row = 0; row < weightMatrix.getColumnCount(); ++row) {
-      Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance(this.neuronClassList
-          .get(curLayerIdx));
+      Neuron n = newNeuronInstance(this.neuronClassList.get(curLayerIdx));
       // calls setup method
       n.setLearningRate(this.learningRate);
       n.setMomentumWeight(this.momentumWeight);
+      n.setLayerIndex(curLayerIdx);
 
       n.setSquashingFunction(squashingFunction);
       n.setOutput(curLayerOutput.get(row));
@@ -568,7 +609,7 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
         // TODO Auto-generated catch block
         e.printStackTrace();
       }
-      
+
       // update weights
       weightUpdateMatrix.setColumn(row, n.getWeights());
       deltaVector.set(row, n.getDelta());
@@ -628,5 +669,5 @@ public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
   public DoubleFunction getSquashingFunction(int idx) {
     return this.squashingFunctionList.get(idx);
   }
-  
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
index ce6d6e4..350200f 100644
--- a/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
@@ -18,6 +18,9 @@
 package org.apache.horn.core;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -41,8 +44,9 @@ public final class LayeredNeuralNetworkTrainer
     extends
     BSP<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> {
 
-  private static final Log LOG = LogFactory.getLog(LayeredNeuralNetworkTrainer.class);
-  
+  private static final Log LOG = LogFactory
+      .getLog(LayeredNeuralNetworkTrainer.class);
+
   private LayeredNeuralNetwork inMemoryModel;
   private HamaConfiguration conf;
   /* Default batch size */
@@ -90,10 +94,9 @@ public final class LayeredNeuralNetworkTrainer
     // write model to modelPath
     if (peer.getPeerIndex() == 0) {
       try {
-        LOG.info(String.format("End of training, number of iterations: %d.\n",
+        LOG.info(String.format("End of training, number of iterations: %d.",
             this.iterations));
-        LOG.info(String.format("Write model back to %s\n",
-            inMemoryModel.getModelPath()));
+        LOG.info(String.format("Write model back to %s", inMemoryModel.getModelPath()));
         this.inMemoryModel.writeModelToFile();
       } catch (IOException e) {
         e.printStackTrace();
@@ -101,10 +104,21 @@ public final class LayeredNeuralNetworkTrainer
     }
   }
 
+  private List<DoubleVector> trainingSet = new ArrayList<DoubleVector>();
+  private Random r = new Random();
+
   @Override
   public void bsp(
       BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
       throws IOException, SyncException, InterruptedException {
+    // load local data into memory
+    LongWritable key = new LongWritable();
+    VectorWritable value = new VectorWritable();
+    while (peer.readNext(key, value)) {
+      DoubleVector v = value.getVector();
+      trainingSet.add(v);
+    }
+
     while (this.iterations++ < maxIterations) {
       // each groom calculate the matrices updates according to local data
       calculateUpdates(peer);
@@ -121,6 +135,10 @@ public final class LayeredNeuralNetworkTrainer
     }
   }
 
+  private DoubleVector getRandomInstance() {
+    return trainingSet.get(r.nextInt(trainingSet.size()));
+  }
+
   /**
    * Calculate the matrices updates according to local partition of data.
    * 
@@ -154,18 +172,8 @@ public final class LayeredNeuralNetworkTrainer
 
     // continue to train
     double avgTrainingError = 0.0;
-    LongWritable key = new LongWritable();
-    VectorWritable value = new VectorWritable();
     for (int recordsRead = 0; recordsRead < batchSize; ++recordsRead) {
-      if (!peer.readNext(key, value)) {
-        peer.reopenInput();
-        if (peer.getPeerIndex() == 0) {
-          epoch++;
-          LOG.info("Training loss: " + curAvgTrainingError + " at " + (epoch) + " epoch.");
-        }
-        peer.readNext(key, value);
-      }
-      DoubleVector trainingInstance = value.getVector();
+      DoubleVector trainingInstance = getRandomInstance();
       LayeredNeuralNetwork.matricesAdd(weightUpdates,
           this.inMemoryModel.trainByInstance(trainingInstance));
       avgTrainingError += this.inMemoryModel.trainingError;
@@ -179,8 +187,8 @@ public final class LayeredNeuralNetworkTrainer
 
     DoubleMatrix[] prevWeightUpdates = this.inMemoryModel
         .getPrevMatricesUpdates();
-    ParameterMessage outMessage = new ParameterMessage(
-        avgTrainingError, false, weightUpdates, prevWeightUpdates);
+    ParameterMessage outMessage = new ParameterMessage(avgTrainingError, false,
+        weightUpdates, prevWeightUpdates);
     peer.send(peer.getPeerName(0), outMessage);
   }
 
@@ -215,7 +223,7 @@ public final class LayeredNeuralNetworkTrainer
         LayeredNeuralNetwork.matricesAdd(prevMatricesUpdates,
             message.getPrevMatrices());
       }
-      
+
       avgTrainingError += message.getTrainingError();
     }
 
@@ -229,7 +237,7 @@ public final class LayeredNeuralNetworkTrainer
 
     this.inMemoryModel.updateWeightMatrices(matricesUpdates);
     this.inMemoryModel.setPrevWeightMatrices(prevMatricesUpdates);
-    
+
     // check convergence
     if (iterations % convergenceCheckInterval == 0) {
       if (prevAvgTrainingError < curAvgTrainingError) {
@@ -238,14 +246,16 @@ public final class LayeredNeuralNetworkTrainer
       }
       // update
       prevAvgTrainingError = curAvgTrainingError;
+      LOG.info("Training error: " + curAvgTrainingError + " at " + (iterations)
+          + " iteration.");
       curAvgTrainingError = 0;
     }
     curAvgTrainingError += avgTrainingError / convergenceCheckInterval;
 
     // broadcast updated weight matrices
     for (String peerName : peer.getAllPeerNames()) {
-      ParameterMessage msg = new ParameterMessage(
-          0, isConverge, this.inMemoryModel.getWeightMatrices(),
+      ParameterMessage msg = new ParameterMessage(0, isConverge,
+          this.inMemoryModel.getWeightMatrices(),
           this.inMemoryModel.getPrevMatricesUpdates());
       peer.send(peerName, msg);
     }

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/Neuron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/Neuron.java b/src/main/java/org/apache/horn/core/Neuron.java
index 4471b45..af18c79 100644
--- a/src/main/java/org/apache/horn/core/Neuron.java
+++ b/src/main/java/org/apache/horn/core/Neuron.java
@@ -25,6 +25,7 @@ import org.apache.hadoop.io.Writable;
 import org.apache.hama.commons.math.DoubleFunction;
 
 public abstract class Neuron<M extends Writable> implements Writable, NeuronInterface<M> {
+  int id;
   double output;
   double weight;
   double delta;
@@ -32,8 +33,27 @@ public abstract class Neuron<M extends Writable> implements Writable, NeuronInte
   double momentumWeight;
   double learningRate;
 
+  int layerIndex;
+  boolean isOutputLayer;
+  
   protected DoubleFunction squashingFunction;
 
+  public void setNeuronID(int id) {
+    this.id = id;
+  }
+  
+  public int getID() {
+    return id;
+  }
+  
+  public int getLayerIndex() {
+    return layerIndex;
+  }
+
+  public void setLayerIndex(int index) {
+    this.layerIndex = index;
+  }
+  
   public void feedforward(double sum) {
     this.output = sum;
   }
@@ -103,6 +123,7 @@ public abstract class Neuron<M extends Writable> implements Writable, NeuronInte
 
   @Override
   public void readFields(DataInput in) throws IOException {
+    id = in.readInt();
     output = in.readDouble();
     weight = in.readDouble();
     delta = in.readDouble();
@@ -113,6 +134,7 @@ public abstract class Neuron<M extends Writable> implements Writable, NeuronInte
 
   @Override
   public void write(DataOutput out) throws IOException {
+    out.writeInt(id);
     out.writeDouble(output);
     out.writeDouble(weight);
     out.writeDouble(delta);

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/NeuronInterface.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/NeuronInterface.java b/src/main/java/org/apache/horn/core/NeuronInterface.java
index ef5a2d3..73d8220 100644
--- a/src/main/java/org/apache/horn/core/NeuronInterface.java
+++ b/src/main/java/org/apache/horn/core/NeuronInterface.java
@@ -25,7 +25,7 @@ public interface NeuronInterface<M extends Writable> {
 
   /**
    * This method is called when the messages are propagated from the next layer.
-   * It can be used to determine if the neuron would activate, or fire.
+   * It can be used to calculate the activation or intermediate output.
    * 
    * @param messages
    * @throws IOException

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/core/Synapse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/Synapse.java b/src/main/java/org/apache/horn/core/Synapse.java
index 714767b..6dbada8 100644
--- a/src/main/java/org/apache/horn/core/Synapse.java
+++ b/src/main/java/org/apache/horn/core/Synapse.java
@@ -69,7 +69,7 @@ public class Synapse<M extends Writable, W extends Writable> implements
   public double getPrevWeight() {
     return prevWeight.get();
   }
-
+  
   @Override
   public void readFields(DataInput in) throws IOException {
     message.readFields(in);

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
index 4c0df95..ac17cc4 100644
--- a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -27,6 +27,7 @@ import org.apache.horn.core.Neuron;
 import org.apache.horn.core.Synapse;
 import org.apache.horn.funcs.CrossEntropy;
 import org.apache.horn.funcs.Sigmoid;
+import org.apache.horn.funcs.SoftMax;
 
 public class MultiLayerPerceptron {
 
@@ -41,8 +42,7 @@ public class MultiLayerPerceptron {
       for (Synapse<DoubleWritable, DoubleWritable> m : messages) {
         sum += m.getInput() * m.getWeight();
       }
-
-      this.feedforward(this.squashingFunction.apply(sum));
+      this.feedforward(squashingFunction.apply(sum));
     }
 
     @Override
@@ -61,7 +61,7 @@ public class MultiLayerPerceptron {
       }
 
       this.backpropagate(gradient
-          * this.squashingFunction.applyDerivative(this.getOutput()));
+          * squashingFunction.applyDerivative(getOutput()));
     }
   }
 
@@ -78,15 +78,15 @@ public class MultiLayerPerceptron {
     job.setLearningRate(learningRate);
     job.setMomentumWeight(momemtumWeight);
     job.setRegularizationWeight(regularizationWeight);
-    
+
     job.setConvergenceCheckInterval(600);
     job.setBatchSize(miniBatch);
-    
+
     job.setTrainingMethod(TrainingMethod.GRADIENT_DESCENT);
 
     job.inputLayer(features, Sigmoid.class, StandardNeuron.class);
     job.addLayer(hu, Sigmoid.class, StandardNeuron.class);
-    job.outputLayer(labels, Sigmoid.class, StandardNeuron.class);
+    job.outputLayer(labels, SoftMax.class, StandardNeuron.class);
 
     job.setCostFunction(CrossEntropy.class);
 

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java b/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
new file mode 100644
index 0000000..96c228a
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/CategoricalCrossEntropy.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleDoubleFunction;
+
+/**
+ * for softmaxed output 
+ */
+public class CategoricalCrossEntropy extends DoubleDoubleFunction {
+  
+  private static final double epsilon = 1e-8;
+  
+  @Override
+  public double apply(double target, double actual) {
+    return -target * Math.log(Math.max(actual, epsilon));
+  }
+
+  @Override
+  public double applyDerivative(double target, double actual) {
+    // o - y
+    return -(target - actual);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/funcs/CrossEntropy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/CrossEntropy.java b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
index 7cc5e6a..a096be0 100644
--- a/src/main/java/org/apache/horn/funcs/CrossEntropy.java
+++ b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
@@ -29,6 +29,8 @@ import org.apache.hama.commons.math.DoubleDoubleFunction;
  */
 public class CrossEntropy extends DoubleDoubleFunction {
 
+  private static final double epsilon = 1e-8;
+  
   @Override
   public double apply(double target, double actual) {
     double adjustedTarget = (target == 0 ? 0.000001 : target);
@@ -36,10 +38,11 @@ public class CrossEntropy extends DoubleDoubleFunction {
     double adjustedActual = (actual == 0 ? 0.000001 : actual);
     adjustedActual = (actual == 1 ? 0.999999 : adjustedActual);
     
-    return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget)
-        * Math.log(1 - adjustedActual);
+    return -target * Math.log(Math.max(actual, epsilon)) - (1 - target)
+        * Math.log(Math.max(1 - actual, epsilon));
+    // return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget) *  Math.log(adjustedActual);
   }
-
+  
   @Override
   public double applyDerivative(double target, double actual) {
     double adjustedTarget = (target == 0 ? 0.000001 : target);

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/funcs/FunctionFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/FunctionFactory.java b/src/main/java/org/apache/horn/funcs/FunctionFactory.java
index 9b38a0d..4310a95 100644
--- a/src/main/java/org/apache/horn/funcs/FunctionFactory.java
+++ b/src/main/java/org/apache/horn/funcs/FunctionFactory.java
@@ -37,6 +37,10 @@ public class FunctionFactory {
       return new Sigmoid();
     } else if (functionName.equalsIgnoreCase(Tanh.class.getSimpleName())) {
       return new Tanh();
+    } else if (functionName.equalsIgnoreCase(ReLU.class.getSimpleName())) {
+      return new ReLU();
+    } else if (functionName.equalsIgnoreCase(SoftMax.class.getSimpleName())) {
+      return new SoftMax();
     } else if (functionName.equalsIgnoreCase(IdentityFunction.class
         .getSimpleName())) {
       return new IdentityFunction();
@@ -59,7 +63,10 @@ public class FunctionFactory {
     } else if (functionName
         .equalsIgnoreCase(CrossEntropy.class.getSimpleName())) {
       return new CrossEntropy();
-    }
+    } else if (functionName
+        .equalsIgnoreCase(CategoricalCrossEntropy.class.getSimpleName())) {
+      return new CategoricalCrossEntropy();
+    } 
 
     throw new IllegalArgumentException(String.format(
         "No double double function with name '%s' exists.", functionName));

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/funcs/ReLU.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/ReLU.java b/src/main/java/org/apache/horn/funcs/ReLU.java
index 425137f..85af867 100644
--- a/src/main/java/org/apache/horn/funcs/ReLU.java
+++ b/src/main/java/org/apache/horn/funcs/ReLU.java
@@ -30,12 +30,15 @@ public class ReLU extends DoubleFunction {
 
   @Override
   public double apply(double value) {
-    return Math.max(0, value);
+    return Math.max(0.001, value);
   }
 
   @Override
   public double applyDerivative(double value) {
-    return (value > Double.MIN_VALUE) ? 1 : 0;
+    if (value > 0)
+      return 0.999;
+    else
+      return 0.001;
   }
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/funcs/Sigmoid.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/Sigmoid.java b/src/main/java/org/apache/horn/funcs/Sigmoid.java
index cc393e3..bcccf76 100644
--- a/src/main/java/org/apache/horn/funcs/Sigmoid.java
+++ b/src/main/java/org/apache/horn/funcs/Sigmoid.java
@@ -30,6 +30,11 @@ public class Sigmoid extends DoubleFunction {
 
   @Override
   public double apply(double value) {
+    if(value > 100) { // to avoid overflow and underflow
+      return 0.9999;
+    } else if (value < -100) {
+      return 0.0001;
+    }
     return 1.0 / (1 + Math.exp(-value));
   }
 

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/funcs/SoftMax.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/funcs/SoftMax.java b/src/main/java/org/apache/horn/funcs/SoftMax.java
new file mode 100644
index 0000000..6e0bf76
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/SoftMax.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import java.io.IOException;
+
+import org.apache.hama.commons.math.DenseDoubleVector;
+import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.DoubleVector;
+import org.apache.horn.core.IntermediateOutput;
+
+public class SoftMax extends DoubleFunction {
+
+  @Override
+  public double apply(double value) {
+    // it will be handled by intermediate output handler
+    return value;
+  }
+
+  @Override
+  public double applyDerivative(double value) {
+    return value * (1d - value);
+  }
+  
+  public static class SoftMaxOutputComputer extends IntermediateOutput {
+
+    @Override
+    public DoubleVector interlayer(DoubleVector output) throws IOException {
+      DoubleVector expVec = new DenseDoubleVector(output.getDimension());
+      double sum = 0.0;
+      for(int i = 0; i < output.getDimension(); ++i) {
+        double exp = Math.exp(output.get(i));
+        sum += exp;
+        expVec.set(i, exp);
+      }
+      // divide by the sum of exponential of the whole vector
+      DoubleVector softmaxed = expVec.divide(sum);
+      return softmaxed;
+    }
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/main/java/org/apache/horn/utils/MNISTEvaluator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/utils/MNISTEvaluator.java b/src/main/java/org/apache/horn/utils/MNISTEvaluator.java
index a5b68e0..839be97 100644
--- a/src/main/java/org/apache/horn/utils/MNISTEvaluator.java
+++ b/src/main/java/org/apache/horn/utils/MNISTEvaluator.java
@@ -21,6 +21,7 @@ import java.io.DataInputStream;
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
+import java.util.Random;
 
 import org.apache.hama.HamaConfiguration;
 import org.apache.hama.commons.math.DenseDoubleVector;
@@ -30,11 +31,11 @@ import org.apache.horn.core.LayeredNeuralNetwork;
 public class MNISTEvaluator {
 
   private static int PIXELS = 28 * 28;
-  
+
   private static double rescale(double x) {
     return 1 - (255 - x) / 255;
   }
-  
+
   public static void main(String[] args) throws IOException {
     if (args.length < 3) {
       System.out.println("Usage: <TRAINED_MODEL> <TEST_IMAGES> <TEST_LABELS>");
@@ -51,15 +52,13 @@ public class MNISTEvaluator {
         new File(training_data)));
     DataInputStream labelsIn = new DataInputStream(new FileInputStream(
         new File(labels_data)));
-    
+
     imagesIn.readInt(); // Magic number
     int count = imagesIn.readInt();
     labelsIn.readInt(); // Magic number
     labelsIn.readInt(); // Count
     imagesIn.readInt(); // Rows
     imagesIn.readInt(); // Cols
-    
-    System.out.println("Evaluating " + count + " images");
 
     byte[][] images = new byte[count][PIXELS];
     byte[] labels = new byte[count];
@@ -70,28 +69,33 @@ public class MNISTEvaluator {
 
     HamaConfiguration conf = new HamaConfiguration();
     LayeredNeuralNetwork ann = new LayeredNeuralNetwork(conf, modelPath);
-    
+
+    Random generator = new Random();
     int correct = 0;
+    int total = 0;
     for (int i = 0; i < count; i++) {
-      double[] vals = new double[PIXELS];
-      for (int j = 0; j < PIXELS; j++) {
-        vals[j] = rescale((images[i][j] & 0xff));
-      }
-      int label = (labels[i] & 0xff);
+      if (generator.nextInt(10) == 1) {
+        double[] vals = new double[PIXELS];
+        for (int j = 0; j < PIXELS; j++) {
+          vals[j] = rescale((images[i][j] & 0xff));
+        }
+        int label = (labels[i] & 0xff);
+
+        DoubleVector instance = new DenseDoubleVector(vals);
+        DoubleVector result = ann.getOutput(instance);
 
-      DoubleVector instance = new DenseDoubleVector(vals);
-      DoubleVector result = ann.getOutput(instance);
-      
-      if(getNumber(result) == label) {
-        correct++;
+        if (getNumber(result) == label) {
+          correct++;
+        }
+        total++;
       }
     }
 
-    System.out.println((double) correct / count);
+    System.out.println(((double) correct / total * 100) + "%");
     // TODO System.out.println("Precision = " + (tp / (tp + fp)));
-    //System.out.println("Recall = " + (tp / (tp + fn)));
-    //System.out.println("Accuracy = " + ((tp + tn) / (tp + tn + fp + fn)));
-    
+    // System.out.println("Recall = " + (tp / (tp + fn)));
+    // System.out.println("Accuracy = " + ((tp + tn) / (tp + tn + fp + fn)));
+
     imagesIn.close();
     labelsIn.close();
   }
@@ -99,9 +103,9 @@ public class MNISTEvaluator {
   private static int getNumber(DoubleVector result) {
     double max = 0;
     int index = -1;
-    for(int x = 0; x < result.getLength(); x++) {
+    for (int x = 0; x < result.getLength(); x++) {
       double curr = result.get(x);
-      if(max < curr) {
+      if (max < curr) {
         max = curr;
         index = x;
       }

http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ca560628/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java b/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java
index bb404d6..9110088 100644
--- a/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java
+++ b/src/test/java/org/apache/horn/examples/MultiLayerPerceptronTest.java
@@ -23,8 +23,6 @@ import java.io.IOException;
 import java.io.InputStreamReader;
 import java.net.URI;
 
-import junit.framework.TestCase;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.fs.FileSystem;
@@ -32,18 +30,24 @@ import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hama.Constants;
+import org.apache.hama.HamaCluster;
 import org.apache.hama.HamaConfiguration;
 import org.apache.hama.commons.io.VectorWritable;
 import org.apache.hama.commons.math.DenseDoubleVector;
 import org.apache.hama.commons.math.DoubleVector;
 import org.apache.horn.core.HornJob;
 import org.apache.horn.core.LayeredNeuralNetwork;
+import org.apache.horn.core.Constants.TrainingMethod;
+import org.apache.horn.examples.MultiLayerPerceptron.StandardNeuron;
+import org.apache.horn.funcs.CrossEntropy;
+import org.apache.horn.funcs.Sigmoid;
 
 /**
  * Test the functionality of NeuralNetwork Example.
  */
-public class MultiLayerPerceptronTest extends TestCase { // HamaCluster {
-  private static final Log LOG = LogFactory.getLog(MultiLayerPerceptronTest.class);
+public class MultiLayerPerceptronTest extends HamaCluster {
+  private static final Log LOG = LogFactory
+      .getLog(MultiLayerPerceptronTest.class);
   private HamaConfiguration conf;
   private FileSystem fs;
   private String MODEL_PATH = "/tmp/neuralnets.model";
@@ -51,7 +55,7 @@ public class MultiLayerPerceptronTest extends TestCase { // HamaCluster {
   private String SEQTRAIN_DATA = "/tmp/test-neuralnets.data";
 
   public MultiLayerPerceptronTest() {
-    conf = new HamaConfiguration();/*
+    conf = new HamaConfiguration();
     conf.set("bsp.master.address", "localhost");
     conf.setBoolean("hama.child.redirect.log.console", true);
     conf.setBoolean("hama.messenger.runtime.compression", false);
@@ -62,7 +66,7 @@ public class MultiLayerPerceptronTest extends TestCase { // HamaCluster {
     conf.setInt(Constants.ZOOKEEPER_CLIENT_PORT, 21810);
     conf.set("hama.sync.client.class",
         org.apache.hama.bsp.sync.ZooKeeperSyncClientImpl.class
-            .getCanonicalName());*/
+            .getCanonicalName());
   }
 
   @Override
@@ -163,12 +167,28 @@ public class MultiLayerPerceptronTest extends TestCase { // HamaCluster {
     }
 
     try {
-      HornJob ann = MultiLayerPerceptron.createJob(conf, MODEL_PATH,
-          SEQTRAIN_DATA, 0.4, 0.2, 0.01, featureDimension, 8, labelDimension,
-          300, 10000);
+      HornJob job = new HornJob(conf, MultiLayerPerceptronTest.class);
+      job.setTrainingSetPath(SEQTRAIN_DATA);
+      job.setModelPath(MODEL_PATH);
+
+      job.setMaxIteration(1000);
+      job.setLearningRate(0.4);
+      job.setMomentumWeight(0.2);
+      job.setRegularizationWeight(0.001);
+
+      job.setConvergenceCheckInterval(100);
+      job.setBatchSize(300);
+
+      job.setTrainingMethod(TrainingMethod.GRADIENT_DESCENT);
+
+      job.inputLayer(featureDimension, Sigmoid.class, StandardNeuron.class);
+      job.addLayer(featureDimension, Sigmoid.class, StandardNeuron.class);
+      job.outputLayer(labelDimension, Sigmoid.class, StandardNeuron.class);
 
+      job.setCostFunction(CrossEntropy.class);
+      
       long startTime = System.currentTimeMillis();
-      if (ann.waitForCompletion(true)) {
+      if (job.waitForCompletion(true)) {
         LOG.info("Job Finished in " + (System.currentTimeMillis() - startTime)
             / 1000.0 + " seconds");
       }


Mime
View raw message