lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1700914 - in /lucene/dev/trunk/lucene/classification/src: java/org/apache/lucene/classification/ java/org/apache/lucene/classification/utils/ test/org/apache/lucene/classification/ test/org/apache/lucene/classification/utils/
Date Wed, 02 Sep 2015 22:21:54 GMT
Author: tommaso
Date: Wed Sep  2 22:21:53 2015
New Revision: 1700914

URL: http://svn.apache.org/r1700914
Log:
LUCENE-6479 - improved cm testing, added stats, minor fixes

Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
Wed Sep  2 22:21:53 2015
@@ -173,7 +173,7 @@ public class BooleanPerceptronClassifier
         // update weights
         Long previousValue = Util.get(fst, term);
         String termString = term.utf8ToString();
-        weights.put(termString, previousValue + modifier * termFreqLocal);
+        weights.put(termString, previousValue == null ? 0 : previousValue + modifier * termFreqLocal);
       }
     }
     if (updateFST) {
@@ -214,6 +214,7 @@ public class BooleanPerceptronClassifier
         }
       }
       tokenStream.end();
+      tokenStream.close();
     }
 
     double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
Wed Sep  2 22:21:53 2015
@@ -80,7 +80,7 @@ public class CachingNaiveBayesClassifier
   }
 
 
-  private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String
inputDocument) throws IOException {
+  protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String
inputDocument) throws IOException {
     String[] tokenizedDoc = tokenizeDoc(inputDocument);
 
     List<ClassificationResult<BytesRef>> dataList = calculateLogLikelihood(tokenizedDoc);
@@ -200,7 +200,7 @@ public class CachingNaiveBayesClassifier
         }
       }
       if (insertPoint != null) {
-        // threadsafe and concurent write
+        // threadsafe and concurrent write
         termCClassHitCache.put(word, searched);
       }
     }

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
Wed Sep  2 22:21:53 2015
@@ -134,7 +134,13 @@ public class SimpleNaiveBayesClassifier
     return doclist.subList(0, max);
   }
 
-  private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String
inputDocument) throws IOException {
+  /**
+   * Calculate probabilities for all classes for a given input text
+   * @param inputDocument the input text as a {@code String}
+   * @return a {@code List} of {@code ClassificationResult}, one for each existing class
+   * @throws IOException if assigning probabilities fails
+   */
+  protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String
inputDocument) throws IOException {
     List<ClassificationResult<BytesRef>> dataList = new ArrayList<>();
 
     Terms terms = MultiFields.getTerms(leafReader, classFieldName);
@@ -143,8 +149,10 @@ public class SimpleNaiveBayesClassifier
     String[] tokenizedDoc = tokenizeDoc(inputDocument);
     int docsWithClassSize = countDocsWithClass();
     while ((next = termsEnum.next()) != null) {
-      double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc,
next, docsWithClassSize);
-      dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
+      if (next.length > 0) {
+        double clVal = calculateLogPrior(next, docsWithClassSize) + calculateLogLikelihood(tokenizedDoc,
next, docsWithClassSize);
+        dataList.add(new ClassificationResult<>(BytesRef.deepCopyOf(next), clVal));
+      }
     }
 
     // normalization; the values transforms to a 0-1 range
@@ -212,6 +220,7 @@ public class SimpleNaiveBayesClassifier
           result.add(charTermAttribute.toString());
         }
         tokenStream.end();
+        tokenStream.close();
       }
     }
     return result.toArray(new String[result.size()]);

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/ConfusionMatrixGenerator.java
Wed Sep  2 22:21:53 2015
@@ -21,11 +21,21 @@ import java.io.IOException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 
 import org.apache.lucene.classification.ClassificationResult;
 import org.apache.lucene.classification.Classifier;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.StoredDocument;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.WildcardQuery;
 import org.apache.lucene.util.BytesRef;
 
 /**
@@ -49,37 +59,67 @@ public class ConfusionMatrixGenerator {
    * @return a {@link org.apache.lucene.classification.utils.ConfusionMatrixGenerator.ConfusionMatrix}
    * @throws IOException if problems occurr while reading the index or using the classifier
    */
