labs-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1706816 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/core/ test/java/org/apache/yay/core/
Date Mon, 05 Oct 2015 12:02:33 GMT
Author: tommaso
Date: Mon Oct  5 12:02:32 2015
New Revision: 1706816

URL: http://svn.apache.org/viewvc?rev=1706816&view=rev
Log:
added LMS, fixed backprop bias checks

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java
Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1706816&r1=1706815&r2=1706816&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
Mon Oct  5 12:02:32 2015
@@ -38,9 +38,9 @@ import org.apache.yay.WeightLearningExce
  */
 public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double>
{
 
-  public static final double DEFAULT_THRESHOLD = 0.005;
+  public static final double DEFAULT_THRESHOLD = 0.05;
   public static final int MAX_ITERATIONS = 100000;
-  public static final double DEFAULT_ALPHA = 0.03;
+  public static final double DEFAULT_ALPHA = 0.000003;
 
   private final PredictionStrategy<Double, Double> predictionStrategy;
   private final CostFunction<RealMatrix, Double, Double> costFunction;
@@ -139,7 +139,7 @@ public class BackPropagationLearningStra
       for (int i = 0; i < updatedWeights.length; i++) {
         for (int j = 0; j < updatedWeights[i].length; j++) {
           double curVal = updatedWeights[i][j];
-          if (curVal > 0d && curVal < 1d) {
+          if (!(i == 0 && curVal == 0d) && !(j == 0 && curVal ==
1d)) {
             updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j];
           }
         }

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java?rev=1706816&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java Mon Oct  5
12:02:32 2015
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.Hypothesis;
+import org.apache.yay.NeuralNetworkCostFunction;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+
+/**
+ * Least mean square cost function
+ */
+public class LMSCostFunction implements NeuralNetworkCostFunction {
+    @Override
+    public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingExamples,
Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
+        TrainingExample<Double, Double>[] samples = new TrainingExample[trainingExamples.size()];
+        int i = 0;
+        for (TrainingExample<Double, Double> sample : trainingExamples) {
+            samples[i] = sample;
+            i++;
+        }
+        return calculateCost(hypothesis, samples);
+    }
+
+    @Override
+    public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis,
TrainingExample<Double, Double>... trainingExamples) throws Exception {
+        Double cost = 0d;
+        for (TrainingExample<Double, Double> example : trainingExamples) {
+            Double[] actualOutput = example.getOutput();
+            Double[] predictedOutput = hypothesis.predict(example);
+            for (int i = 0; i < actualOutput.length; i++) {
+                cost += actualOutput[i] - predictedOutput[i];
+            }
+        }
+        return Math.pow(cost, 2) / 2;
+    }
+}

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1706816&r1=1706815&r2=1706816&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Mon Oct  5 12:02:32 2015
@@ -137,7 +137,7 @@ public class BackPropagationLearningStra
     initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d,
1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); // 4 x 4
     initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}}); //
1 x 4
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(10000,
2, 1);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(100,
2, 1);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples);
     RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights,
trainingSet);
     assertNotNull(learntWeights);

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?rev=1706816&r1=1706815&r2=1706816&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java
Mon Oct  5 12:02:32 2015
@@ -186,7 +186,7 @@ public class NeuralNetworkIntegrationTes
           if (c == 0) {
             d[0][c] = 1d;
           } else {
-            d[0][c] = r.nextInt(100) / 101d;;
+            d[0][c] = r.nextInt(100) / 101d;
           }
         } else {
           d[0][c] = 0;
@@ -199,7 +199,7 @@ public class NeuralNetworkIntegrationTes
           if (j == 0) {
             val = 1d;
           } else {
-            val = r.nextInt(100) / 101d;;
+            val = r.nextInt(100) / 101d;
           }
           d[k][j] = val;
         }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org


Mime
View raw message