lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tomm...@apache.org
Subject svn commit: r1676997 - in /lucene/dev/trunk/lucene/classification/src: java/org/apache/lucene/classification/ java/org/apache/lucene/classification/utils/ test/org/apache/lucene/classification/
Date Thu, 30 Apr 2015 14:12:03 GMT
Author: tommaso
Date: Thu Apr 30 14:12:03 2015
New Revision: 1676997

URL: http://svn.apache.org/r1676997
Log:
LUCENE-6045 - refactor Classifier API to work better with multithreading

Modified:
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
    lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
    lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/BooleanPerceptronClassifier.java Thu Apr 30 14:12:03 2015
@@ -58,76 +58,14 @@ import org.apache.lucene.util.fst.Util;
  */
 public class BooleanPerceptronClassifier implements Classifier<Boolean> {
 
-  private Double threshold;
-  private final Integer batchSize;
-  private Terms textTerms;
-  private Analyzer analyzer;
-  private String textFieldName;
+  private final Double threshold;
+  private final Terms textTerms;
+  private final Analyzer analyzer;
+  private final String textFieldName;
   private FST<Long> fst;
 
-  /**
-   * Create a {@link BooleanPerceptronClassifier}
-   *
-   * @param threshold the binary threshold for perceptron output evaluation
-   */
-  public BooleanPerceptronClassifier(Double threshold, Integer batchSize) {
-    this.threshold = threshold;
-    this.batchSize = batchSize;
-  }
-
-  /**
-   * Default constructor, no batch updates of FST, perceptron threshold is
-   * calculated via underlying index metrics during
-   * {@link #train(org.apache.lucene.index.LeafReader, String, String, org.apache.lucene.analysis.Analyzer)
-   * training}
-   */
-  public BooleanPerceptronClassifier() {
-    batchSize = 1;
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public ClassificationResult<Boolean> assignClass(String text)
-      throws IOException {
-    if (textTerms == null) {
-      throw new IOException("You must first call Classifier#train");
-    }
-    Long output = 0l;
-    try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
-      CharTermAttribute charTermAttribute = tokenStream
-          .addAttribute(CharTermAttribute.class);
-      tokenStream.reset();
-      while (tokenStream.incrementToken()) {
-        String s = charTermAttribute.toString();
-        Long d = Util.get(fst, new BytesRef(s));
-        if (d != null) {
-          output += d;
-        }
-      }
-      tokenStream.end();
-    }
-
-    double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
-    return new ClassificationResult<>(output >= threshold, score);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName,
-                    String classFieldName, Analyzer analyzer) throws IOException {
-    train(leafReader, textFieldName, classFieldName, analyzer, null);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName,
-                    String classFieldName, Analyzer analyzer, Query query) throws IOException {
+  public BooleanPerceptronClassifier(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer,
+                                     Query query, Integer batchSize, Double threshold) throws IOException {
     this.textTerms = MultiFields.getTerms(leafReader, textFieldName);
 
     if (textTerms == null) {
@@ -144,9 +82,11 @@ public class BooleanPerceptronClassifier
         this.threshold = (double) sumDocFreq / 2d;
       } else {
         throw new IOException(
-            "threshold cannot be assigned since term vectors for field "
-                + textFieldName + " do not exist");
+                "threshold cannot be assigned since term vectors for field "
+                        + textFieldName + " do not exist");
       }
+    } else {
+      this.threshold = threshold;
     }
 
     // TODO : remove this map as soon as we have a writable FST
@@ -170,7 +110,7 @@ public class BooleanPerceptronClassifier
     }
     // run the search and use stored field values
     for (ScoreDoc scoreDoc : indexSearcher.search(q,
-        Integer.MAX_VALUE).scoreDocs) {
+            Integer.MAX_VALUE).scoreDocs) {
       StoredDocument doc = indexSearcher.doc(scoreDoc.doc);
 
       StorableField textField = doc.getField(textFieldName);
@@ -187,7 +127,7 @@ public class BooleanPerceptronClassifier
         long modifier = correctClass.compareTo(assignedClass);
         if (modifier != 0) {
           updateWeights(leafReader, scoreDoc.doc, assignedClass,
-                weights, modifier, batchCount % batchSize == 0);
+                  weights, modifier, batchCount % batchSize == 0);
         }
         batchCount++;
       }
@@ -195,11 +135,6 @@ public class BooleanPerceptronClassifier
     weights.clear(); // free memory while waiting for GC
   }
 
