labs-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1711018 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/core/BackPropagationLearningStrategy.java test/java/org/apache/yay/core/WordVectorsTest.java test/resources/word2vec/test.txt
Date Wed, 28 Oct 2015 14:08:28 GMT
Author: tommaso
Date: Wed Oct 28 14:08:28 2015
New Revision: 1711018

URL: http://svn.apache.org/viewvc?rev=1711018&view=rev
Log:
simplified wordvec test, back to serial matrix update impl

Added:
    labs/yay/trunk/core/src/test/resources/word2vec/test.txt   (with props)
Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1711018&r1=1711017&r2=1711018&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
Wed Oct 28 14:08:28 2015
@@ -18,15 +18,19 @@
  */
 package org.apache.yay.core;
 
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
-import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.yay.*;
-
 import java.util.Arrays;
-import java.util.Collection;
 import java.util.Iterator;
-import java.util.LinkedList;
-import java.util.concurrent.*;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.CostFunction;
+import org.apache.yay.DerivativeUpdateFunction;
+import org.apache.yay.LearningStrategy;
+import org.apache.yay.NeuralNetwork;
+import org.apache.yay.PredictionStrategy;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.WeightLearningException;
 
 /**
  * Back propagation learning algorithm for neural networks implementation (see
@@ -46,8 +50,6 @@ public class BackPropagationLearningStra
   private final int batch;
   private final int maxIterations;
 
-  private final ExecutorService executorService = Executors.newCachedThreadPool();
-
   public BackPropagationLearningStrategy(double alpha, double threshold, PredictionStrategy<Double,
Double> predictionStrategy,
                                          CostFunction<RealMatrix, Double, Double> costFunction)
{
     this(alpha, 1, threshold, predictionStrategy, costFunction, MAX_ITERATIONS);
@@ -108,7 +110,7 @@ public class BackPropagationLearningStra
         } else if (iterations > 1 && (cost == newCost || newCost < threshold
|| iterations > maxIterations)) {
           System.out.println("successfully converged after " + (iterations - 1) + " iterations
(alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters
" + Arrays.toString(hypothesis.getParameters()));
           break;
-        } else if (Double.isNaN(newCost)) {
+        } else if (Double.isNaN(newCost)){
           throw new RuntimeException("failed to converge at iteration " + iterations + "
with alpha " + alpha + " : cost calculation underflow");
         }
 
@@ -137,7 +139,14 @@ public class BackPropagationLearningStra
     RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length];
     for (int l = 0; l < weightsMatrixSet.length; l++) {
       double[][] updatedWeights = weightsMatrixSet[l].getData();
-      updateMatrix(derivatives, alpha, l, updatedWeights);
+      for (int i = 0; i < updatedWeights.length; i++) {
+        for (int j = 0; j < updatedWeights[i].length; j++) {
+          double curVal = updatedWeights[i][j];
+          if (!(i == 0 && curVal == 0d) && !(j == 0 && curVal ==
1d)) {
+            updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j];
+          }
+        }
+      }
       if (updatedParameters[l] != null) {
         updatedParameters[l].setSubMatrix(updatedWeights, 0, 0);
       } else {
@@ -147,36 +156,4 @@ public class BackPropagationLearningStra
     return updatedParameters;
   }
 
-  private void updateMatrix(final RealMatrix[] derivatives, final double alpha, final int
l, final double[][] updatedWeights) {
-    Collection<Future<Double>> futures = new LinkedList<Future<Double>>();
-    for (int i = 0; i < updatedWeights.length; i++) {
-      for (int j = 0; j < updatedWeights[i].length; j++) {
-        final int finalI = i;
-        final int finalJ = j;
-        Callable<Double> callable = new Callable<Double>() {
-          @Override
-          public Double call() throws Exception {
-            double curVal = updatedWeights[finalI][finalJ];
-            double val;
-            if (!(finalI == 0 && curVal == 0d) && !(finalJ == 0 &&
curVal == 1d)) {
-              val = -alpha * derivatives[l].getData()[finalI][finalJ];
-              updatedWeights[finalI][finalJ] = val;
-            } else {
-              val = curVal;
-            }
-            return val;
-          }
-        };
-        futures.add(executorService.submit(callable));
-      }
-    }
-    for (Future<Double> f : futures) {
-      try {
-        f.get(3, TimeUnit.SECONDS);
-      } catch (Exception e) {
-        throw new RuntimeException(e);
-      }
-    }
-  }
-
-}
+}
\ No newline at end of file

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1711018&r1=1711017&r2=1711018&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Wed Oct 28
14:08:28 2015
@@ -18,37 +18,21 @@
  */
 package org.apache.yay.core;
 
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.ml.distance.*;
+import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
+import org.apache.yay.*;
+import org.junit.Test;
+
 import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.FileWriter;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.InputStreamReader;