-  public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T>
classifier, String classFieldName, String textFieldName) throws IOException {
+  public static <T> ConfusionMatrix getConfusionMatrix(LeafReader reader, Classifier<T>
classifier, String classFieldName,
+                                                       String textFieldName) throws IOException
{
 
-    Map<String, Map<String, Long>> counts = new HashMap<>();
+    ExecutorService executorService = Executors.newFixedThreadPool(1);
+
+    try {
+
+      Map<String, Map<String, Long>> counts = new HashMap<>();
+      IndexSearcher indexSearcher = new IndexSearcher(reader);
+      TopDocs topDocs = indexSearcher.search(new WildcardQuery(new Term(classFieldName, "*")),
Integer.MAX_VALUE);
+      double time = 0d;
+
+      for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+        StoredDocument doc = reader.document(scoreDoc.doc);
+        String correctAnswer = doc.get(classFieldName);
+
+        if (correctAnswer != null && correctAnswer.length() > 0) {
+          ClassificationResult<T> result;
+          String text = doc.get(textFieldName);
+          if (text != null) {
+            try {
+              // fail if classification takes more than 5s
+              long start = System.currentTimeMillis();
+              result = executorService.submit(() -> classifier.assignClass(text)).get(5,
TimeUnit.SECONDS);
+              long end = System.currentTimeMillis();
+              time += end - start;
+
+              if (result != null) {
+                T assignedClass = result.getAssignedClass();
+                if (assignedClass != null) {
+                  String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString()
: assignedClass.toString();
+
+                  Map<String, Long> stringLongMap = counts.get(correctAnswer);
+                  if (stringLongMap != null) {
+                    Long aLong = stringLongMap.get(classified);
+                    if (aLong != null) {
+                      stringLongMap.put(classified, aLong + 1);
+                    } else {
+                      stringLongMap.put(classified, 1l);
+                    }
+                  } else {
+                    stringLongMap = new HashMap<>();
+                    stringLongMap.put(classified, 1l);
+                    counts.put(correctAnswer, stringLongMap);
+                  }
+                }
+              }
+            } catch (TimeoutException timeoutException) {
+              // add timeout
+              time += 5000;
+            } catch (ExecutionException | InterruptedException executionException) {
+              throw new RuntimeException(executionException);
+            }
 
-    for (int i = 0; i < reader.maxDoc(); i++) {
-      StoredDocument doc = reader.document(i);
-      String correctAnswer = doc.get(classFieldName);
-
-      if (correctAnswer != null && correctAnswer.length() > 0) {
-
-        ClassificationResult<T> result = classifier.assignClass(doc.get(textFieldName));
-        T assignedClass = result.getAssignedClass();
-        String classified = assignedClass instanceof BytesRef ? ((BytesRef) assignedClass).utf8ToString()
: assignedClass.toString();
-
-        Map<String, Long> stringLongMap = counts.get(correctAnswer);
-        if (stringLongMap != null) {
-          Long aLong = stringLongMap.get(classified);
-          if (aLong != null) {
-            stringLongMap.put(classified, aLong + 1);
-          } else {
-            stringLongMap.put(classified, 1l);
           }
-        } else {
-          stringLongMap = new HashMap<>();
-          stringLongMap.put(classified, 1l);
-          counts.put(correctAnswer, stringLongMap);
         }
-
       }
+      return new ConfusionMatrix(counts, time / topDocs.totalHits, topDocs.totalHits);
+    } finally {
+      executorService.shutdown();
     }
-    return new ConfusionMatrix(counts);
   }
 
   /**
@@ -88,9 +128,13 @@ public class ConfusionMatrixGenerator {
   public static class ConfusionMatrix {
 
     private final Map<String, Map<String, Long>> linearizedMatrix;
+    private final double avgClassificationTime;
+    private final int numberOfEvaluatedDocs;
 
-    private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix)
{
+    private ConfusionMatrix(Map<String, Map<String, Long>> linearizedMatrix,
double avgClassificationTime, int numberOfEvaluatedDocs) {
       this.linearizedMatrix = linearizedMatrix;
+      this.avgClassificationTime = avgClassificationTime;
+      this.numberOfEvaluatedDocs = numberOfEvaluatedDocs;
     }
 
     /**
@@ -104,8 +148,26 @@ public class ConfusionMatrixGenerator {
     @Override
     public String toString() {
       return "ConfusionMatrix{" +
-              "linearizedMatrix=" + linearizedMatrix +
-              '}';
+          "linearizedMatrix=" + linearizedMatrix +
+          ", avgClassificationTime=" + avgClassificationTime +
+          ", numberOfEvaluatedDocs=" + numberOfEvaluatedDocs +
+          '}';
+    }
+
+    /**
+     * get the average classification time in milliseconds
+     * @return the avg classification time
+     */
+    public double getAvgClassificationTime() {
+      return avgClassificationTime;
+    }
+
+    /**
+     * get the no. of documents evaluated while generating this confusion matrix
+     * @return the no. of documents evaluated
+     */
+    public int getNumberOfEvaluatedDocs() {
+      return numberOfEvaluatedDocs;
     }
   }
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
Wed Sep  2 22:21:53 2015
@@ -17,6 +17,7 @@
 package org.apache.lucene.classification;
 
 import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
@@ -32,7 +33,7 @@ public class BooleanPerceptronClassifier
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, null,
1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
     } finally {
       if (leafReader != null) {
@@ -46,7 +47,7 @@ public class BooleanPerceptronClassifier
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader,
analyzer, null, 1, 50d, booleanFieldName, textFieldName);
       checkCorrectClassification(classifier, TECHNOLOGY_INPUT, false);
       checkCorrectClassification(classifier, POLITICS_INPUT, true);
@@ -63,7 +64,7 @@ public class BooleanPerceptronClassifier
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, analyzer, query,
1, null, booleanFieldName, textFieldName), TECHNOLOGY_INPUT, false);
     } finally {
       if (leafReader != null) {
@@ -72,4 +73,29 @@ public class BooleanPerceptronClassifier
     }
   }
 
