Return-Path: X-Original-To: apmail-labs-commits-archive@minotaur.apache.org Delivered-To: apmail-labs-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 8A03818B7F for ; Thu, 12 Nov 2015 13:39:27 +0000 (UTC) Received: (qmail 90706 invoked by uid 500); 12 Nov 2015 13:39:27 -0000 Delivered-To: apmail-labs-commits-archive@labs.apache.org Received: (qmail 90598 invoked by uid 500); 12 Nov 2015 13:39:27 -0000 Mailing-List: contact commits-help@labs.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: labs@labs.apache.org Delivered-To: mailing list commits@labs.apache.org Received: (qmail 90589 invoked by uid 99); 12 Nov 2015 13:39:27 -0000 Received: from Unknown (HELO spamd3-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 12 Nov 2015 13:39:27 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd3-us-west.apache.org (ASF Mail Server at spamd3-us-west.apache.org) with ESMTP id C6D7818098E for ; Thu, 12 Nov 2015 13:39:26 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd3-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: 1.791 X-Spam-Level: * X-Spam-Status: No, score=1.791 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, KAM_LAZY_DOMAIN_SECURITY=1, T_RP_MATCHES_RCVD=-0.01, URIBL_BLOCKED=0.001] autolearn=disabled Received: from mx1-us-east.apache.org ([10.40.0.8]) by localhost (spamd3-us-west.apache.org [10.40.0.10]) (amavisd-new, port 10024) with ESMTP id Y5txD7pwyPA3 for ; Thu, 12 Nov 2015 13:39:20 +0000 (UTC) Received: from mailrelay1-us-west.apache.org (mailrelay1-us-west.apache.org [209.188.14.139]) by mx1-us-east.apache.org (ASF Mail Server at mx1-us-east.apache.org) with ESMTP id 0C41843ACA for ; Thu, 12 Nov 2015 13:39:20 +0000 (UTC) Received: from svn01-us-west.apache.org (svn.apache.org [10.41.0.6]) by mailrelay1-us-west.apache.org (ASF Mail Server at mailrelay1-us-west.apache.org) with ESMTP id 6422FE0338 for ; Thu, 12 Nov 2015 13:39:19 +0000 (UTC) Received: from svn01-us-west.apache.org (localhost [127.0.0.1]) by svn01-us-west.apache.org (ASF Mail Server at svn01-us-west.apache.org) with ESMTP id 38B373A0567 for ; Thu, 12 Nov 2015 13:39:19 +0000 (UTC) Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit Subject: svn commit: r1714045 - in /labs/yay/trunk: ./ core/ core/src/main/java/org/apache/yay/core/ core/src/main/java/org/apache/yay/core/utils/ core/src/test/java/org/apache/yay/core/ Date: Thu, 12 Nov 2015 13:39:19 -0000 To: commits@labs.apache.org From: tommaso@apache.org X-Mailer: svnmailer-1.0.9 Message-Id: <20151112133919.38B373A0567@svn01-us-west.apache.org> Author: tommaso Date: Thu Nov 12 13:39:18 2015 New Revision: 1714045 URL: http://svn.apache.org/viewvc?rev=1714045&view=rev Log: performance improvements on training set creation and softmax Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java (with props) Modified: labs/yay/trunk/core/pom.xml labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java labs/yay/trunk/pom.xml Modified: labs/yay/trunk/core/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1714045&r1=1714044&r2=1714045&view=diff ============================================================================== --- labs/yay/trunk/core/pom.xml (original) +++ labs/yay/trunk/core/pom.xml Thu Nov 12 13:39:18 2015 @@ -52,5 +52,23 @@ commons-collections 3.2.1 + + com.google.guava + guava + 18.0 + test + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.18.1 + + -Xmx8g + + + + \ No newline at end of file Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1714045&r1=1714044&r2=1714045&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Thu Nov 12 13:39:18 2015 @@ -163,7 +163,7 @@ public class BackPropagationLearningStra } }; realMatrix.walkInOptimizedOrder(visitor); - if (updatedParameters[l]== null) { + if (updatedParameters[l] == null) { updatedParameters[l] = realMatrix; } } Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java?rev=1714045&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java Thu Nov 12 13:39:18 2015 @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay.core; + +import org.apache.yay.Feature; +import org.apache.yay.TrainingExample; +import org.apache.yay.TrainingSet; +import org.apache.yay.core.utils.ConversionUtils; +import org.apache.yay.core.utils.ExamplesFactory; + +import java.util.Collection; +import java.util.Iterator; +import java.util.List; + +/** + * An hot encoded {@link TrainingSet}, only indices values are stored + */ +public class EncodedTrainingSet extends TrainingSet { + private final List vocabulary; + private final int window; + + public EncodedTrainingSet(Collection> samples, List vocabulary, int window) { + super(samples); + this.vocabulary = vocabulary; + this.window = window; + } + + @Override + public int size() { + return super.size(); + } + + @Override + public Iterator> iterator() { + return new Iterator>() { + @Override + public boolean hasNext() { + return EncodedTrainingSet.super.iterator().hasNext(); + } + + @Override + public TrainingExample next() { + TrainingExample sample = EncodedTrainingSet.super.iterator().next(); + Collection> features = sample.getFeatures(); + int vocabularySize = vocabulary.size(); + Double[] outputs = new Double[vocabularySize * (window - 1)]; + Double[] inputs = new Double[vocabularySize]; + for (Feature feature : features) { + inputs = ConversionUtils.hotEncode(feature.getValue().intValue(), vocabularySize); + break; + } + int k = 0; + for (Double d : sample.getOutput()) { + Double[] currentOutput = ConversionUtils.hotEncode(d.intValue(), vocabularySize); + System.arraycopy(currentOutput, 0, outputs, k, currentOutput.length); + k += vocabularySize; + } + return ExamplesFactory.createDoubleArrayTrainingExample(outputs, inputs); + } + }; + } +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/core/EncodedTrainingSet.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1714045&r1=1714044&r2=1714045&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Thu Nov 12 13:39:18 2015 @@ -79,27 +79,43 @@ public class FeedForwardStrategy impleme // apply the activation function to each element in the matrix int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0; final ActivationFunction af = activationFunctionMap.get(idx); - RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() { - @Override - public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { - - } - - @Override - public double visit(int row, int column, double value) { - return af.apply(cm, value); - } - - @Override - public double end() { - return 0; - } - }; - x.walkInOptimizedOrder(visitor); + if (af instanceof SoftmaxActivationFunction) { + x = ((SoftmaxActivationFunction) af).applyMatrix(x); + } else { + x.walkInOptimizedOrder(new ActivationFunctionVisitor(af, cm)); + } debugOutput[w] = x.getRowVector(0); } return debugOutput; } + private static class ActivationFunctionVisitor implements RealMatrixChangingVisitor { + + private final ActivationFunction af; + private final RealMatrix matrix; + + ActivationFunctionVisitor(ActivationFunction af, RealMatrix matrix) { + this.af = af; + this.matrix = matrix; + } + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return af.apply(matrix, value); + } + + @Override + public double end() { + return 0; + } + + + } + } \ No newline at end of file Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java?rev=1714045&r1=1714044&r2=1714045&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java Thu Nov 12 13:39:18 2015 @@ -19,6 +19,9 @@ package org.apache.yay.core; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealMatrixChangingVisitor; +import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.stat.descriptive.rank.Max; import org.apache.yay.ActivationFunction; import java.util.Map; @@ -31,6 +34,25 @@ public class SoftmaxActivationFunction i private static final Map cache = new WeakHashMap(); + private static final Max m = new Max(); + + private static final RealMatrixChangingVisitor expVisitor = new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return Math.exp(value); + } + + @Override + public double end() { + return 0; + } + }; + @Override public Double apply(RealMatrix weights, Double signal) { double num = Math.exp(signal); @@ -38,18 +60,49 @@ public class SoftmaxActivationFunction i return num / den; } + public RealMatrix applyMatrix(RealMatrix weights) { + + RealMatrix matrix = weights.copy(); + double d = expDen(matrix); + final double finalD = d; + matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return Math.exp(value) / finalD; + } + + @Override + public double end() { + return 0; + } + }); + return matrix; + } + + private double expDen(RealMatrix matrix) { + double d = 0d; + for (int i = 0; i < matrix.getRowDimension(); i++) { + RealVector currentRow = matrix.getRowVector(i); + for (int j = 0; j < matrix.getColumnDimension(); j++) { + double entry = currentRow.getEntry(j); + d += Math.exp(entry); + } + } + return d; + } + private double getDen(RealMatrix weights) { Double d = cache.get(weights); - if (d == null) { - double den = 0d; - for (int i = 0; i < weights.getRowDimension(); i++) { - double[] row1 = weights.getRow(i); - for (int j = 0; j < weights.getColumnDimension(); j++) { - den += Math.exp(row1[j]); - } + synchronized (cache) { + if (d == null) { + d = expDen(weights.copy()); + cache.put(weights, d); } - d = den; - cache.put(weights, d); } return d; } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java?rev=1714045&r1=1714044&r2=1714045&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java Thu Nov 12 13:39:18 2015 @@ -18,8 +18,6 @@ */ package org.apache.yay.core.utils; -import java.util.ArrayList; -import java.util.Collection; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.OpenMapRealVector; import org.apache.commons.math3.linear.RealMatrix; @@ -27,11 +25,21 @@ import org.apache.commons.math3.linear.R import org.apache.yay.Feature; import org.apache.yay.Input; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.WeakHashMap; + /** * Temporary class for conversion between model objects and commons-math matrices/vectors */ public class ConversionUtils { + private static final WeakHashMap wordCache = new WeakHashMap(); + private static final WeakHashMap vocabularyCache = new WeakHashMap(); + /** * Converts a set of examples to a matrix of inputs with features * @@ -82,7 +90,7 @@ public class ConversionUtils { * T objects. * * @param featureVector the vector of features - * @param the type of features + * @param the type of features * @return a vector of Doubles */ public static Collection toValuesCollection(Collection> featureVector) { @@ -107,4 +115,41 @@ public class ConversionUtils { } return doubles; } + + public static Double[] hotEncode(byte[] word, List vocabulary) { + String wordString = new String(word); + Double[] vector = wordCache.get(wordString); + if (vector == null) { + vector = new Double[vocabulary.size()]; + Integer index = vocabularyCache.get(wordString); + if (index == null || index < 0) { + index = Collections.binarySearch(vocabulary, wordString); + vocabularyCache.put(wordString, index); + } + Arrays.fill(vector, 0d); + vector[index] = 1d; + wordCache.put(wordString, vector); + } + return vector; + } + + public static Double[] hotEncode(int index, int size) { + Double[] vector = new Double[size]; + Arrays.fill(vector, 0d); + vector[index] = 1d; + return vector; + } + + public static String hotDecode(Double[] doubles, List vocabulary) { + double max = -Double.MAX_VALUE; + int index = -1; + for (int i = 0; i < doubles.length; i++) { + Double aDouble = doubles[i]; + if (aDouble > max) { + max = aDouble; + index = i; + } + } + return vocabulary.get(index); + } } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1714045&r1=1714044&r2=1714045&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Thu Nov 12 13:39:18 2015 @@ -18,6 +18,7 @@ */ package org.apache.yay.core; +import com.google.common.base.Splitter; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.ml.distance.CanberraDistance; @@ -41,15 +42,27 @@ import java.io.FileWriter; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.channels.SeekableByteChannel; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.regex.Pattern; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -59,22 +72,32 @@ import static org.junit.Assert.assertNot */ public class WordVectorsTest { + private static final boolean measure = false; + + private static final boolean serialize = true; + @Test public void testSGM() throws Exception { - Collection sentences = getSentences(); - assertFalse(sentences.isEmpty()); - List vocabulary = getVocabulary(sentences); - assertFalse(vocabulary.isEmpty()); - Collections.sort(vocabulary); - Collection fragments = getFragments(sentences, 4); + + Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile()); + + System.out.println("reading fragments"); + int window = 4; + Queue> fragments = getFragments(path, window); assertFalse(fragments.isEmpty()); + System.out.println("generating vocabulary"); + List vocabulary = getVocabulary(path); + assertFalse(vocabulary.isEmpty()); - TrainingSet trainingSet = createTrainingSet(vocabulary, fragments); + System.out.println("creating training set"); + TrainingSet trainingSet = createTrainingSet(vocabulary, fragments, window); + fragments.clear(); TrainingExample next = trainingSet.iterator().next(); int inputSize = next.getFeatures().size(); int outputSize = next.getOutput().length; - int hiddenSize = 100; + int hiddenSize = 30; + System.out.println("initializing neural network"); RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize); Map> activationFunctions = new HashMap>(); @@ -83,137 +106,80 @@ public class WordVectorsTest { FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions); BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.01d, 1, BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(), - 100); + trainingSet.size()); NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy); + System.out.println("learning..."); RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet); + System.out.println("learning finished"); RealMatrix wordVectors = learnedWeights[0]; assertNotNull(wordVectors); - Collection measures = new LinkedList(); - measures.add(new EuclideanDistance()); - measures.add(new CanberraDistance()); - measures.add(new ChebyshevDistance()); - measures.add(new ManhattanDistance()); - measures.add(new EarthMoversDistance()); - measures.add(new DistanceMeasure() { - private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation(); - - @Override - public double compute(double[] a, double[] b) { - return 1 / pearsonsCorrelation.correlation(a, b); + if (serialize) { + System.out.println("serializing word vectors"); + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv"))); + for (int i = 1; i < wordVectors.getColumnDimension(); i++) { + double[] a = wordVectors.getColumnVector(i).toArray(); + String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length)); + csq = csq.substring(1, csq.length() - 1); + bufferedWriter.append(csq); + bufferedWriter.append(","); + bufferedWriter.append(vocabulary.get(i - 1)); + bufferedWriter.newLine(); } + bufferedWriter.flush(); + bufferedWriter.close(); + } - @Override - public String toString() { - return "inverse pearson correlation distance measure"; - } - }); - measures.add(new DistanceMeasure() { - @Override - public double compute(double[] a, double[] b) { - double dp = 0.0; - double na = 0.0; - double nb = 0.0; - for (int i = 0; i < a.length; i++) { - dp += a[i] * b[i]; - na += Math.pow(a[i], 2); - nb += Math.pow(b[i], 2); + if (measure) { + System.out.println("measuring similarities"); + Collection measures = new LinkedList(); + measures.add(new EuclideanDistance()); + measures.add(new CanberraDistance()); + measures.add(new ChebyshevDistance()); + measures.add(new ManhattanDistance()); + measures.add(new EarthMoversDistance()); + measures.add(new DistanceMeasure() { + private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation(); + + @Override + public double compute(double[] a, double[] b) { + return 1 / pearsonsCorrelation.correlation(a, b); } - double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb)); - return 1 / cosineSimilarity; - } - @Override - public String toString() { - return "inverse cosine similarity distance measure"; - } - }); + @Override + public String toString() { + return "inverse pearson correlation distance measure"; + } + }); + measures.add(new DistanceMeasure() { + @Override + public double compute(double[] a, double[] b) { + double dp = 0.0; + double na = 0.0; + double nb = 0.0; + for (int i = 0; i < a.length; i++) { + dp += a[i] * b[i]; + na += Math.pow(a[i], 2); + nb += Math.pow(b[i], 2); + } + double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb)); + return 1 / cosineSimilarity; + } - for (DistanceMeasure distanceMeasure : measures) { - System.out.println("computing similarity using " + distanceMeasure); - computeSimilarities(vocabulary, wordVectors, distanceMeasure); - } + @Override + public String toString() { + return "inverse cosine similarity distance measure"; + } + }); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv"))); - for (int i = 1; i < wordVectors.getColumnDimension(); i++) { - double[] a = wordVectors.getColumnVector(i).toArray(); - String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length)); - csq = csq.substring(1, csq.length() - 1); - bufferedWriter.append(csq); - bufferedWriter.append(","); - bufferedWriter.append(vocabulary.get(i-1)); - bufferedWriter.newLine(); + for (DistanceMeasure distanceMeasure : measures) { + System.out.println("computing similarity using " + distanceMeasure); + computeSimilarities(vocabulary, wordVectors, distanceMeasure); + } } - - bufferedWriter.flush(); - bufferedWriter.close(); - -// RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length); -// -// BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt"))); -// int m = 0; -// for (String word : vocabulary) { -// final Double[] doubles = hotEncode(word, vocabulary); -// Input input = new TrainingExample() { -// @Override -// public ArrayList> getFeatures() { -// ArrayList> features = new ArrayList>(); -// Feature byasFeature = new Feature(); -// byasFeature.setValue(1d); -// features.add(byasFeature); -// for (Double d : doubles) { -// Feature f = new Feature(); -// f.setValue(d); -// features.add(f); -// } -// return features; -// } -// -// @Override -// public Double[] getOutput() { -// return new Double[0]; -// } -// }; -// Double[] predict = neuralNetwork.predict(input); -// assertNotNull(predict); -// double[] row = new double[predict.length]; -// for (int x = 0; x < row.length; x++) { -// row[x] = predict[x]; -// } -// mappingsMatrix.setRow(m, row); -// m++; -// -// String vectorString = Arrays.toString(predict); -// bufferedWriter.append(vectorString); -// bufferedWriter.newLine(); -// -// Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size()); -// assertNotNull(wordVec1); -// Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size()); -// assertNotNull(wordVec2); -// Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size()); -// assertNotNull(wordVec3); -// -// String word1 = hotDecode(wordVec1, vocabulary); -// assertNotNull(word1); -// assertTrue(vocabulary.contains(word1)); -// String word2 = hotDecode(wordVec2, vocabulary); -// assertNotNull(word2); -// assertTrue(vocabulary.contains(word2)); -// String word3 = hotDecode(wordVec3, vocabulary); -// assertNotNull(word3); -// assertTrue(vocabulary.contains(word3)); -// -// System.out.println(word + " generates " + word1 + " " + word2 + " " + word3); -// } -// bufferedWriter.flush(); -// bufferedWriter.close(); -// -// ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin"))); -// MatrixUtils.serializeRealMatrix(mappingsMatrix, os); } private void computeSimilarities(List vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) { @@ -272,28 +238,17 @@ public class WordVectorsTest { } } - private String hotDecode(Double[] doubles, List vocabulary) { - double max = -Double.MAX_VALUE; - int index = -1; - for (int i = 0; i < doubles.length; i++) { - Double aDouble = doubles[i]; - if (aDouble > max) { - max = aDouble; - index = i; - } - } - return vocabulary.get(index); - } - - private TrainingSet createTrainingSet(List vocabulary, Collection fragments) { - Collection> samples = new LinkedList>(); - for (String fragment : fragments) { - String[] tokens = fragment.split(" "); - String inputWord = null; - for (int i = 0; i < tokens.length; i++) { - List outputWords = new LinkedList(); - for (int j = 0; j < tokens.length; j++) { - String token = tokens[i]; + private TrainingSet createTrainingSet(final List vocabulary, Queue> fragments, int window) { + long start = System.currentTimeMillis(); + Path file = Paths.get("/Users/teofili/Desktop/ts.txt"); + Collection> samples = new LinkedList<>(); + List fragment; + while ((fragment = fragments.poll()) != null) { + byte[] inputWord = null; + for (int i = 0; i < fragment.size(); i++) { + List outputWords = new ArrayList<>(fragment.size() - 1); + for (int j = 0; j < fragment.size(); j++) { + byte[] token = fragment.get(i); if (i == j) { inputWord = token; } else { @@ -301,92 +256,152 @@ public class WordVectorsTest { } } - final Double[] input = hotEncode(inputWord, vocabulary); - final Double[] outputs = new Double[outputWords.size() * vocabulary.size()]; - for (int k = 0; k < outputWords.size(); k++) { - Double[] doubles = hotEncode(outputWords.get(k), vocabulary); - for (int z = 0; z < doubles.length; z++) { - outputs[(k * doubles.length) + z] = doubles[z]; - } - } + final byte[] finalInputWord = inputWord; samples.add(new TrainingExample() { @Override public Double[] getOutput() { - return outputs; + Double[] doubles = new Double[window - 1]; + for (int i = 0; i < doubles.length; i++) { + doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i))); + } + return doubles; } @Override public ArrayList> getFeatures() { - ArrayList> features = new ArrayList>(); - Feature byasFeature = new Feature(); - byasFeature.setValue(1d); - features.add(byasFeature); - for (Double d : input) { - Feature e = new Feature(); - e.setValue(d); - features.add(e); - } + ArrayList> features = new ArrayList<>(); + Feature e = new Feature<>(); + e.setValue((double) vocabulary.indexOf(new String(finalInputWord))); + features.add(e); return features; } }); } } - return new TrainingSet(samples); + EncodedTrainingSet trainingSet = new EncodedTrainingSet(samples, vocabulary, window); + + long end = System.currentTimeMillis(); + System.out.println("training set created in " + (end - start) / 60000 + " minutes"); + + return trainingSet; } - private Double[] hotEncode(String word, List vocabulary) { - Double[] vector = new Double[vocabulary.size()]; - int index = Collections.binarySearch(vocabulary, word); - Arrays.fill(vector, 0d); - vector[index] = 1d; - return vector; + + private List getVocabulary(Path path) throws IOException { + long start = System.currentTimeMillis(); + Set vocabulary = new HashSet(); + SeekableByteChannel sbc = Files.newByteChannel(path); + ByteBuffer buf = ByteBuffer.allocate(100); + try { + + String encoding = System.getProperty("file.encoding"); + StringBuilder previous = new StringBuilder(); + Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults(); + while (sbc.read(buf) > 0) { + buf.rewind(); + CharBuffer charBuffer = Charset.forName(encoding).decode(buf); + String string = charBuffer.toString(); + List split = splitter.splitToList(string); + int splitSize = split.size(); + if (splitSize > 1) { + String term = previous.append(split.get(0)).toString(); + vocabulary.add(term.intern()); + for (int i = 1; i < splitSize - 1; i++) { + String term2 = split.get(i); + vocabulary.add(term2.intern()); + } + previous = new StringBuilder().append(split.get(splitSize - 1)); + } else if (split.size() == 1) { + previous.append(string); + } + buf.flip(); + } + } catch (IOException x) { + System.err.println("caught exception: " + x); + } finally { + sbc.close(); + buf.clear(); + } + long end = System.currentTimeMillis(); + List list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()])); + Collections.sort(list); + System.out.println("vocabulary read in " + (end - start) / 60000 + " minutes (" + (list.size()) + ")"); + return list; } - private List getVocabulary(Collection sentences) { + private List getVocabulary(Collection sentences) { + long start = System.currentTimeMillis(); List vocabulary = new LinkedList(); - for (String sentence : sentences) { - for (String token : sentence.split(" ")) { + for (byte[] sentence : sentences) { + for (String token : new String(sentence).split(" ")) { if (!vocabulary.contains(token)) { vocabulary.add(token); } } } + System.out.println("sorting vocabulary"); Collections.sort(vocabulary); + long end = System.currentTimeMillis(); + System.out.println("vocabulary generated in " + (end - start) / 60000 + " minutes"); return vocabulary; } - private Collection getFragments(Collection vocabulary, int w) { - Collection fragments = new LinkedList(); - for (String sentence : vocabulary) { - while (sentence.length() > 0) { - int idx = 0; - for (int i = 0; i < w; i++) { - idx = sentence.indexOf(' ', idx + 1); - } - if (idx > 0) { - String fragment = sentence.substring(0, idx); - if (fragment.split(" ").length == 4) { + private Queue> getFragments(Path path, int w) throws IOException { + long start = System.currentTimeMillis(); + Queue> fragments = new ConcurrentLinkedDeque>(); + + SeekableByteChannel sbc = Files.newByteChannel(path); + ByteBuffer buf = ByteBuffer.allocate(100); + try { + + String encoding = System.getProperty("file.encoding"); + StringBuilder previous = new StringBuilder(); + Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults(); + int lastConsumedIndex = -1; + while (sbc.read(buf) > 0) { + buf.rewind(); + CharBuffer charBuffer = Charset.forName(encoding).decode(buf); + String string = charBuffer.toString(); + List split = splitter.splitToList(string); + int splitSize = split.size(); + if (splitSize > w) { + for (int j = 0; j < splitSize - w; j++) { + List fragment = new ArrayList(w); + fragment.add(previous.append(split.get(j)).toString().getBytes()); + for (int i = 1; i < w; i++) { + fragment.add(split.get(i + j).getBytes()); + } + // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration + lastConsumedIndex = j + w; fragments.add(fragment); - sentence = sentence.substring(sentence.indexOf(' ') + 1); - } - } else { - if (sentence.split(" ").length == 4) { - fragments.add(sentence); - sentence = ""; + previous = new StringBuilder(); } + previous = new StringBuilder().append(split.get(splitSize - 1)); + } else if (split.size() == w) { + previous.append(string); } + buf.flip(); } + } catch (IOException x) { + System.err.println("caught exception: " + x); + } finally { + sbc.close(); + buf.clear(); } + long end = System.currentTimeMillis(); + System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")"); return fragments; } private Collection getSentences() throws IOException { + Collection sentences = new LinkedList(); + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/test.txt"); BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resourceAsStream)); - Collection sentences = new LinkedList(); String line; while ((line = bufferedReader.readLine()) != null) { - sentences.add(line.toLowerCase()); + String cleanLine = line.toLowerCase().replaceAll("\\.", "").replaceAll("\\;", "").replaceAll("\\,", "").replaceAll("\\:", ""); + sentences.add(cleanLine); } return sentences; } Modified: labs/yay/trunk/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1714045&r1=1714044&r2=1714045&view=diff ============================================================================== --- labs/yay/trunk/pom.xml (original) +++ labs/yay/trunk/pom.xml Thu Nov 12 13:39:18 2015 @@ -152,8 +152,8 @@ maven-compiler-plugin 2.0.2 - 1.6 - 1.6 + 1.8 + 1.8 UTF-8 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org For additional commands, e-mail: commits-help@labs.apache.org