-import java.io.ObjectOutputStream;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-
-import org.apache.commons.math3.linear.Array2DRowRealMatrix;
-import org.apache.commons.math3.linear.MatrixUtils;
-import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.yay.ActivationFunction;
-import org.apache.yay.Feature;
-import org.apache.yay.Input;
-import org.apache.yay.NeuralNetwork;
-import org.apache.yay.TrainingExample;
-import org.apache.yay.TrainingSet;
-import org.junit.Test;
+import java.util.*;
 
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 
 /**
  * Integration test for using Yay to implement word vectors algorithms.
@@ -68,89 +52,190 @@ public class WordVectorsTest {
     TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments);
     TrainingExample<Double, Double> next = trainingSet.iterator().next();
 
-    int inputSize = next.getFeatures().size() ;
+    int inputSize = next.getFeatures().size();
     int outputSize = next.getOutput().length;
-    int hiddenSize = new Random().nextInt(50) + 15;
-    System.err.println("i:"+inputSize+",h:"+hiddenSize+",o:"+outputSize);
+    int hiddenSize = 50;
     RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize);
 
     Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer,
ActivationFunction<Double>>();
     activationFunctions.put(0, new IdentityActivationFunction<Double>());
     activationFunctions.put(1, new SoftmaxActivationFunction());
     FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
-    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d,
1,
-            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(),
10);
+    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.003d,
1,
+            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(),
+            1000);
     NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy,
predictionStrategy);
 
     RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet);
 
-    RealMatrix wordVectors = learnedWeights[learnedWeights.length - 1];
+    RealMatrix wordVectors = learnedWeights[0];
 
-    assertNotNull(wordVectors);
+    Collection<DistanceMeasure> measures = new LinkedList<DistanceMeasure>();
+    measures.add(new EuclideanDistance());
+    measures.add(new CanberraDistance());
+    measures.add(new ChebyshevDistance());
+    measures.add(new ManhattanDistance());
+    measures.add(new EarthMoversDistance());
+    measures.add(new DistanceMeasure() {
+      private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation();
+
+      @Override
+      public double compute(double[] a, double[] b) {
+        return 1 / pearsonsCorrelation.correlation(a, b);
+      }
 
-    RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
+      @Override
+      public String toString() {
+        return "inverse pearson correlation distance measure";
+      }
+    });
+    measures.add(new DistanceMeasure() {
+      @Override
+      public double compute(double[] a, double[] b) {
+        double dp = 0.0;
+        double na = 0.0;
+        double nb = 0.0;
+        for (int i = 0; i < a.length; i++) {
+          dp += a[i] * b[i];
+          na += Math.pow(a[i], 2);
+          nb += Math.pow(b[i], 2);
+        }
+        double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
+        return 1 / cosineSimilarity;
+      }
 
-    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt")));
-    int m = 0;
-    for (String word : vocabulary) {
-      final Double[] doubles = hotEncode(word, vocabulary);
-      Input<Double> input = new TrainingExample<Double, Double>() {
-        @Override
-        public ArrayList<Feature<Double>> getFeatures() {
-          ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
-          Feature<Double> byasFeature = new Feature<Double>();
-          byasFeature.setValue(1d);
-          features.add(byasFeature);
-          for (Double d : doubles) {
-            Feature<Double> f = new Feature<Double>();
-            f.setValue(d);
-            features.add(f);
-          }
-          return features;
-        }
-
-        @Override
-        public Double[] getOutput() {
-          return new Double[0];
-        }
-      };
-      Double[] predict = neuralNetwork.predict(input);
-      assertNotNull(predict);
-      double[] row = new double[predict.length];
-      for (int x = 0; x < row.length; x++) {
-        row[x] = predict[x];
-      }
-      mappingsMatrix.setRow(m, row);
-      m++;
-
-      String vectorString = Arrays.toString(predict);
-      bufferedWriter.append(vectorString);
-      bufferedWriter.newLine();
-
-      Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size());
-      assertNotNull(wordVec1);
-      Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size());
-      assertNotNull(wordVec2);
-      Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size());
-      assertNotNull(wordVec3);
-
-      String word1 = hotDecode(wordVec1, vocabulary);
-      assertNotNull(word1);
-      assertTrue(vocabulary.contains(word1));
-      String word2 = hotDecode(wordVec2, vocabulary);
-      assertNotNull(word2);
-      assertTrue(vocabulary.contains(word2));
-      String word3 = hotDecode(wordVec3, vocabulary);
-      assertNotNull(word3);
-      assertTrue(vocabulary.contains(word3));
+      @Override
+      public String toString() {
+        return "inverse cosine similarity distance measure";
+      }
+    });
 
-      System.out.println(word + " -> " + word1 + " " + word2 + " " + word3);
+    for (DistanceMeasure distanceMeasure : measures) {
+      System.out.println("computing similarity using " + distanceMeasure);
+      computeSimilarities(vocabulary, wordVectors, distanceMeasure);
     }
-    bufferedWriter.flush();
-    bufferedWriter.close();
 
-    ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin")));
-    MatrixUtils.serializeRealMatrix(mappingsMatrix, os);
+    assertNotNull(wordVectors);
+
+//    RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(),
next.getOutput().length);
+//
+//    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt")));
+//    int m = 0;
+//    for (String word : vocabulary) {
+//      final Double[] doubles = hotEncode(word, vocabulary);
+//      Input<Double> input = new TrainingExample<Double, Double>() {
+//        @Override
+//        public ArrayList<Feature<Double>> getFeatures() {
+//          ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
+//          Feature<Double> byasFeature = new Feature<Double>();
+//          byasFeature.setValue(1d);
+//          features.add(byasFeature);
+//          for (Double d : doubles) {
+//            Feature<Double> f = new Feature<Double>();
+//            f.setValue(d);
+//            features.add(f);
+//          }
+//          return features;
+//        }
+//
+//        @Override
+//        public Double[] getOutput() {
+//          return new Double[0];
+//        }
+//      };
+//      Double[] predict = neuralNetwork.predict(input);
+//      assertNotNull(predict);
+//      double[] row = new double[predict.length];
+//      for (int x = 0; x < row.length; x++) {
+//        row[x] = predict[x];
+//      }
+//      mappingsMatrix.setRow(m, row);
+//      m++;
+//
+//      String vectorString = Arrays.toString(predict);
+//      bufferedWriter.append(vectorString);
+//      bufferedWriter.newLine();
+//
+//      Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size());
+//      assertNotNull(wordVec1);
+//      Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size());
+//      assertNotNull(wordVec2);
+//      Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size());
+//      assertNotNull(wordVec3);
+//
+//      String word1 = hotDecode(wordVec1, vocabulary);
+//      assertNotNull(word1);
+//      assertTrue(vocabulary.contains(word1));
+//      String word2 = hotDecode(wordVec2, vocabulary);
+//      assertNotNull(word2);
+//      assertTrue(vocabulary.contains(word2));
+//      String word3 = hotDecode(wordVec3, vocabulary);
+//      assertNotNull(word3);
+//      assertTrue(vocabulary.contains(word3));
+//
+//      System.out.println(word + " generates " + word1 + " " + word2 + " " + word3);
+//    }
+//    bufferedWriter.flush();
+//    bufferedWriter.close();
+//
+//    ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin")));
+//    MatrixUtils.serializeRealMatrix(mappingsMatrix, os);
+  }
+
+  private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors,
DistanceMeasure distanceMeasure) {
+    for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+      double[] subjectVector = wordVectors.getColumn(i);
+      subjectVector = Arrays.copyOfRange(subjectVector, 1, subjectVector.length);
+      double maxSimilarity = -Double.MAX_VALUE;
+      double maxSimilarity1 = -Double.MAX_VALUE;
+      double maxSimilarity2 = -Double.MAX_VALUE;
+      double[] bestVector = null;
+      double[] bestVector1 = null;
+      double[] bestVector2 = null;
+      int j0 = -1;
+      int j1 = -1;
+      int j2 = -1;
+      for (int j = 1; j < wordVectors.getColumnDimension(); j++) {
+        if (i != j) {
+          double[] vector = wordVectors.getColumn(j);
+          vector = Arrays.copyOfRange(vector, 1, vector.length);
+          double similarity = 1 / distanceMeasure.compute(subjectVector, vector);
+          if (similarity > maxSimilarity) {
+            maxSimilarity2 = maxSimilarity1;
+            bestVector2 = bestVector1;
+            j2 = j1;
+
+            maxSimilarity1 = maxSimilarity;
+            bestVector1 = bestVector;
+            j1 = j0;
+
+            maxSimilarity = similarity;
+            bestVector = vector;
+            j0 = j;
+          } else if (similarity > maxSimilarity1) {
+            maxSimilarity2 = maxSimilarity1;
+            bestVector2 = bestVector1;
+            j2 = j1;
+
+            maxSimilarity1 = similarity;
+            bestVector1 = vector;
+            j1 = j;
+          } else if (similarity > maxSimilarity2) {
+            maxSimilarity2 = similarity;
+            bestVector2 = vector;
+            j2 = j;
+          }
+        }
+      }
+      if (bestVector != null && i > 0 && j0 > 0 && j1 >
0 && j2 > 0) {
+        System.out.println(vocabulary.get(i - 1) + " is similar to "
+                + vocabulary.get(j0 - 1) + ", "
+                + vocabulary.get(j1 - 1) + ", "
+                + vocabulary.get(j2 - 1));
+      } else {
+        throw new RuntimeException();
+      }
+    }
   }
 
   private String hotDecode(Double[] doubles, List<String> vocabulary) {
@@ -232,6 +317,7 @@ public class WordVectorsTest {
         }
       }
     }
+    Collections.sort(vocabulary);
     return vocabulary;
   }
 
@@ -261,7 +347,7 @@ public class WordVectorsTest {
   }
 
   private Collection<String> getSentences() throws IOException {
-    InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt");
+    InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/test.txt");
     BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resourceAsStream));
     Collection<String> sentences = new LinkedList<String>();
     String line;

Added: labs/yay/trunk/core/src/test/resources/word2vec/test.txt
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/resources/word2vec/test.txt?rev=1711018&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/resources/word2vec/test.txt (added)
+++ labs/yay/trunk/core/src/test/resources/word2vec/test.txt Wed Oct 28 14:08:28 2015
@@ -0,0 +1,3 @@
+the dog saw a cat
+the dog chased the cat
+the cat climbed a tree
\ No newline at end of file

Propchange: labs/yay/trunk/core/src/test/resources/word2vec/test.txt
------------------------------------------------------------------------------
    svn:eol-style = native



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org


Mime
View raw message