+  @Test
+  public void testPerformance() throws Exception {
+    MockAnalyzer analyzer = new MockAnalyzer(random());
+    LeafReader leafReader = getRandomIndex(analyzer, 100);
+    try {
+      long trainStart = System.currentTimeMillis();
+      BooleanPerceptronClassifier classifier = new BooleanPerceptronClassifier(leafReader,
analyzer, null, 1, 0d, booleanFieldName, textFieldName);
+      long trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime <
10000);
+
+      long evaluationStart = System.currentTimeMillis();
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
+          classifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      long evaluationEnd = System.currentTimeMillis();
+      long evaluationTime = evaluationEnd - evaluationStart;
+      assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime
< 60000);
+      double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+      assertTrue(5000 > avgClassificationTime);
+    } finally {
+      leafReader.close();
+    }
+  }
+
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
Wed Sep  2 22:21:53 2015
@@ -23,8 +23,8 @@ import org.apache.lucene.analysis.Tokeni
 import org.apache.lucene.analysis.core.KeywordTokenizer;
 import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
 import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
 import org.apache.lucene.index.LeafReader;
-import org.apache.lucene.index.SlowCompositeReaderWrapper;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.util.BytesRef;
@@ -40,7 +40,7 @@ public class CachingNaiveBayesClassifier
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null,
categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
       checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null,
categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
     } finally {
@@ -55,7 +55,7 @@ public class CachingNaiveBayesClassifier
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       TermQuery query = new TermQuery(new Term(textFieldName, "it"));
       checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query,
categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
     } finally {
@@ -70,7 +70,7 @@ public class CachingNaiveBayesClassifier
     LeafReader leafReader = null;
     try {
       NGramAnalyzer analyzer = new NGramAnalyzer();
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null,
categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
     } finally {
       if (leafReader != null) {
@@ -87,4 +87,31 @@ public class CachingNaiveBayesClassifier
     }
   }
 
+  @Test
+  public void testPerformance() throws Exception {
+    MockAnalyzer analyzer = new MockAnalyzer(random());
+    LeafReader leafReader = getRandomIndex(analyzer, 100);
+    try {
+      long trainStart = System.currentTimeMillis();
+      CachingNaiveBayesClassifier simpleNaiveBayesClassifier = new CachingNaiveBayesClassifier(leafReader,
+          analyzer, null, categoryFieldName, textFieldName);
+      long trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime <
10000);
+
+      long evaluationStart = System.currentTimeMillis();
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
+          simpleNaiveBayesClassifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      long evaluationEnd = System.currentTimeMillis();
+      long evaluationTime = evaluationEnd - evaluationStart;
+      assertTrue("evaluation took more than 1m: " + evaluationTime / 1000 + "s", evaluationTime
< 60000);
+      double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+      assertTrue(5000 > avgClassificationTime);
+    } finally {
+      leafReader.close();
+    }
+
+  }
+
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
Wed Sep  2 22:21:53 2015
@@ -40,21 +40,21 @@ import org.junit.Before;
  * Base class for testing {@link Classifier}s
  */
 public abstract class ClassificationTestBase<T> extends LuceneTestCase {
-  public final static String POLITICS_INPUT = "Here are some interesting questions and answers
about Mitt Romney.. " +
+  protected static final String POLITICS_INPUT = "Here are some interesting questions and
answers about Mitt Romney.. " +
           "If you don't know the answer to the question about Mitt Romney, then simply click
on the answer below the question section.";
-  public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
+  protected static final BytesRef POLITICS_RESULT = new BytesRef("politics");
 
-  public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook,
Google and Apple know about users." +
+  protected static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook,
Google and Apple know about users." +
           " Truth is, Amazon may know more.";
 
-  public static final String STRONG_TECHNOLOGY_INPUT = "Much is made of what the likes of
Facebook, Google and Apple know about users." +
+  protected static final String STRONG_TECHNOLOGY_INPUT = "Much is made of what the likes
of Facebook, Google and Apple know about users." +
       " Truth is, Amazon may know more. This technology observation is extracted from the
internet.";
 
-  public static final String SUPER_STRONG_TECHNOLOGY_INPUT = "More than 400 million people
trust Google with their e-mail, and 50 million store files" +
+  protected static final String SUPER_STRONG_TECHNOLOGY_INPUT = "More than 400 million people
trust Google with their e-mail, and 50 million store files" +
       " in the cloud using the Dropbox service. People manage their bank accounts, pay bills,
trade stocks and " +
       "generally transfer or store huge volumes of personal data online. traveling seeks
raises some questions Republican presidential. ";
 
-  public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
+  protected static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
 
   protected RandomIndexWriter indexWriter;
   private Directory dir;
@@ -101,7 +101,7 @@ public abstract class ClassificationTest
   }
 
   protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc,
T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query)
throws Exception {
-    populateSampleIndex(analyzer);
+    getSampleIndex(analyzer);
 
     ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
     assertNotNull(classificationResult.getAssignedClass());
@@ -115,7 +115,7 @@ public abstract class ClassificationTest
 
   }
 
-  protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
+  protected LeafReader getSampleIndex(Analyzer analyzer) throws IOException {
     indexWriter.close();
     indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
     indexWriter.commit();
@@ -193,34 +193,27 @@ public abstract class ClassificationTest
     return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
   }
 