-  @Override
-  public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
-    throw new IOException("training with multiple fields not supported by boolean perceptron classifier");
-  }
-
   private void updateWeights(LeafReader leafReader,
                              int docId, Boolean assignedClass, SortedMap<String, Double> weights,
                              double modifier, boolean updateFST) throws IOException {
@@ -210,7 +145,7 @@ public class BooleanPerceptronClassifier
 
     if (terms == null) {
       throw new IOException("term vectors must be stored for field "
-          + textFieldName);
+              + textFieldName);
     }
 
     TermsEnum termsEnum = terms.iterator();
@@ -240,17 +175,46 @@ public class BooleanPerceptronClassifier
     for (Map.Entry<String, Double> entry : weights.entrySet()) {
       scratchBytes.copyChars(entry.getKey());
       fstBuilder.add(Util.toIntsRef(scratchBytes.get(), scratchInts), entry
-          .getValue().longValue());
+              .getValue().longValue());
     }
     fst = fstBuilder.finish();
   }
 
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public ClassificationResult<Boolean> assignClass(String text)
+          throws IOException {
+    if (textTerms == null) {
+      throw new IOException("You must first call Classifier#train");
+    }
+    Long output = 0l;
+    try (TokenStream tokenStream = analyzer.tokenStream(textFieldName, text)) {
+      CharTermAttribute charTermAttribute = tokenStream
+              .addAttribute(CharTermAttribute.class);
+      tokenStream.reset();
+      while (tokenStream.incrementToken()) {
+        String s = charTermAttribute.toString();
+        Long d = Util.get(fst, new BytesRef(s));
+        if (d != null) {
+          output += d;
+        }
+      }
+      tokenStream.end();
+    }
+
+    double score = 1 - Math.exp(-1 * Math.abs(threshold - output.doubleValue()) / threshold);
+    return new ClassificationResult<>(output >= threshold, score);
+  }
+
   /**
    * {@inheritDoc}
    */
   @Override
   public List<ClassificationResult<Boolean>> getClasses(String text)
