labs-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1714192 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/core/ core/src/main/java/org/apache/yay/core/neuron/ core/src/main/java/org/apache/yay/core/utils/ core/src/test/java/org/apache/yay/core/
Date Fri, 13 Nov 2015 11:40:47 GMT
Author: tommaso
Date: Fri Nov 13 11:40:47 2015
New Revision: 1714192

URL: http://svn.apache.org/viewvc?rev=1714192&view=rev
Log:
various performance improvements

Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/ActivationFunction.java Fri Nov 13 11:40:47
2015
@@ -38,4 +38,13 @@ public interface ActivationFunction<T> {
    */
   T apply(RealMatrix weights, T signal);
 
+  /**
+   * Apply this <code>ActivationFunction</code> to the given matrix of signals,
generating a new matrix of transformed
+   * signals.
+   *
+   * @param weights the matrix of weights the activation should be applied to
+   * @return the output signal generated
+   */
+  RealMatrix applyMatrix(RealMatrix weights);
+
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
Fri Nov 13 11:40:47 2015
@@ -83,11 +83,12 @@ public class BackPropagationLearningStra
     try {
       int iterations = 0;
 
-      NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double,
Double>(), predictionStrategy);
+      NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<>(),
predictionStrategy);
       Iterator<TrainingExample<Double, Double>> iterator = trainingExamples.iterator();
 
       double cost = Double.MAX_VALUE;
       while (true) {
+        System.err.println(iterations);
         TrainingSet<Double, Double> samples;
         if (batch == -1) {
           samples = trainingExamples;
@@ -103,12 +104,12 @@ public class BackPropagationLearningStra
         }
 
         // calculate cost
-        double newCost = costFunction.calculateAggregatedCost(samples, hypothesis);
+        double newCost = costFunction.calculateAggregatedCost(samples, neuralNetwork);
 
         if (Double.POSITIVE_INFINITY == newCost || newCost > cost && batch ==
-1) {
           throw new RuntimeException("failed to converge at iteration " + iterations + "
with alpha " + alpha + " : cost going from " + cost + " to " + newCost);
         } else if (iterations > 1 && (cost == newCost || newCost < threshold
|| iterations > maxIterations)) {
-          System.out.println("successfully converged after " + (iterations - 1) + " iterations
(alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters
" + Arrays.toString(hypothesis.getParameters()));
+          System.out.println("successfully converged after " + (iterations - 1) + " iterations
(alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters
" + Arrays.toString(neuralNetwork.getParameters()));
           break;
         } else if (Double.isNaN(newCost)) {
           throw new RuntimeException("failed to converge at iteration " + iterations + "
with alpha " + alpha + " : cost calculation underflow");
@@ -124,7 +125,7 @@ public class BackPropagationLearningStra
         updatedWeights = updateWeights(updatedWeights, derivatives, alpha);
 
         // update parameters in the hypothesis
-        hypothesis.setParameters(updatedWeights);
+        neuralNetwork.setParameters(updatedWeights);
 
         iterations++;
       }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Fri Nov 13
11:40:47 2015
@@ -18,8 +18,7 @@
  */
 package org.apache.yay.core;
 
-import java.util.Collection;
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.MatrixUtils;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.Input;
 import org.apache.yay.LearningException;
@@ -30,6 +29,8 @@ import org.apache.yay.TrainingSet;
 import org.apache.yay.core.neuron.BinaryThresholdNeuron;
 import org.apache.yay.core.utils.ConversionUtils;
 
+import java.util.Collection;
+
 /**
  * A perceptron {@link org.apache.yay.NeuralNetwork} implementation based on
  * {@link org.apache.yay.core.neuron.BinaryThresholdNeuron}s
@@ -56,7 +57,7 @@ public class BasicPerceptron implements
     for (TrainingExample<Double, Double> example : trainingExamples) {
       learn(example);
     }
-    return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
+    return new RealMatrix[]{MatrixUtils.createRowRealMatrix(currentWeights)};
   }
 
   public void learn(TrainingExample<Double, Double> example) {
@@ -87,7 +88,7 @@ public class BasicPerceptron implements
 
   @Override
   public RealMatrix[] getParameters() {
-    return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
+    return new RealMatrix[]{MatrixUtils.createRowRealMatrix(currentWeights)};
   }
 
   @Override

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Fri Nov
13 11:40:47 2015
@@ -18,9 +18,8 @@
  */
 package org.apache.yay.core;
 
-import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.MatrixUtils;
 import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.ActivationFunction;
 import org.apache.yay.PredictionStrategy;
@@ -29,6 +28,7 @@ import org.apache.yay.core.utils.Convers
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.stream.Stream;
 
 /**
  * Octave code for FF to be converted :
@@ -64,58 +64,23 @@ public class FeedForwardStrategy impleme
   }
 
   private RealVector[] applyFF(Collection<Double> input, RealMatrix[] realMatrixSet)
{
-    RealVector[] debugOutput = new ArrayRealVector[realMatrixSet.length];
+    RealVector[] debugOutput = new RealVector[realMatrixSet.length];
 
-    // TODO : fix this impl as it's very slow
-    RealVector v = ConversionUtils.toRealVector(input);
-    RealMatrix x = v.outerProduct(new ArrayRealVector(new Double[]{1d})).transpose(); //
a 1xN matrix
+    Double[] doubles = input.toArray(new Double[input.size()]);
+    RealMatrix x = MatrixUtils.createRowRealMatrix(Stream.of(doubles).mapToDouble(Double::doubleValue).toArray());
     for (int w = 0; w < realMatrixSet.length; w++) {
       final RealMatrix currentWeightsMatrix = realMatrixSet[w];
       // compute matrix multiplication
       x = x.multiply(currentWeightsMatrix.transpose());
 
-      final RealMatrix cm = x.getRowMatrix(0);
-
       // apply the activation function to each element in the matrix
+      final RealMatrix cm = x.getRowMatrix(0);
       int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0;
       final ActivationFunction<Double> af = activationFunctionMap.get(idx);
-
-      if (af instanceof SoftmaxActivationFunction) {
-        x = ((SoftmaxActivationFunction) af).applyMatrix(x);
-      } else {
-        x.walkInOptimizedOrder(new ActivationFunctionVisitor(af, cm));
-      }
+      x = af.applyMatrix(cm);
       debugOutput[w] = x.getRowVector(0);
     }
     return debugOutput;
   }
 
-  private static class ActivationFunctionVisitor implements RealMatrixChangingVisitor {
-
-    private final ActivationFunction<Double> af;
-    private final RealMatrix matrix;
-
-    ActivationFunctionVisitor(ActivationFunction<Double> af, RealMatrix matrix) {
-      this.af = af;
-      this.matrix = matrix;
-    }
-
-    @Override
-    public void start(int rows, int columns, int startRow, int endRow, int startColumn, int
endColumn) {
-
-    }
-
-    @Override
-    public double visit(int row, int column, double value) {
-      return af.apply(matrix, value);
-    }
-
-    @Override
-    public double end() {
-      return 0;
-    }
-
-
-  }
-
 }
\ No newline at end of file

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/IdentityActivationFunction.java
Fri Nov 13 11:40:47 2015
@@ -31,4 +31,9 @@ public class IdentityActivationFunction<
     return signal;
   }
 
+  @Override
+  public RealMatrix applyMatrix(RealMatrix weights) {
+    return weights;
+  }
+
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java Fri Nov 13
11:40:47 2015
@@ -19,6 +19,7 @@
 package org.apache.yay.core;
 
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -27,7 +28,32 @@ import org.apache.yay.ActivationFunction
 public class SigmoidFunction implements ActivationFunction<Double> {
 
   public Double apply(RealMatrix matrix, final Double input) {
+    return sigmoid(input);
+  }
+
+  private double sigmoid(Double input) {
     return 1d / (1d + Math.exp(-1d * input));
   }
 
+  @Override
+  public RealMatrix applyMatrix(RealMatrix weights) {
+    RealMatrix matrix = weights.copy();
+    matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+      @Override
+      public void start(int rows, int columns, int startRow, int endRow, int startColumn,
int endColumn) {
+
+      }
+
+      @Override
+      public double visit(int row, int column, double value) {
+        return sigmoid(value);
+      }
+
+      @Override
+      public double end() {
+        return 0;
+      }
+    });
+    return matrix;
+  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java Fri
Nov 13 11:40:47 2015
@@ -21,50 +21,23 @@ package org.apache.yay.core;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.commons.math3.linear.RealVector;
-import org.apache.commons.math3.stat.descriptive.rank.Max;
 import org.apache.yay.ActivationFunction;
 
-import java.util.Map;
-import java.util.WeakHashMap;
-
 /**
  * Softmax activation function
  */
 public class SoftmaxActivationFunction implements ActivationFunction<Double> {
 
-  private static final Map<RealMatrix, Double> cache = new WeakHashMap<RealMatrix,
Double>();
-
-  private static final Max m = new Max();
-
-  private static final RealMatrixChangingVisitor expVisitor = new RealMatrixChangingVisitor()
{
-    @Override
-    public void start(int rows, int columns, int startRow, int endRow, int startColumn, int
endColumn) {
-
-    }
-
-    @Override
-    public double visit(int row, int column, double value) {
-      return Math.exp(value);
-    }
-
-    @Override
-    public double end() {
-      return 0;
-    }
-  };
-
   @Override
   public Double apply(RealMatrix weights, Double signal) {
     double num = Math.exp(signal);
-    double den = getDen(weights);
+    double den = expDen(weights);
     return num / den;
   }
 
   public RealMatrix applyMatrix(RealMatrix weights) {
-
     RealMatrix matrix = weights.copy();
-    double d = expDen(matrix);
-    final double finalD = d;
+    final double finalD = expDen(matrix);
     matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
       @Override
       public void start(int rows, int columns, int startRow, int endRow, int startColumn,
int endColumn) {
@@ -96,14 +69,4 @@ public class SoftmaxActivationFunction i
     return d;
   }
 
-  private double getDen(RealMatrix weights) {
-    Double d = cache.get(weights);
-    synchronized (cache) {
-      if (d == null) {
-        d = expDen(weights.copy());
-        cache.put(weights, d);
-      }
-    }
-    return d;
-  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java Fri
Nov 13 11:40:47 2015
@@ -19,6 +19,7 @@
 package org.apache.yay.core;
 
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -34,7 +35,32 @@ public class StepActivationFunction impl
 
   @Override
   public Double apply(RealMatrix matrix, Double signal) {
+    return step(signal);
+  }
+
+  private double step(Double signal) {
     return signal >= center ? 1d : 0d;
   }
 
+  @Override
+  public RealMatrix applyMatrix(RealMatrix weights) {
+    RealMatrix matrix = weights.copy();
+    matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+      @Override
+      public void start(int rows, int columns, int startRow, int endRow, int startColumn,
int endColumn) {
+
+      }
+
+      @Override
+      public double visit(int row, int column, double value) {
+        return step(value);
+      }
+
+      @Override
+      public double end() {
+        return 0;
+      }
+    });
+    return matrix;
+  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java Fri Nov 13 11:40:47
2015
@@ -19,6 +19,7 @@
 package org.apache.yay.core;
 
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -29,4 +30,26 @@ public class TanhFunction implements Act
   public Double apply(RealMatrix matrix, Double signal) {
     return Math.tanh(signal);
   }
+
+  @Override
+  public RealMatrix applyMatrix(RealMatrix weights) {
+    RealMatrix matrix = weights.copy();
+    matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+      @Override
+      public void start(int rows, int columns, int startRow, int endRow, int startColumn,
int endColumn) {
+
+      }
+
+      @Override
+      public double visit(int row, int column, double value) {
+        return Math.tanh(value);
+      }
+
+      @Override
+      public double end() {
+        return 0;
+      }
+    });
+    return matrix;
+  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/RectifiedLinearNeuron.java
Fri Nov 13 11:40:47 2015
@@ -19,6 +19,7 @@
 package org.apache.yay.core.neuron;
 
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.yay.ActivationFunction;
 
 /**
@@ -32,8 +33,34 @@ class RectifiedLinearNeuron extends Line
     this.activationFunction = new ActivationFunction<Double>() {
       @Override
       public Double apply(RealMatrix matrix, Double signal) {
-        return signal > 0 ? signal : 0;
+        return rect(signal);
+      }
+
+      @Override
+      public RealMatrix applyMatrix(RealMatrix weights) {
+        RealMatrix matrix = weights.copy();
+        matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+          @Override
+          public void start(int rows, int columns, int startRow, int endRow, int startColumn,
int endColumn) {
+
+          }
+
+          @Override
+          public double visit(int row, int column, double value) {
+            return rect(value);
+          }
+
+          @Override
+          public double end() {
+            return 0;
+          }
+        });
+        return matrix;
       }
     };
   }
+
+  private double rect(Double signal) {
+    return signal > 0 ? signal : 0;
+  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java Fri Nov
13 11:40:47 2015
@@ -18,7 +18,7 @@
  */
 package org.apache.yay.core.utils;
 
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.MatrixUtils;
 import org.apache.commons.math3.linear.OpenMapRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
@@ -56,7 +56,7 @@ public class ConversionUtils {
       i++;
     }
 
-    return new Array2DRowRealMatrix(matrixData);
+    return MatrixUtils.createRealMatrix(matrixData);
   }
 
   /**

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1714192&r1=1714191&r2=1714192&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Fri Nov 13
11:40:47 2015
@@ -19,7 +19,7 @@
 package org.apache.yay.core;
 
 import com.google.common.base.Splitter;
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.MatrixUtils;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.ml.distance.CanberraDistance;
 import org.apache.commons.math3.ml.distance.ChebyshevDistance;
@@ -445,7 +445,7 @@ public class WordVectorsTest {
           d[k][j] = val;
         }
       }
-      initialWeights[i] = new Array2DRowRealMatrix(d);
+      initialWeights[i] = MatrixUtils.createRealMatrix(d);
     }
     return initialWeights;
   }



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org


Mime
View raw message