-  protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String
classFieldName) throws Exception {
-    long trainStart = System.currentTimeMillis();
-    populatePerformanceIndex(analyzer);
-    long trainEnd = System.currentTimeMillis();
-    long trainTime = trainEnd - trainStart;
-    assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime <
120000);
-  }
-
-  private void populatePerformanceIndex(Analyzer analyzer) throws IOException {
+  protected LeafReader getRandomIndex(Analyzer analyzer, int size) throws IOException {
     indexWriter.close();
     indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
+    indexWriter.deleteAll();
     indexWriter.commit();
 
     FieldType ft = new FieldType(TextField.TYPE_STORED);
     ft.setStoreTermVectors(true);
     ft.setStoreTermVectorOffsets(true);
     ft.setStoreTermVectorPositions(true);
-    int docs = 1000;
     Random random = random();
-    for (int i = 0; i < docs; i++) {
+    for (int i = 0; i < size; i++) {
       boolean b = random.nextBoolean();
       Document doc = new Document();
       doc.add(new Field(textFieldName, createRandomString(random), ft));
-      doc.add(new Field(categoryFieldName, b ? "technology" : "politics", ft));
+      doc.add(new Field(categoryFieldName, String.valueOf(random.nextInt(1000)), ft));
       doc.add(new Field(booleanFieldName, String.valueOf(b), ft));
       indexWriter.addDocument(doc);
     }
     indexWriter.commit();
+    return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
   }
 
   private String createRandomString(Random random) {

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
Wed Sep  2 22:21:53 2015
@@ -21,6 +21,7 @@ import java.util.List;
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.analysis.MockAnalyzer;
 import org.apache.lucene.analysis.en.EnglishAnalyzer;
+import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
@@ -38,7 +39,7 @@ public class KNearestNeighborClassifierT
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer,
null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
       checkCorrectClassification(new KNearestNeighborClassifier(leafReader, new LMDirichletSimilarity(),
analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
       ClassificationResult<BytesRef> resultDS =  checkCorrectClassification(new KNearestNeighborClassifier(leafReader,
null, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
@@ -63,7 +64,7 @@ public class KNearestNeighborClassifierT
     LeafReader leafReader = null;
     try {
       Analyzer analyzer = new EnglishAnalyzer();
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader,
null, analyzer, null, 6, 1, 1, categoryFieldName, textFieldName);
       List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(STRONG_TECHNOLOGY_INPUT);
       assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
@@ -88,7 +89,7 @@ public class KNearestNeighborClassifierT
     LeafReader leafReader = null;
     try {
       Analyzer analyzer = new EnglishAnalyzer();
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       KNearestNeighborClassifier knnClassifier = new KNearestNeighborClassifier(leafReader,
null,analyzer, null, 3, 1, 1, categoryFieldName, textFieldName);
       List<ClassificationResult<BytesRef>> classes = knnClassifier.getClasses(SUPER_STRONG_TECHNOLOGY_INPUT);
       assertTrue(classes.get(0).getScore() > classes.get(1).getScore());
@@ -105,7 +106,7 @@ public class KNearestNeighborClassifierT
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       TermQuery query = new TermQuery(new Term(textFieldName, "it"));
       checkCorrectClassification(new KNearestNeighborClassifier(leafReader, null, analyzer,
query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
     } finally {
@@ -115,4 +116,30 @@ public class KNearestNeighborClassifierT
     }
   }
 
+  @Test
+  public void testPerformance() throws Exception {
+    MockAnalyzer analyzer = new MockAnalyzer(random());
+    LeafReader leafReader = getRandomIndex(analyzer, 100);
+    try {
+      long trainStart = System.currentTimeMillis();
+      KNearestNeighborClassifier kNearestNeighborClassifier = new KNearestNeighborClassifier(leafReader,
null,
+          analyzer, null, 1, 2, 2, categoryFieldName, textFieldName);
+      long trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime <
10000);
+
+      long evaluationStart = System.currentTimeMillis();
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
+          kNearestNeighborClassifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      long evaluationEnd = System.currentTimeMillis();
+      long evaluationTime = evaluationEnd - evaluationStart;
+      assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime
< 120000);
+      double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+      assertTrue(5000 > avgClassificationTime);
+    } finally {
+      leafReader.close();
+    }
+  }
+
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
Wed Sep  2 22:21:53 2015
@@ -22,11 +22,12 @@ import org.apache.lucene.analysis.Tokeni
 import org.apache.lucene.analysis.core.KeywordTokenizer;
 import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
 import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
 import org.apache.lucene.index.LeafReader;
-import org.apache.lucene.index.SlowCompositeReaderWrapper;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.util.BytesRef;
+import org.junit.Ignore;
 import org.junit.Test;
 
 /**
@@ -39,9 +40,10 @@ public class SimpleNaiveBayesClassifierT
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
-      checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null,
categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
-      checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null,
categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
+      leafReader = getSampleIndex(analyzer);
+      SimpleNaiveBayesClassifier classifier = new SimpleNaiveBayesClassifier(leafReader,
analyzer, null, categoryFieldName, textFieldName);
+      checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+      checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
     } finally {
       if (leafReader != null) {
         leafReader.close();
@@ -54,7 +56,7 @@ public class SimpleNaiveBayesClassifierT
     LeafReader leafReader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       TermQuery query = new TermQuery(new Term(textFieldName, "it"));
       checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query,
categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
     } finally {
@@ -69,7 +71,7 @@ public class SimpleNaiveBayesClassifierT
     LeafReader leafReader = null;
     try {
       Analyzer analyzer = new NGramAnalyzer();
-      leafReader = populateSampleIndex(analyzer);
+      leafReader = getSampleIndex(analyzer);
       checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null,
categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
     } finally {
       if (leafReader != null) {
@@ -86,4 +88,32 @@ public class SimpleNaiveBayesClassifierT
     }
   }
 
+  @Ignore
+  @Test
+  public void testPerformance() throws Exception {
+    MockAnalyzer analyzer = new MockAnalyzer(random());
+    LeafReader leafReader = getRandomIndex(analyzer, 100);
+    try {
+      long trainStart = System.currentTimeMillis();
+      SimpleNaiveBayesClassifier simpleNaiveBayesClassifier = new SimpleNaiveBayesClassifier(leafReader,
+          analyzer, null, categoryFieldName, textFieldName);
+      long trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime <
10000);
+
+      long evaluationStart = System.currentTimeMillis();
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
+          simpleNaiveBayesClassifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      long evaluationEnd = System.currentTimeMillis();
+      long evaluationTime = evaluationEnd - evaluationStart;
+      assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime
< 120000);
+      double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+      assertTrue("avg classification time: " + avgClassificationTime, 5000 > avgClassificationTime);
+    } finally {
+      leafReader.close();
+    }
+
+  }
+
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java?rev=1700914&r1=1700913&r2=1700914&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
(original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
Wed Sep  2 22:21:53 2015
@@ -17,9 +17,13 @@ package org.apache.lucene.classification
  * limitations under the License.
  */
 
+import java.io.IOException;
+import java.util.List;
+
 import org.apache.lucene.analysis.MockAnalyzer;
 import org.apache.lucene.classification.BooleanPerceptronClassifier;
 import org.apache.lucene.classification.CachingNaiveBayesClassifier;
+import org.apache.lucene.classification.ClassificationResult;
 import org.apache.lucene.classification.ClassificationTestBase;
 import org.apache.lucene.classification.Classifier;
 import org.apache.lucene.classification.KNearestNeighborClassifier;
@@ -34,15 +38,52 @@ import org.junit.Test;
 public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object> {
 
   @Test
+  public void testGetConfusionMatrix() throws Exception {
+    LeafReader reader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      reader = getSampleIndex(analyzer);
+      Classifier<BytesRef> classifier = new Classifier<BytesRef>() {
+        @Override
+        public ClassificationResult<BytesRef> assignClass(String text) throws IOException
{
+          return new ClassificationResult<>(new BytesRef(), 1 / (1 + Math.exp(-random().nextInt())));
+        }
+
+        @Override
+        public List<ClassificationResult<BytesRef>> getClasses(String text) throws
IOException {
+          return null;
+        }
+
+        @Override
+        public List<ClassificationResult<BytesRef>> getClasses(String text, int
max) throws IOException {
+          return null;
+        }
+      };
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName);
+      assertNotNull(confusionMatrix);
+      assertNotNull(confusionMatrix.getLinearizedMatrix());
+      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
+      double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+      assertTrue(avgClassificationTime >= 0d );
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+  }
+
+  @Test
   public void testGetConfusionMatrixWithSNB() throws Exception {
     LeafReader reader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      reader = populateSampleIndex(analyzer);
+      reader = getSampleIndex(analyzer);
       Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer,
null, categoryFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
+      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
+      assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -55,11 +96,13 @@ public class ConfusionMatrixGeneratorTes
     LeafReader reader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      reader = populateSampleIndex(analyzer);
+      reader = getSampleIndex(analyzer);
       Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer,
null, categoryFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
+      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
+      assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -72,11 +115,13 @@ public class ConfusionMatrixGeneratorTes
     LeafReader reader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      reader = populateSampleIndex(analyzer);
+      reader = getSampleIndex(analyzer);
       Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null,
analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, categoryFieldName, textFieldName);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
+      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
+      assertTrue(confusionMatrix.getAvgClassificationTime() > 0d);
     } finally {
       if (reader != null) {
         reader.close();
@@ -89,11 +134,13 @@ public class ConfusionMatrixGeneratorTes
     LeafReader reader = null;
     try {
       MockAnalyzer analyzer = new MockAnalyzer(random());
-      reader = populateSampleIndex(analyzer);
+      reader = getSampleIndex(analyzer);
       Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer,
null, 1, null, booleanFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
classifier, booleanFieldName, textFieldName);
       assertNotNull(confusionMatrix);
       assertNotNull(confusionMatrix.getLinearizedMatrix());
+      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
+      assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
     } finally {
       if (reader != null) {
         reader.close();



Mime
View raw message