-      throws IOException {
+          throws IOException {
     throw new RuntimeException("not implemented");
   }
 
@@ -259,7 +223,7 @@ public class BooleanPerceptronClassifier
    */
   @Override
   public List<ClassificationResult<Boolean>> getClasses(String text, int max)
-      throws IOException {
+          throws IOException {
     throw new RuntimeException("not implemented");
   }
 

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/CachingNaiveBayesClassifier.java Thu Apr 30 14:12:03 2015
@@ -49,50 +49,30 @@ import org.apache.lucene.util.BytesRef;
  */
 public class CachingNaiveBayesClassifier extends SimpleNaiveBayesClassifier {
   //for caching classes this will be the classification class list
-  private ArrayList<BytesRef> cclasses = new ArrayList<>();
+  private final ArrayList<BytesRef> cclasses = new ArrayList<>();
   // it's a term-inmap style map, where the inmap contains class-hit pairs to the
   // upper term
-  private Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
+  private final Map<String, Map<BytesRef, Integer>> termCClassHitCache = new HashMap<>();
   // the term frequency in classes
-  private Map<BytesRef, Double> classTermFreq = new HashMap<>();
+  private final Map<BytesRef, Double> classTermFreq = new HashMap<>();
   private boolean justCachedTerms;
   private int docsWithClassSize;
 
   /**
-   * Creates a new NaiveBayes classifier with inside caching. Note that you must
-   * call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before
-   * you can classify any documents. If you want less memory usage you could
+   * Creates a new NaiveBayes classifier with inside caching. If you want less memory usage you could
    * call {@link #reInitCache(int, boolean) reInitCache()}.
    */
-  public CachingNaiveBayesClassifier() {
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
-    train(leafReader, textFieldName, classFieldName, analyzer, null);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
-    train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
-    super.train(leafReader, textFieldNames, classFieldName, analyzer, query);
+  public CachingNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
+    super(leafReader, analyzer, query, classFieldName, textFieldNames);
     // building the cache
-    reInitCache(0, true);
+    try {
+      reInitCache(0, true);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
   }
 
+
   private List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
     if (leafReader == null) {
       throw new IOException("You must first call Classifier#train");

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/ClassificationResult.java Thu Apr 30 14:12:03 2015
@@ -18,17 +18,19 @@ package org.apache.lucene.classification
 
 /**
  * The result of a call to {@link Classifier#assignClass(String)} holding an assigned class of type <code>T</code> and a score.
+ *
  * @lucene.experimental
  */
-public class ClassificationResult<T> implements Comparable<ClassificationResult<T>>{
+public class ClassificationResult<T> implements Comparable<ClassificationResult<T>> {
 
   private final T assignedClass;
   private double score;
 
   /**
    * Constructor
+   *
    * @param assignedClass the class <code>T</code> assigned by a {@link Classifier}
-   * @param score the score for the assignedClass as a <code>double</code>
+   * @param score         the score for the assignedClass as a <code>double</code>
    */
   public ClassificationResult(T assignedClass, double score) {
     this.assignedClass = assignedClass;
@@ -37,6 +39,7 @@ public class ClassificationResult<T> imp
 
   /**
    * retrieve the result class
+   *
    * @return a <code>T</code> representing an assigned class
    */
   public T getAssignedClass() {
@@ -45,14 +48,16 @@ public class ClassificationResult<T> imp
 
   /**
    * retrieve the result score
+   *
    * @return a <code>double</code> representing a result score
    */
   public double getScore() {
     return score;
   }
-  
+
   /**
    * set the score value
+   *
    * @param score the score for the assignedClass as a <code>double</code>
    */
   public void setScore(double score) {

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/Classifier.java Thu Apr 30 14:12:03 2015
@@ -22,7 +22,6 @@ import java.util.List;
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.search.Query;
-import org.apache.lucene.util.BytesRef;
 
 /**
  * A classifier, see <code>http://en.wikipedia.org/wiki/Classifier_(mathematics)</code>, which assign classes of type
@@ -39,7 +38,7 @@ public interface Classifier<T> {
    * @return a {@link ClassificationResult} holding assigned class of type <code>T</code> and score
    * @throws IOException If there is a low-level I/O error.
    */
-  public ClassificationResult<T> assignClass(String text) throws IOException;
+  ClassificationResult<T> assignClass(String text) throws IOException;
 
   /**
    * Get all the classes (sorted by score, descending) assigned to the given text String.
@@ -48,7 +47,7 @@ public interface Classifier<T> {
    * @return the whole list of {@link ClassificationResult}, the classes and scores. Returns <code>null</code> if the classifier can't make lists.
    * @throws IOException If there is a low-level I/O error.
    */
-  public List<ClassificationResult<T>> getClasses(String text) throws IOException;
+  List<ClassificationResult<T>> getClasses(String text) throws IOException;
 
   /**
    * Get the first <code>max</code> classes (sorted by score, descending) assigned to the given text String.
@@ -58,44 +57,6 @@ public interface Classifier<T> {
    * @return the whole list of {@link ClassificationResult}, the classes and scores. Cut for "max" number of elements. Returns <code>null</code> if the classifier can't make lists.
    * @throws IOException If there is a low-level I/O error.
    */
-  public List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
-
-  /**
-   * Train the classifier using the underlying Lucene index
-   *
-   * @param leafReader   the reader to use to access the Lucene index
-   * @param textFieldName  the name of the field used to compare documents
-   * @param classFieldName the name of the field containing the class assigned to documents
-   * @param analyzer       the analyzer used to tokenize / filter the unseen text
-   * @throws IOException If there is a low-level I/O error.
-   */
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer)
-      throws IOException;
-
-  /**
-   * Train the classifier using the underlying Lucene index
-   *
-   * @param leafReader   the reader to use to access the Lucene index
-   * @param textFieldName  the name of the field used to compare documents
-   * @param classFieldName the name of the field containing the class assigned to documents
-   * @param analyzer       the analyzer used to tokenize / filter the unseen text
-   * @param query          the query to filter which documents use for training
-   * @throws IOException If there is a low-level I/O error.
-   */
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
-      throws IOException;
-
-  /**
-   * Train the classifier using the underlying Lucene index
-   *
-   * @param leafReader   the reader to use to access the Lucene index
-   * @param textFieldNames the names of the fields to be used to compare documents
-   * @param classFieldName the name of the field containing the class assigned to documents
-   * @param analyzer       the analyzer used to tokenize / filter the unseen text
-   * @param query          the query to filter which documents use for training
-   * @throws IOException If there is a low-level I/O error.
-   */
-  public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
-      throws IOException;
+  List<ClassificationResult<T>> getClasses(String text, int max) throws IOException;
 
 }

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/KNearestNeighborClassifier.java Thu Apr 30 14:12:03 2015
@@ -26,6 +26,7 @@ import java.util.Map;
 
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.StorableField;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.queries.mlt.MoreLikeThis;
 import org.apache.lucene.search.BooleanClause;
@@ -45,37 +46,31 @@ import org.apache.lucene.util.BytesRef;
  */
 public class KNearestNeighborClassifier implements Classifier<BytesRef> {
 
-  private MoreLikeThis mlt;
-  private String[] textFieldNames;
-  private String classFieldName;
-  private IndexSearcher indexSearcher;
+  private final MoreLikeThis mlt;
+  private final String[] textFieldNames;
+  private final String classFieldName;
+  private final IndexSearcher indexSearcher;
   private final int k;
-  private Query query;
+  private final Query query;
 
-  private int minDocsFreq;
-  private int minTermFreq;
-
-  /**
-   * Create a {@link Classifier} using kNN algorithm
-   *
-   * @param k the number of neighbors to analyze as an <code>int</code>
-   */
-  public KNearestNeighborClassifier(int k) {
+  public KNearestNeighborClassifier(LeafReader leafReader, Analyzer analyzer, Query query, int k, int minDocsFreq,
+                                    int minTermFreq, String classFieldName, String... textFieldNames) {
+    this.textFieldNames = textFieldNames;
+    this.classFieldName = classFieldName;
+    this.mlt = new MoreLikeThis(leafReader);
+    this.mlt.setAnalyzer(analyzer);
+    this.mlt.setFieldNames(textFieldNames);
+    this.indexSearcher = new IndexSearcher(leafReader);
+    if (minDocsFreq > 0) {
+      mlt.setMinDocFreq(minDocsFreq);
+    }
+    if (minTermFreq > 0) {
+      mlt.setMinTermFreq(minTermFreq);
+    }
+    this.query = query;
     this.k = k;
   }
 
-  /**
-   * Create a {@link Classifier} using kNN algorithm
-   *
-   * @param k           the number of neighbors to analyze as an <code>int</code>
-   * @param minDocsFreq the minimum number of docs frequency for MLT to be set with {@link MoreLikeThis#setMinDocFreq(int)}
-   * @param minTermFreq the minimum number of term frequency for MLT to be set with {@link MoreLikeThis#setMinTermFreq(int)}
-   */
-  public KNearestNeighborClassifier(int k, int minDocsFreq, int minTermFreq) {
-    this.k = k;
-    this.minDocsFreq = minDocsFreq;
-    this.minTermFreq = minTermFreq;
-  }
 
   /**
    * {@inheritDoc}
@@ -136,12 +131,15 @@ public class KNearestNeighborClassifier
   private List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs topDocs) throws IOException {
     Map<BytesRef, Integer> classCounts = new HashMap<>();
     for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
-      BytesRef cl = new BytesRef(indexSearcher.doc(scoreDoc.doc).getField(classFieldName).stringValue());
-      Integer count = classCounts.get(cl);
-      if (count != null) {
-        classCounts.put(cl, count + 1);
-      } else {
-        classCounts.put(cl, 1);
+      StorableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
+      if (storableField != null) {
+        BytesRef cl = new BytesRef(storableField.stringValue());
+        Integer count = classCounts.get(cl);
+        if (count != null) {
+          classCounts.put(cl, count + 1);
+        } else {
+          classCounts.put(cl, 1);
+        }
       }
     }
     List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
@@ -161,39 +159,4 @@ public class KNearestNeighborClassifier
     return returnList;
   }
 
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
-    train(leafReader, textFieldName, classFieldName, analyzer, null);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query) throws IOException {
-    train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query) throws IOException {
-    this.textFieldNames = textFieldNames;
-    this.classFieldName = classFieldName;
-    mlt = new MoreLikeThis(leafReader);
-    mlt.setAnalyzer(analyzer);
-    mlt.setFieldNames(textFieldNames);
-    indexSearcher = new IndexSearcher(leafReader);
-    if (minDocsFreq > 0) {
-      mlt.setMinDocFreq(minDocsFreq);
-    }
-    if (minTermFreq > 0) {
-      mlt.setMinTermFreq(minTermFreq);
-    }
-    this.query = query;
-  }
 }

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/SimpleNaiveBayesClassifier.java Thu Apr 30 14:12:03 2015
@@ -51,64 +51,38 @@ public class SimpleNaiveBayesClassifier
    * {@link org.apache.lucene.index.LeafReader} used to access the {@link org.apache.lucene.classification.Classifier}'s
    * index
    */
-  protected LeafReader leafReader;
+  protected final LeafReader leafReader;
 
   /**
    * names of the fields to be used as input text
    */
-  protected String[] textFieldNames;
+  protected final String[] textFieldNames;
 
   /**
    * name of the field to be used as a class / category output
    */
-  protected String classFieldName;
+  protected final String classFieldName;
 
   /**
    * {@link org.apache.lucene.analysis.Analyzer} to be used for tokenizing unseen input text
    */
-  protected Analyzer analyzer;
+  protected final Analyzer analyzer;
 
   /**
    * {@link org.apache.lucene.search.IndexSearcher} to run searches on the index for retrieving frequencies
    */
-  protected IndexSearcher indexSearcher;
+  protected final IndexSearcher indexSearcher;
 
   /**
    * {@link org.apache.lucene.search.Query} used to eventually filter the document set to be used to classify
    */
-  protected Query query;
+  protected final Query query;
 
   /**
    * Creates a new NaiveBayes classifier.
-   * Note that you must call {@link #train(org.apache.lucene.index.LeafReader, String, String, Analyzer) train()} before you can
    * classify any documents.
    */
-  public SimpleNaiveBayesClassifier() {
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer) throws IOException {
-    train(leafReader, textFieldName, classFieldName, analyzer, null);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String textFieldName, String classFieldName, Analyzer analyzer, Query query)
-      throws IOException {
-    train(leafReader, new String[]{textFieldName}, classFieldName, analyzer, query);
-  }
-
-  /**
-   * {@inheritDoc}
-   */
-  @Override
-  public void train(LeafReader leafReader, String[] textFieldNames, String classFieldName, Analyzer analyzer, Query query)
-      throws IOException {
+  public SimpleNaiveBayesClassifier(LeafReader leafReader, Analyzer analyzer, Query query, String classFieldName, String... textFieldNames) {
     this.leafReader = leafReader;
     this.indexSearcher = new IndexSearcher(this.leafReader);
     this.textFieldNames = textFieldNames;

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/package-info.java Thu Apr 30 14:12:03 2015
@@ -18,7 +18,7 @@
 /**
  * Uses already seen data (the indexed documents) to classify new documents.
  * <p>
- * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest 
+ * Currently contains a (simplistic) Naive Bayes classifier, a k-Nearest
  * Neighbor classifier and a Perceptron based classifier.
  */
 package org.apache.lucene.classification;

Modified: lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java (original)
+++ lucene/dev/trunk/lucene/classification/src/java/org/apache/lucene/classification/utils/DocToDoubleVectorUtils.java Thu Apr 30 14:12:03 2015
@@ -33,7 +33,8 @@ public class DocToDoubleVectorUtils {
 
   /**
    * create a sparse <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
-   * @param docTerms term vectors for a given document
+   *
+   * @param docTerms   term vectors for a given document
    * @param fieldTerms field term vectors
    * @return a sparse vector of <code>Double</code>s as an array
    * @throws IOException in case accessing the underlying index fails
@@ -54,8 +55,7 @@ public class DocToDoubleVectorUtils {
         if (seekStatus.equals(TermsEnum.SeekStatus.FOUND)) {
           long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
           freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
-        }
-        else {
+        } else {
           freqVector[i] = 0d;
         }
         i++;
@@ -66,6 +66,7 @@ public class DocToDoubleVectorUtils {
 
   /**
    * create a dense <code>Double</code> vector given doc and field term vectors using local frequency of the terms in the doc
+   *
    * @param docTerms term vectors for a given document
    * @return a dense vector of <code>Double</code>s as an array
    * @throws IOException in case accessing the underlying index fails
@@ -73,16 +74,16 @@ public class DocToDoubleVectorUtils {
   public static Double[] toDenseLocalFreqDoubleArray(Terms docTerms) throws IOException {
     Double[] freqVector = null;
     if (docTerms != null) {
-        freqVector = new Double[(int) docTerms.size()];
-        int i = 0;
-        TermsEnum docTermsEnum = docTerms.iterator();
+      freqVector = new Double[(int) docTerms.size()];
+      int i = 0;
+      TermsEnum docTermsEnum = docTerms.iterator();
 
-        while (docTermsEnum.next() != null) {
-            long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
-            freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
-            i++;
-        }
+      while (docTermsEnum.next() != null) {
+        long termFreqLocal = docTermsEnum.totalTermFreq(); // the total number of occurrences of this term in the given document
+        freqVector[i] = Long.valueOf(termFreqLocal).doubleValue();
+        i++;
+      }
     }
     return freqVector;
-}
+  }
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/BooleanPerceptronClassifierTest.java Thu Apr 30 14:12:03 2015
@@ -17,6 +17,8 @@
 package org.apache.lucene.classification;
 
 import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.SlowCompositeReaderWrapper;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
 import org.junit.Test;
@@ -28,22 +30,45 @@ public class BooleanPerceptronClassifier
 
   @Test
   public void testBasicUsage() throws Exception {
-    checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, null), TECHNOLOGY_INPUT, false);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   @Test
   public void testExplicitThreshold() throws Exception {
-    checkCorrectClassification(new BooleanPerceptronClassifier(100d, 1), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName);
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, null, 1, 100d), TECHNOLOGY_INPUT, false);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   @Test
   public void testBasicUsageWithQuery() throws Exception {
-    checkCorrectClassification(new BooleanPerceptronClassifier(), TECHNOLOGY_INPUT, false, new MockAnalyzer(random()), textFieldName, booleanFieldName, new TermQuery(new Term(textFieldName, "it")));
-  }
-
-  @Test
-  public void testPerformance() throws Exception {
-    checkPerformance(new BooleanPerceptronClassifier(), new MockAnalyzer(random()), booleanFieldName);
+    TermQuery query = new TermQuery(new Term(textFieldName, "it"));
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new BooleanPerceptronClassifier(leafReader, textFieldName, booleanFieldName, analyzer, query, 1, null), TECHNOLOGY_INPUT, false);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/CachingNaiveBayesClassifierTest.java Thu Apr 30 14:12:03 2015
@@ -23,6 +23,8 @@ import org.apache.lucene.analysis.Tokeni
 import org.apache.lucene.analysis.core.KeywordTokenizer;
 import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
 import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.SlowCompositeReaderWrapper;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.util.BytesRef;
@@ -35,18 +37,46 @@ public class CachingNaiveBayesClassifier
 
   @Test
   public void testBasicUsage() throws Exception {
-    checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
-    checkCorrectClassification(new CachingNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+      checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   @Test
   public void testBasicUsageWithQuery() throws Exception {
-    checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
+      checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   @Test
   public void testNGramUsage() throws Exception {
-    checkCorrectClassification(new CachingNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
+    LeafReader leafReader = null;
+    try {
+      NGramAnalyzer analyzer = new NGramAnalyzer();
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   private class NGramAnalyzer extends Analyzer {
@@ -57,9 +87,4 @@ public class CachingNaiveBayesClassifier
     }
   }
 
-  @Test
-  public void testPerformance() throws Exception {
-    checkPerformance(new CachingNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
-  }
-
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/ClassificationTestBase.java Thu Apr 30 14:12:03 2015
@@ -41,14 +41,14 @@ import org.junit.Before;
  */
 public abstract class ClassificationTestBase<T> extends LuceneTestCase {
   public final static String POLITICS_INPUT = "Here are some interesting questions and answers about Mitt Romney.. " +
-      "If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
+          "If you don't know the answer to the question about Mitt Romney, then simply click on the answer below the question section.";
   public static final BytesRef POLITICS_RESULT = new BytesRef("politics");
 
   public static final String TECHNOLOGY_INPUT = "Much is made of what the likes of Facebook, Google and Apple know about users." +
-      " Truth is, Amazon may know more.";
+          " Truth is, Amazon may know more.";
   public static final BytesRef TECHNOLOGY_RESULT = new BytesRef("technology");
 
-  private RandomIndexWriter indexWriter;
+  protected RandomIndexWriter indexWriter;
   private Directory dir;
   private FieldType ft;
 
@@ -79,53 +79,34 @@ public abstract class ClassificationTest
     dir.close();
   }
 
-  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
-    checkCorrectClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
+  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult) throws Exception {
+    ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
+    assertNotNull(classificationResult.getAssignedClass());
+    assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
+    double score = classificationResult.getScore();
+    assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
   }
 
-  protected void checkCorrectClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
-    LeafReader leafReader = null;
-    try {
-      populateSampleIndex(analyzer);
-      leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
-      classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
-      ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
-      assertNotNull(classificationResult.getAssignedClass());
-      assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
-      double score = classificationResult.getScore();
-      assertTrue("score should be between 0 and 1, got:" + score, score <= 1 && score >= 0);
-    } finally {
-      if (leafReader != null)
-        leafReader.close();
-    }
-  }
   protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName) throws Exception {
     checkOnlineClassification(classifier, inputDoc, expectedResult, analyzer, textFieldName, classFieldName, null);
   }
 
   protected void checkOnlineClassification(Classifier<T> classifier, String inputDoc, T expectedResult, Analyzer analyzer, String textFieldName, String classFieldName, Query query) throws Exception {
-    LeafReader leafReader = null;
-    try {
-      populateSampleIndex(analyzer);
-      leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
-      classifier.train(leafReader, textFieldName, classFieldName, analyzer, query);
-      ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
-      assertNotNull(classificationResult.getAssignedClass());
-      assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
-      double score = classificationResult.getScore();
-      assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
-      updateSampleIndex();
-      ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
-      assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
-      assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
-
-    } finally {
-      if (leafReader != null)
-        leafReader.close();
-    }
+    populateSampleIndex(analyzer);
+
+    ClassificationResult<T> classificationResult = classifier.assignClass(inputDoc);
+    assertNotNull(classificationResult.getAssignedClass());
+    assertEquals("got an assigned class of " + classificationResult.getAssignedClass(), expectedResult, classificationResult.getAssignedClass());
+    double score = classificationResult.getScore();
+    assertTrue("score should be between 0 and 1, got: " + score, score <= 1 && score >= 0);
+    updateSampleIndex();
+    ClassificationResult<T> secondClassificationResult = classifier.assignClass(inputDoc);
+    assertEquals(classificationResult.getAssignedClass(), secondClassificationResult.getAssignedClass());
+    assertEquals(Double.valueOf(score), Double.valueOf(secondClassificationResult.getScore()));
+
   }
 
-  private void populateSampleIndex(Analyzer analyzer) throws IOException {
+  protected LeafReader populateSampleIndex(Analyzer analyzer) throws IOException {
     indexWriter.close();
     indexWriter = new RandomIndexWriter(random(), dir, newIndexWriterConfig(analyzer).setOpenMode(IndexWriterConfig.OpenMode.CREATE));
     indexWriter.commit();
@@ -134,8 +115,8 @@ public abstract class ClassificationTest
 
     Document doc = new Document();
     text = "The traveling press secretary for Mitt Romney lost his cool and cursed at reporters " +
-        "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
-        "the Unknown Soldier in Warsaw Tuesday.";
+            "who attempted to ask questions of the Republican presidential candidate in a public plaza near the Tomb of " +
+            "the Unknown Soldier in Warsaw Tuesday.";
     doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
     doc.add(new Field(booleanFieldName, "true", ft));
@@ -144,7 +125,7 @@ public abstract class ClassificationTest
 
     doc = new Document();
     text = "Mitt Romney seeks to assure Israel and Iran, as well as Jewish voters in the United" +
-        " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
+            " States, that he will be tougher against Iran's nuclear ambitions than President Barack Obama.";
     doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
     doc.add(new Field(booleanFieldName, "true", ft));
@@ -152,8 +133,8 @@ public abstract class ClassificationTest
 
     doc = new Document();
     text = "And there's a threshold question that he has to answer for the American people and " +
-        "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
-        "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
+            "that's whether he is prepared to be commander-in-chief,\" she continued. \"As we look to the past events, we " +
+            "know that this raises some questions about his preparedness and we'll see how the rest of his trip goes.\"";
     doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
     doc.add(new Field(booleanFieldName, "true", ft));
@@ -161,8 +142,8 @@ public abstract class ClassificationTest
 
     doc = new Document();
     text = "Still, when it comes to gun policy, many congressional Democrats have \"decided to " +
-        "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
-        "Albany's School of Criminal Justice.";
+            "keep quiet and not go there,\" said Alan Lizotte, dean and professor at the State University of New York at " +
+            "Albany's School of Criminal Justice.";
     doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "politics", ft));
     doc.add(new Field(booleanFieldName, "true", ft));
@@ -170,8 +151,8 @@ public abstract class ClassificationTest
 
     doc = new Document();
     text = "Standing amongst the thousands of people at the state Capitol, Jorstad, director of " +
-        "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
-        "world through the Internet.";
+            "technology at the University of Wisconsin-La Crosse, documented the historic moment and shared it with the " +
+            "world through the Internet.";
     doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
     doc.add(new Field(booleanFieldName, "false", ft));
@@ -179,7 +160,7 @@ public abstract class ClassificationTest
 
     doc = new Document();
     text = "So, about all those experts and analysts who've spent the past year or so saying " +
-        "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
+            "Facebook was going to make a phone. A new expert has stepped forward to say it's not going to happen.";
     doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
     doc.add(new Field(booleanFieldName, "false", ft));
@@ -187,8 +168,8 @@ public abstract class ClassificationTest
 
     doc = new Document();
     text = "More than 400 million people trust Google with their e-mail, and 50 million store files" +
-        " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
-        "generally transfer or store huge volumes of personal data online.";
+            " in the cloud using the Dropbox service. People manage their bank accounts, pay bills, trade stocks and " +
+            "generally transfer or store huge volumes of personal data online.";
     doc.add(new Field(textFieldName, text, ft));
     doc.add(new Field(categoryFieldName, "technology", ft));
     doc.add(new Field(booleanFieldName, "false", ft));
@@ -200,22 +181,15 @@ public abstract class ClassificationTest
     indexWriter.addDocument(doc);
 
     indexWriter.commit();
+    return SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
   }
 
   protected void checkPerformance(Classifier<T> classifier, Analyzer analyzer, String classFieldName) throws Exception {
-    LeafReader leafReader = null;
     long trainStart = System.currentTimeMillis();
-    try {
-      populatePerformanceIndex(analyzer);
-      leafReader = SlowCompositeReaderWrapper.wrap(indexWriter.getReader());
-      classifier.train(leafReader, textFieldName, classFieldName, analyzer);
-      long trainEnd = System.currentTimeMillis();
-      long trainTime = trainEnd - trainStart;
-      assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
-    } finally {
-      if (leafReader != null)
-        leafReader.close();
-    }
+    populatePerformanceIndex(analyzer);
+    long trainEnd = System.currentTimeMillis();
+    long trainTime = trainEnd - trainStart;
+    assertTrue("training took more than 2 mins : " + trainTime / 1000 + "s", trainTime < 120000);
   }
 
   private void populatePerformanceIndex(Analyzer analyzer) throws IOException {

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/KNearestNeighborClassifierTest.java Thu Apr 30 14:12:03 2015
@@ -17,6 +17,8 @@
 package org.apache.lucene.classification;
 
 import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.SlowCompositeReaderWrapper;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.util.BytesRef;
@@ -29,20 +31,32 @@ public class KNearestNeighborClassifierT
 
   @Test
   public void testBasicUsage() throws Exception {
-    // usage with default MLT min docs / term freq
-    checkCorrectClassification(new KNearestNeighborClassifier(3), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
-    // usage without custom min docs / term freq for MLT
-    checkCorrectClassification(new KNearestNeighborClassifier(3, 2, 1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+      checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, null, 3, 2, 1, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   @Test
   public void testBasicUsageWithQuery() throws Exception {
-    checkCorrectClassification(new KNearestNeighborClassifier(1), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
-  }
-
-  @Test
-  public void testPerformance() throws Exception {
-    checkPerformance(new KNearestNeighborClassifier(100), new MockAnalyzer(random()), categoryFieldName);
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
+      checkCorrectClassification(new KNearestNeighborClassifier(leafReader, analyzer, query, 1, 0, 0, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
 }

Modified: lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java?rev=1676997&r1=1676996&r2=1676997&view=diff
==============================================================================
--- lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java (original)
+++ lucene/dev/trunk/lucene/classification/src/test/org/apache/lucene/classification/SimpleNaiveBayesClassifierTest.java Thu Apr 30 14:12:03 2015
@@ -22,14 +22,13 @@ import org.apache.lucene.analysis.Tokeni
 import org.apache.lucene.analysis.core.KeywordTokenizer;
 import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
 import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.SlowCompositeReaderWrapper;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.LuceneTestCase;
 import org.junit.Test;
 
-import java.io.Reader;
-
 /**
  * Testcase for {@link SimpleNaiveBayesClassifier}
  */
@@ -37,18 +36,46 @@ public class SimpleNaiveBayesClassifierT
 
   @Test
   public void testBasicUsage() throws Exception {
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), POLITICS_INPUT, POLITICS_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName);
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+      checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), POLITICS_INPUT, POLITICS_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   @Test
   public void testBasicUsageWithQuery() throws Exception {
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new MockAnalyzer(random()), textFieldName, categoryFieldName, new TermQuery(new Term(textFieldName, "it")));
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = populateSampleIndex(analyzer);
+      TermQuery query = new TermQuery(new Term(textFieldName, "it"));
+      checkCorrectClassification(new SimpleNaiveBayesClassifier(leafReader, analyzer, query, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   @Test
   public void testNGramUsage() throws Exception {
-    checkCorrectClassification(new SimpleNaiveBayesClassifier(), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT, new NGramAnalyzer(), textFieldName, categoryFieldName);
+    LeafReader leafReader = null;
+    try {
+      Analyzer analyzer = new NGramAnalyzer();
+      leafReader = populateSampleIndex(analyzer);
+      checkCorrectClassification(new CachingNaiveBayesClassifier(leafReader, analyzer, null, categoryFieldName, textFieldName), TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
   }
 
   private class NGramAnalyzer extends Analyzer {
@@ -59,9 +86,4 @@ public class SimpleNaiveBayesClassifierT
     }
   }
 
-  @Test
-  public void testPerformance() throws Exception {
-    checkPerformance(new SimpleNaiveBayesClassifier(), new MockAnalyzer(random()), categoryFieldName);
-  }
-
 }



Mime
View raw message