labs-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1707760 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/core/ core/src/main/java/org/apache/yay/core/neuron/ core/src/test/java/org/apache/yay/core/
Date Fri, 09 Oct 2015 15:28:35 GMT
Author: tommaso
Date: Fri Oct  9 15:28:34 2015
New Revision: 1707760

URL: http://svn.apache.org/viewvc?rev=1707760&view=rev
Log:
added softmax, adjusted af/neuron APIs, tests adjusted

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/Neuron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/BinaryThresholdNeuron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/SigmoidNeuron.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/VoidLearningStrategyTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java Fri Oct  9 15:28:34
2015
@@ -18,6 +18,8 @@
  */
 package org.apache.yay;
 
+import org.apache.commons.math3.linear.RealMatrix;
+
 /**
  * An activation function AF : S -* S receives a signal and generates a new signal.
  * An activation function AF has horizontal asymptotes at 0 and 1 and a non
@@ -30,9 +32,10 @@ public interface ActivationFunction<T> {
   /**
    * Apply this <code>ActivationFunction</code> to the given signal, generating
a new signal
    *
-   * @param signal the input signal
+   * @param signal  the input signal
+   * @param weights the matrix of weights the activation should be applied to
    * @return the output signal generated
    */
-  T apply(T signal);
+  T apply(RealMatrix weights, T signal);
 
 }

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Neuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Neuron.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/Neuron.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/Neuron.java Fri Oct  9 15:28:34 2015
@@ -18,19 +18,28 @@
  */
 package org.apache.yay;
 
