labs-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1707564 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/core/ test/java/org/apache/yay/core/
Date Thu, 08 Oct 2015 15:03:49 GMT
Author: tommaso
Date: Thu Oct  8 15:03:49 2015
New Revision: 1707564

URL: http://svn.apache.org/viewvc?rev=1707564&view=rev
Log:
reduced boilerplate code in ff strategy for applying activation function, added layer specific
AFs, improved error derivative calculation

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
Thu Oct  8 15:03:49 2015
@@ -18,6 +18,8 @@
  */
 package org.apache.yay.core;
 
+import java.util.Arrays;
+
 import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
@@ -107,19 +109,13 @@ class DefaultDerivativeUpdateFunction im
   private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample,
RealVector[] activations) {
     RealVector output = activations[activations.length - 1];
 
-//    Double[] sampleOutput = new Double[output.getDimension()];
     Double[] actualOutput = trainingExample.getOutput();
-//    int sampleOutputIntValue = actualOutput.intValue();
-//    if (sampleOutputIntValue < sampleOutput.length) {
-//      sampleOutput[sampleOutputIntValue] = 1d;
-//    } else if (sampleOutput.length == 1) {
-//      sampleOutput[0] = actualOutput;
-//    } else {
-//      throw new RuntimeException("problem with multiclass output mapping");
-//    }
     RealVector learnedOutputRealVector = new ArrayRealVector(actualOutput); // turn example
output to a vector
 
-    // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) *
(tgt_a - out_a)
-    return output.subtract(learnedOutputRealVector);
+    double[] ones = new double[output.getDimension()];
+    Arrays.fill(ones, 1d);
+
+    // error calculation -> er_a = out_a * (1 - out_a) * (tgt_a - out_a) (was: output.subtract(learnedOutputRealVector)
+    return output.ebeMultiply(new ArrayRealVector(ones).subtract(output)).ebeMultiply(output.subtract(learnedOutputRealVector));
   }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Thu Oct
 8 15:03:49 2015
@@ -20,10 +20,14 @@ package org.apache.yay.core;
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.collections.Transformer;
 import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.ActivationFunction;
 import org.apache.yay.PredictionStrategy;
@@ -40,10 +44,15 @@ import org.apache.yay.core.utils.Convers
  */
 public class FeedForwardStrategy implements PredictionStrategy<Double, Double> {
 
-  private final ActivationFunction<Double> activationFunction;
+  private final Map<Integer, ActivationFunction<Double>> activationFunctionMap;
 
   public FeedForwardStrategy(ActivationFunction<Double> activationFunction) {
-    this.activationFunction = activationFunction;
+    this.activationFunctionMap = new HashMap<Integer, ActivationFunction<Double>>();
+    this.activationFunctionMap.put(0, activationFunction);
+  }
+
+  public FeedForwardStrategy(Map<Integer, ActivationFunction<Double>> activationFunctionMap)
{
+    this.activationFunctionMap = activationFunctionMap;
   }
 
   @Override
@@ -69,32 +78,27 @@ public class FeedForwardStrategy impleme
       x = x.multiply(currentWeightsMatrix.transpose());
 
       // apply the activation function to each element in the matrix
-      for (int i = 0; i < x.getRowDimension(); i++) {
-        double[] doubles = x.getRow(i);
-        final ArrayList<Double> row = new ArrayList<Double>(doubles.length);
-        for (int j = 0; j < doubles.length; j++) {
-          row.add(j, doubles[j]);
+      int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0;
+      final ActivationFunction<Double> af = activationFunctionMap.get(idx);
+      x.walkInRowOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int startColumn,
int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return af.apply(value);
         }
-        // TODO : see if bias term is handled correctly here
-        CollectionUtils.transform(row, new ActivationRowTransformer());
-        double[] finRow = new double[row.size()];
-        for (int h = 0; h < finRow.length; h++) {
-          finRow[h] = row.get(h);
+
+        @Override
+        public double end() {
+          return 0;
         }
-        x.setRow(i, finRow);
-      }
+      });
       debugOutput[w] = x.getRowVector(0);
     }
     return debugOutput;
   }
 
