labs-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1710995 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/core/BackPropagationLearningStrategy.java test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Date Wed, 28 Oct 2015 11:43:25 GMT
Author: tommaso
Date: Wed Oct 28 11:43:25 2015
New Revision: 1710995

URL: http://svn.apache.org/viewvc?rev=1710995&view=rev
Log:
slightly parallelize weight matrix update

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1710995&r1=1710994&r2=1710995&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
Wed Oct 28 11:43:25 2015
@@ -18,19 +18,15 @@
  */
 package org.apache.yay.core;
 
-import java.util.Arrays;
-import java.util.Iterator;
-
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.yay.CostFunction;
-import org.apache.yay.DerivativeUpdateFunction;
-import org.apache.yay.LearningStrategy;
-import org.apache.yay.NeuralNetwork;
-import org.apache.yay.PredictionStrategy;
-import org.apache.yay.TrainingExample;
-import org.apache.yay.TrainingSet;
-import org.apache.yay.WeightLearningException;
+import org.apache.yay.*;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.concurrent.*;
 
 /**
  * Back propagation learning algorithm for neural networks implementation (see
@@ -50,6 +46,8 @@ public class BackPropagationLearningStra
   private final int batch;
   private final int maxIterations;
 
+  private final ExecutorService executorService = Executors.newCachedThreadPool();
+
   public BackPropagationLearningStrategy(double alpha, double threshold, PredictionStrategy<Double,
Double> predictionStrategy,
                                          CostFunction<RealMatrix, Double, Double> costFunction)
{
     this(alpha, 1, threshold, predictionStrategy, costFunction, MAX_ITERATIONS);
@@ -110,7 +108,7 @@ public class BackPropagationLearningStra
         } else if (iterations > 1 && (cost == newCost || newCost < threshold
|| iterations > maxIterations)) {
           System.out.println("successfully converged after " + (iterations - 1) + " iterations
(alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters
" + Arrays.toString(hypothesis.getParameters()));
           break;
-        } else if (Double.isNaN(newCost)){
+        } else if (Double.isNaN(newCost)) {
           throw new RuntimeException("failed to converge at iteration " + iterations + "
with alpha " + alpha + " : cost calculation underflow");
         }
 
@@ -139,14 +137,7 @@ public class BackPropagationLearningStra
     RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length];
     for (int l = 0; l < weightsMatrixSet.length; l++) {
       double[][] updatedWeights = weightsMatrixSet[l].getData();
-      for (int i = 0; i < updatedWeights.length; i++) {
-        for (int j = 0; j < updatedWeights[i].length; j++) {
-          double curVal = updatedWeights[i][j];
-          if (!(i == 0 && curVal == 0d) && !(j == 0 && curVal ==
1d)) {
-            updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j];
-          }
-        }
-      }
+      updateMatrix(derivatives, alpha, l, updatedWeights);
       if (updatedParameters[l] != null) {
         updatedParameters[l].setSubMatrix(updatedWeights, 0, 0);
       } else {
@@ -156,4 +147,36 @@ public class BackPropagationLearningStra
     return updatedParameters;
   }
 
+  private void updateMatrix(final RealMatrix[] derivatives, final double alpha, final int
l, final double[][] updatedWeights) {
+    Collection<Future<Double>> futures = new LinkedList<Future<Double>>();
+    for (int i = 0; i < updatedWeights.length; i++) {
+      for (int j = 0; j < updatedWeights[i].length; j++) {
+        final int finalI = i;
+        final int finalJ = j;
+        Callable<Double> callable = new Callable<Double>() {
+          @Override
+          public Double call() throws Exception {
+            double curVal = updatedWeights[finalI][finalJ];
+            double val;
+            if (!(finalI == 0 && curVal == 0d) && !(finalJ == 0 &&
curVal == 1d)) {
+              val = -alpha * derivatives[l].getData()[finalI][finalJ];
+              updatedWeights[finalI][finalJ] = val;
+            } else {
+              val = curVal;
+            }
+            return val;
+          }
+        };
+        futures.add(executorService.submit(callable));
+      }
+    }
+    for (Future<Double> f : futures) {
+      try {
+        f.get(3, TimeUnit.SECONDS);
+      } catch (Exception e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+
 }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1710995&r1=1710994&r2=1710995&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Wed Oct 28 11:43:25 2015
@@ -143,9 +143,9 @@ public class BackPropagationLearningStra
     RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights,
trainingSet);
     assertNotNull(learntWeights);
 
-    assertFalse(learntWeights[0].equals(initialWeights[0]));
-    assertFalse(learntWeights[1].equals(initialWeights[1]));
-    assertFalse(learntWeights[2].equals(initialWeights[2]));
+    for (int i = 0; i < learntWeights.length; i++) {
+      assertFalse("weights have not been changed", learntWeights[i].equals(initialWeights[i]));
+    }
 
     backPropagationLearningStrategy = new BackPropagationLearningStrategy(BackPropagationLearningStrategy.DEFAULT_ALPHA,
-1,
             BackPropagationLearningStrategy.DEFAULT_THRESHOLD, new FeedForwardStrategy(new
SigmoidFunction()),
@@ -154,9 +154,9 @@ public class BackPropagationLearningStra
     learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
 
-    assertFalse(learntWeights[0].equals(initialWeights[0]));
-    assertFalse(learntWeights[1].equals(initialWeights[1]));
-    assertFalse(learntWeights[2].equals(initialWeights[2]));
+    for (int i = 0; i < learntWeights.length; i++) {
+      assertFalse("weights have not been changed", learntWeights[i].equals(initialWeights[i]));
+    }
   }
 
   @Test



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org


Mime
View raw message