+import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+
 /**
  * A node in a neural network
  */
 public abstract class Neuron<T> {
 
+  protected double[] weights;
+
+  private final RealMatrix matrix;
+
   protected ActivationFunction<T> activationFunction;
 
-  public Neuron(ActivationFunction<T> activationFunction) {
+  public Neuron(ActivationFunction<T> activationFunction, double... weights) {
     this.activationFunction = activationFunction;
+    this.weights = weights;
+    this.matrix = MatrixUtils.createRowRealMatrix(weights);
   }
 
   public T elaborate(T... inputs) {
-    return activationFunction.apply(combine(inputs));
+    return activationFunction.apply(matrix, combine(inputs));
   }
 
   protected abstract T combine(T... inputs);

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
Fri Oct  9 15:28:34 2015
@@ -88,7 +88,6 @@ public class BackPropagationLearningStra
 
       double cost = Double.MAX_VALUE;
       while (true) {
-
         TrainingSet<Double, Double> samples;
         if (batch == -1) {
           samples = trainingExamples;

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Fri Oct
 9 15:28:34 2015
@@ -18,13 +18,6 @@
  */
 package org.apache.yay.core;
 
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.commons.collections.CollectionUtils;
-import org.apache.commons.collections.Transformer;
 import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
@@ -33,6 +26,10 @@ import org.apache.yay.ActivationFunction
 import org.apache.yay.PredictionStrategy;
 import org.apache.yay.core.utils.ConversionUtils;
 
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
 /**
  * Octave code for FF to be converted :
  * m = size(X, 1);
@@ -73,14 +70,17 @@ public class FeedForwardStrategy impleme
     RealVector v = ConversionUtils.toRealVector(input);
     RealMatrix x = v.outerProduct(new ArrayRealVector(new Double[]{1d})).transpose(); //
a 1xN matrix
     for (int w = 0; w < realMatrixSet.length; w++) {
-      RealMatrix currentWeightsMatrix = realMatrixSet[w];
+      final RealMatrix currentWeightsMatrix = realMatrixSet[w];
       // compute matrix multiplication
       x = x.multiply(currentWeightsMatrix.transpose());
 
+      final RealMatrix cm = x.getRowMatrix(0);
+
       // apply the activation function to each element in the matrix
       int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0;
       final ActivationFunction<Double> af = activationFunctionMap.get(idx);
-      x.walkInRowOrder(new RealMatrixChangingVisitor() {
+      RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() {
+
         @Override
         public void start(int rows, int columns, int startRow, int endRow, int startColumn,
int endColumn) {
 
@@ -88,14 +88,18 @@ public class FeedForwardStrategy impleme
 
         @Override
         public double visit(int row, int column, double value) {
-          return af.apply(value);
+          final RealMatrix rowMatrix = cm.getRowMatrix(row);
+          final RealMatrix columnMatrix = cm.getColumnMatrix(column);
+          Double newValue = af.apply(cm, value);
+          return newValue;
         }
 
         @Override
         public double end() {
           return 0;
         }
-      });
+      };
+      x.walkInRowOrder(visitor);
       debugOutput[w] = x.getRowVector(0);
     }
     return debugOutput;

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
Fri Oct  9 15:28:34 2015
@@ -18,6 +18,7 @@
  */
 package org.apache.yay.core;
 
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -26,7 +27,7 @@ import org.apache.yay.ActivationFunction
 public class IdentityActivationFunction<T> implements ActivationFunction<T> {
 
   @Override
-  public T apply(T signal) {
+  public T apply(RealMatrix matrix, T signal) {
     return signal;
   }
 

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java Fri Oct  9
15:28:34 2015
@@ -18,7 +18,9 @@
  */
 package org.apache.yay.core;
 
+import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.Hypothesis;
 import org.apache.yay.NeuralNetworkCostFunction;
 import org.apache.yay.TrainingExample;
@@ -28,27 +30,32 @@ import org.apache.yay.TrainingSet;
  * Least mean square cost function
  */
 class LMSCostFunction implements NeuralNetworkCostFunction {
-    @Override
-    public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingExamples,
Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
-        TrainingExample<Double, Double>[] samples = new TrainingExample[trainingExamples.size()];
-        int i = 0;
-        for (TrainingExample<Double, Double> sample : trainingExamples) {
-            samples[i] = sample;
-            i++;
-        }
-        return calculateCost(hypothesis, samples);
+  @Override
+  public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingExamples,
Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
+    TrainingExample<Double, Double>[] samples = new TrainingExample[trainingExamples.size()];
+    int i = 0;
+    for (TrainingExample<Double, Double> sample : trainingExamples) {
+      samples[i] = sample;
+      i++;
     }
+    return calculateCost(hypothesis, samples);
+  }
 
-    @Override
-    public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis,
TrainingExample<Double, Double>... trainingExamples) throws Exception {
-        Double cost = 0d;
-        for (TrainingExample<Double, Double> example : trainingExamples) {
-            Double[] actualOutput = example.getOutput();
-            Double[] predictedOutput = hypothesis.predict(example);
-            for (int i = 0; i < actualOutput.length; i++) {
-                cost += actualOutput[i] - predictedOutput[i];
-            }
-        }
-        return Math.pow(cost, 2) / 2;
+  @Override
+  public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double,
Double>... trainingExamples) throws Exception {
+    Double cost = 0d;
+    for (TrainingExample<Double, Double> example : trainingExamples) {
+      Double[] actualOutput = example.getOutput();
+      Double[] predictedOutput = hypothesis.predict(example);
+      RealVector actualVector = new ArrayRealVector(actualOutput);
+      RealVector predictedVector = new ArrayRealVector(predictedOutput);
+      RealVector diffVector = actualVector.subtract(predictedVector);
+      for (int i = 0; i < diffVector.getDimension(); i++) {
+        double entry = diffVector.getEntry(i);
+        cost += Math.pow(entry, 2);
+      }
     }
+    cost /= 2;
+    return cost;
+  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java Fri Oct  9
15:28:34 2015
@@ -18,6 +18,7 @@
  */
 package org.apache.yay.core;
 
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -25,7 +26,7 @@ import org.apache.yay.ActivationFunction
  */
 public class SigmoidFunction implements ActivationFunction<Double> {
 
-  public Double apply(final Double input) {
+  public Double apply(RealMatrix matrix, final Double input) {
     return 1d / (1d + Math.exp(-1d * input));
   }
 

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java?rev=1707760&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java Fri
Oct  9 15:28:34 2015
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.ActivationFunction;
+
+/**
+ * Softmax activation function
+ */
+public class SoftmaxActivationFunction implements ActivationFunction<Double> {
+
+    @Override
+    public Double apply(RealMatrix weights, Double signal) {
+        double num = Math.exp(signal);
+        double den = 0d;
+        for (int i = 0; i < weights.getRowDimension(); i++) {
+            double[] row1 = weights.getRow(i);
+            for (int j = 0; j < weights.getColumnDimension(); j++) {
+                den += Math.exp(row1[j]);
+            }
+        }
+      double v = num / den;
+      return v;
+    }
+}

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java Fri
Oct  9 15:28:34 2015
@@ -18,6 +18,7 @@
  */
 package org.apache.yay.core;
 
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -32,7 +33,7 @@ public class StepActivationFunction impl
   }
 
   @Override
-  public Double apply(Double signal) {
+  public Double apply(RealMatrix matrix, Double signal) {
     return signal >= center ? 1d : 0d;
   }
 

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java Fri Oct  9 15:28:34
2015
@@ -18,6 +18,7 @@
  */
 package org.apache.yay.core;
 
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -25,7 +26,7 @@ import org.apache.yay.ActivationFunction
  */
 public class TanhFunction implements ActivationFunction<Double> {
   @Override
-  public Double apply(Double signal) {
+  public Double apply(RealMatrix matrix, Double signal) {
     return Math.tanh(signal);
   }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/BinaryThresholdNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/BinaryThresholdNeuron.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/BinaryThresholdNeuron.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/BinaryThresholdNeuron.java
Fri Oct  9 15:28:34 2015
@@ -30,11 +30,8 @@ import org.apache.yay.core.StepActivatio
  */
 public class BinaryThresholdNeuron extends Neuron<Double> {
 
-  private double[] weights;
-
   public BinaryThresholdNeuron(double threshold, double... weights) {
-    super(new StepActivationFunction(threshold));
-    this.weights = weights;
+    super(new StepActivationFunction(threshold), weights);
   }
 
   public void updateWeights(double... weights) {

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java Fri Oct
 9 15:28:34 2015
@@ -26,12 +26,12 @@ import org.apache.yay.core.IdentityActiv
  */
 class LinearNeuron extends Neuron<Double> {
 
-  private final Double[] weights;
+  private final double[] weights;
 
-  private final Double bias;
+  private final double bias;
 
-  LinearNeuron(Double bias, Double... weights) {
-    super(new IdentityActivationFunction<Double>());
+  LinearNeuron(double bias, double... weights) {
+    super(new IdentityActivationFunction<Double>(), weights);
     this.bias = bias;
     this.weights = weights;
   }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
Fri Oct  9 15:28:34 2015
@@ -18,6 +18,7 @@
  */
 package org.apache.yay.core.neuron;
 
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -26,11 +27,11 @@ import org.apache.yay.ActivationFunction
  */
 class RectifiedLinearNeuron extends LinearNeuron {
 
-  public RectifiedLinearNeuron(Double bias, Double... weights) {
+  public RectifiedLinearNeuron(double bias, double... weights) {
     super(bias, weights);
     this.activationFunction = new ActivationFunction<Double>() {
       @Override
-      public Double apply(Double signal) {
+      public Double apply(RealMatrix matrix, Double signal) {
         return signal > 0 ? signal : 0;
       }
     };

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/SigmoidNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/SigmoidNeuron.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/SigmoidNeuron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/SigmoidNeuron.java Fri Oct
 9 15:28:34 2015
@@ -25,7 +25,7 @@ import org.apache.yay.core.SigmoidFuncti
  */
 class SigmoidNeuron extends LinearNeuron {
 
-  public SigmoidNeuron(Double bias, Double... weights) {
+  public SigmoidNeuron(double bias, double... weights) {
     super(bias, weights);
     this.activationFunction = new SigmoidFunction();
   }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Fri Oct  9 15:28:34 2015
@@ -33,7 +33,7 @@ import static junit.framework.Assert.ass
 import static junit.framework.Assert.assertNotNull;
 
 /**
- * Testcase for {@link org.apache.yay.core.BackPropagationLearningStrategy}
+ * Tests for {@link org.apache.yay.core.BackPropagationLearningStrategy}
  */
 public class BackPropagationLearningStrategyTest {
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java Fri Oct
 9 15:28:34 2015
@@ -32,7 +32,7 @@ import static org.junit.Assert.assertEqu
 import static org.junit.Assert.assertTrue;
 
 /**
- * Testcase for {@link org.apache.yay.core.BasicPerceptron}
+ * Tests for {@link org.apache.yay.core.BasicPerceptron}
  */
 public class BasicPerceptronTest {
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java Fri
Oct  9 15:28:34 2015
@@ -28,7 +28,7 @@ import org.junit.Test;
 import static junit.framework.Assert.assertNotNull;
 
 /**
- * Testcase for {@link FeedForwardStrategy}
+ * Tests for {@link FeedForwardStrategy}
  */
 public class FeedForwardStrategyTest {
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java
Fri Oct  9 15:28:34 2015
@@ -33,7 +33,7 @@ import org.junit.Test;
 import static org.junit.Assert.assertTrue;
 
 /**
- * Testcase for {@link org.apache.yay.core.LogisticRegressionCostFunction}
+ * Tests for {@link org.apache.yay.core.LogisticRegressionCostFunction}
  */
 public class LogisticRegressionCostFunctionTest {
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java Fri Oct
 9 15:28:34 2015
@@ -23,26 +23,26 @@ import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 
 /**
- * Testcase for {@link org.apache.yay.core.SigmoidFunction}
+ * Tests for {@link org.apache.yay.core.SigmoidFunction}
  */
 public class SigmoidFunctionTest {
 
   @Test
   public void testCorrectOutput() throws Exception {
     SigmoidFunction sigmoidFunction = new SigmoidFunction();
-    Double output = sigmoidFunction.apply(38d);
+    Double output = sigmoidFunction.apply(null, 38d);
     assertEquals(Double.valueOf(1d), output);
 
-    output = sigmoidFunction.apply(6d);
+    output = sigmoidFunction.apply(null, 6d);
     assertEquals(Double.valueOf(0.9975273768433653d), output);
 
-    output = sigmoidFunction.apply(2.5d);
+    output = sigmoidFunction.apply(null, 2.5d);
     assertEquals(Double.valueOf(0.9241418199787566d), output);
 
-    output = sigmoidFunction.apply(-2.5d);
+    output = sigmoidFunction.apply(null, -2.5d);
     assertEquals(Double.valueOf(0.07585818002124355d), output);
 
-    output = sigmoidFunction.apply(0d);
+    output = sigmoidFunction.apply(null, 0d);
     assertEquals(Double.valueOf(0.5d), output);
   }
 }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/VoidLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/VoidLearningStrategyTest.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/VoidLearningStrategyTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/VoidLearningStrategyTest.java Fri
Oct  9 15:28:34 2015
@@ -31,7 +31,7 @@ import static org.junit.Assert.assertNot
 import static org.junit.Assert.assertTrue;
 
 /**
- * Testcase for {@link org.apache.yay.core.VoidLearningStrategy}
+ * Tests for {@link org.apache.yay.core.VoidLearningStrategy}
  */
 public class VoidLearningStrategyTest {
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1707760&r1=1707759&r2=1707760&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Fri Oct  9
15:28:34 2015
@@ -70,24 +70,22 @@ public class WordVectorsTest {
     TrainingExample<Double, Double> next = trainingSet.iterator().next();
     int inputSize = next.getFeatures().size() ;
     int outputSize = next.getOutput().length;
-    int n = new Random().nextInt(20) + 5;
-    RealMatrix[] randomWeights = createRandomWeights(inputSize, n, outputSize);
+    int hiddenSize = new Random().nextInt(50) + 15;
+    RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize);
 
     Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer,
ActivationFunction<Double>>();
     activationFunctions.put(0, new IdentityActivationFunction<Double>());
-    // TODO : place a softmax activation for the output layer
-    activationFunctions.put(1, new SigmoidFunction());
+    activationFunctions.put(1, new SoftmaxActivationFunction());
     FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
-    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.05d,
10,
-            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(),
-            80);
+    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d,
10,
+            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(),
10);
     NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy,
predictionStrategy);
 
     neuralNetwork.learn(trainingSet);
 
     RealMatrix vectorsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
 
-    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/vectors.txt")));
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt")));
     int m = 0;
     for (String word : vocabulary) {
       final Double[] doubles = hotEncode(word, vocabulary);
@@ -146,7 +144,7 @@ public class WordVectorsTest {
     bufferedWriter.flush();
     bufferedWriter.close();
 
-    ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/vectors.bin")));
+    ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin")));
     MatrixUtils.serializeRealMatrix(vectorsMatrix, os);
 
   }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org


Mime
View raw message