-  private class ActivationRowTransformer implements Transformer {
-    @Override
-    public Object transform(Object input) {
-      assert input instanceof Double;
-      final Double d = (Double) input;
-      return activationFunction.apply(d);
-    }
-  }
-
 }
\ No newline at end of file

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
Thu Oct  8 15:03:49 2015
@@ -18,24 +18,21 @@
  */
 package org.apache.yay.core;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Random;
-
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.yay.CreationException;
-import org.apache.yay.Feature;
-import org.apache.yay.Input;
-import org.apache.yay.LearningStrategy;
-import org.apache.yay.NeuralNetwork;
-import org.apache.yay.TrainingExample;
-import org.apache.yay.TrainingSet;
+import org.apache.commons.math3.ml.distance.CanberraDistance;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.yay.*;
 import org.apache.yay.core.utils.ExamplesFactory;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Random;
+
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
 /**
  * Integration test for NN
@@ -134,15 +131,20 @@ public class NeuralNetworkIntegrationTes
     int noOfFeatures = randomWeights[0].getColumnDimension() - 1;
     Collection<TrainingExample<Double, Double>> samples = createSamples(1000000,
noOfFeatures, noOfOutputs);
     nn.learn(new TrainingSet<Double, Double>(samples));
+    DistanceMeasure distanceMeasure = new CanberraDistance();
     for (TrainingExample<Double, Double> sample : samples) {
       Double[] predictedOutput = nn.predict(sample);
+      double[] a1 = new double[predictedOutput.length];
+      for (int i = 0; i < a1.length; i++) {
+        a1[i] = predictedOutput[i];
+      }
       Double[] expectedOutput = sample.getOutput();
-      boolean equals = Arrays.equals(expectedOutput, predictedOutput);
-//      if (!equals) {
-//        System.err.println(Arrays.toString(expectedOutput) + " vs " + Arrays.toString(predictedOutput));
-//      } else {
-//        System.err.println("equals!");
-//      }
+      double[] a2 = new double[expectedOutput.length];
+      for (int i = 0; i < a2.length; i++) {
+        a2[i] = expectedOutput[i];
+      }
+      double dist = distanceMeasure.compute(a1, a2);
+      assertTrue("expected and actual outputs are distant " + dist, dist < 10d);
     }
 
   }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java?rev=1707564&r1=1707563&r2=1707564&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java Thu Oct  8 15:03:49
2015
@@ -31,21 +31,21 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
 
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.MatrixUtils;
 import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.SingularValueDecomposition;
+import org.apache.yay.ActivationFunction;
 import org.apache.yay.Feature;
 import org.apache.yay.Input;
 import org.apache.yay.NeuralNetwork;
 import org.apache.yay.TrainingExample;
 import org.apache.yay.TrainingSet;
-import org.apache.yay.core.utils.ConversionUtils;
-import org.apache.yay.core.utils.ExamplesFactory;
 import org.junit.Test;
 
 import static org.junit.Assert.*;
@@ -70,13 +70,17 @@ public class Word2VecTest {
     TrainingExample<Double, Double> next = trainingSet.iterator().next();
     int inputSize = next.getFeatures().size() ;
     int outputSize = next.getOutput().length;
-    int n = new Random().nextInt(20);
+    int n = new Random().nextInt(20) + 5;
     RealMatrix[] randomWeights = createRandomWeights(inputSize, n, outputSize);
 
-    FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(new IdentityActivationFunction<Double>());
-    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.0005d,
-1,
+    Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer,
ActivationFunction<Double>>();
+    activationFunctions.put(0, new IdentityActivationFunction<Double>());
+    // TODO : place a softmax activation for the output layer
+    activationFunctions.put(0, new IdentityActivationFunction<Double>());
+    FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
+    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.05d,
10,
             BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(),
-            30);
+            80);
     NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy,
predictionStrategy);
 
     neuralNetwork.learn(trainingSet);



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org


Mime
View raw message