lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tflo...@apache.org
Subject [34/50] [abbrv] lucene-solr:jira/solr-10233: LUCENE-7838 - added knn classifier based on flt
Date Fri, 19 May 2017 00:13:55 GMT
LUCENE-7838 - added knn classifier based on flt


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/bd9e32d3
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/bd9e32d3
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/bd9e32d3

Branch: refs/heads/jira/solr-10233
Commit: bd9e32d358399af7c31e732314e1ef1dd89bcfa1
Parents: afd70a4
Author: Tommaso Teofili <tommaso@apache.org>
Authored: Thu May 18 14:35:53 2017 +0200
Committer: Tommaso Teofili <tommaso@apache.org>
Committed: Thu May 18 14:36:18 2017 +0200

----------------------------------------------------------------------
 .../lucene/classification/classification.iml    |   3 +-
 lucene/classification/build.xml                 |   8 +-
 .../classification/KNearestFuzzyClassifier.java | 225 +++++++++++++++++++
 .../classification/utils/DatasetSplitter.java   |   2 +-
 .../KNearestFuzzyClassifierTest.java            | 124 ++++++++++
 .../utils/ConfusionMatrixGeneratorTest.java     | 123 +++++-----
 6 files changed, 417 insertions(+), 68 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/dev-tools/idea/lucene/classification/classification.iml
----------------------------------------------------------------------
diff --git a/dev-tools/idea/lucene/classification/classification.iml b/dev-tools/idea/lucene/classification/classification.iml
index 0f20274..44af1e4 100644
--- a/dev-tools/idea/lucene/classification/classification.iml
+++ b/dev-tools/idea/lucene/classification/classification.iml
@@ -16,8 +16,9 @@
     <orderEntry type="module" scope="TEST" module-name="lucene-test-framework" />
     <orderEntry type="module" module-name="lucene-core" />
     <orderEntry type="module" module-name="queries" />
-    <orderEntry type="module" scope="TEST" module-name="analysis-common" />
+    <orderEntry type="module" module-name="analysis-common" />
     <orderEntry type="module" module-name="grouping" />
     <orderEntry type="module" module-name="misc" />
+    <orderEntry type="module" module-name="sandbox" />
   </component>
 </module>

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/build.xml
----------------------------------------------------------------------
diff --git a/lucene/classification/build.xml b/lucene/classification/build.xml
index 704cae8..b3f1bfd 100644
--- a/lucene/classification/build.xml
+++ b/lucene/classification/build.xml
@@ -28,6 +28,8 @@
     <path refid="base.classpath"/>
     <pathelement path="${queries.jar}"/>
     <pathelement path="${grouping.jar}"/>
+    <pathelement path="${sandbox.jar}"/>
+    <pathelement path="${analyzers-common.jar}"/>
   </path>
 
   <path id="test.classpath">
@@ -36,16 +38,18 @@
     <path refid="test.base.classpath"/>
   </path>
 
-  <target name="compile-core" depends="jar-grouping,jar-queries,jar-analyzers-common,common.compile-core"
/>
+  <target name="compile-core" depends="jar-sandbox,jar-grouping,jar-queries,jar-analyzers-common,common.compile-core"
/>
 
   <target name="jar-core" depends="common.jar-core" />
 
-  <target name="javadocs" depends="javadocs-grouping,compile-core,check-javadocs-uptodate"
+  <target name="javadocs" depends="javadocs-sandbox,javadocs-grouping,compile-core,check-javadocs-uptodate"
           unless="javadocs-uptodate-${name}">
     <invoke-module-javadoc>
       <links>
         <link href="../queries"/>
+        <link href="../analyzers/common"/>
         <link href="../grouping"/>
+        <link href="../sandbox"/>
       </links>
     </invoke-module-javadoc>
   </target>

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java
b/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java
new file mode 100644
index 0000000..1cde468
--- /dev/null
+++ b/lucene/classification/src/java/org/apache/lucene/classification/KNearestFuzzyClassifier.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexableField;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.sandbox.queries.FuzzyLikeThisQuery;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.WildcardQuery;
+import org.apache.lucene.search.similarities.BM25Similarity;
+import org.apache.lucene.search.similarities.ClassicSimilarity;
+import org.apache.lucene.search.similarities.Similarity;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * A k-Nearest Neighbor classifier based on {@link FuzzyLikeThisQuery}.
+ *
+ * @lucene.experimental
+ */
+public class KNearestFuzzyClassifier implements Classifier<BytesRef> {
+
+  /**
+   * the name of the fields used as the input text
+   */
+  protected final String[] textFieldNames;
+
+  /**
+   * the name of the field used as the output text
+   */
+  protected final String classFieldName;
+
+  /**
+   * an {@link IndexSearcher} used to perform queries
+   */
+  protected final IndexSearcher indexSearcher;
+
+  /**
+   * the no. of docs to compare in order to find the nearest neighbor to the input text
+   */
+  protected final int k;
+
+  /**
+   * a {@link Query} used to filter the documents that should be used from this classifier's
underlying {@link LeafReader}
+   */
+  protected final Query query;
+  private final Analyzer analyzer;
+
+  /**
+   * Creates a {@link KNearestFuzzyClassifier}.
+   *
+   * @param indexReader    the reader on the index to be used for classification
+   * @param analyzer       an {@link Analyzer} used to analyze unseen text
+   * @param similarity     the {@link Similarity} to be used by the underlying {@link IndexSearcher}
or {@code null}
+   *                       (defaults to {@link BM25Similarity})
+   * @param query          a {@link Query} to eventually filter the docs used for training
the classifier, or {@code null}
+   *                       if all the indexed docs should be used
+   * @param k              the no. of docs to select in the MLT results to find the nearest
neighbor
+   * @param classFieldName the name of the field used as the output for the classifier
+   * @param textFieldNames the name of the fields used as the inputs for the classifier,
they can contain boosting indication e.g. title^10
+   */
+  public KNearestFuzzyClassifier(IndexReader indexReader, Similarity similarity, Analyzer
analyzer, Query query, int k,
+                                 String classFieldName, String... textFieldNames) {
+    this.textFieldNames = textFieldNames;
+    this.classFieldName = classFieldName;
+    this.analyzer = analyzer;
+    this.indexSearcher = new IndexSearcher(indexReader);
+    if (similarity != null) {
+      this.indexSearcher.setSimilarity(similarity);
+    } else {
+      this.indexSearcher.setSimilarity(new BM25Similarity());
+    }
+    this.query = query;
+    this.k = k;
+  }
+
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public ClassificationResult<BytesRef> assignClass(String text) throws IOException
{
+    TopDocs knnResults = knnSearch(text);
+    List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+    ClassificationResult<BytesRef> assignedClass = null;
+    double maxscore = -Double.MAX_VALUE;
+    for (ClassificationResult<BytesRef> cl : assignedClasses) {
+      if (cl.getScore() > maxscore) {
+        assignedClass = cl;
+        maxscore = cl.getScore();
+      }
+    }
+    return assignedClass;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text) throws
IOException {
+    TopDocs knnResults = knnSearch(text);
+    List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+    Collections.sort(assignedClasses);
+    return assignedClasses;
+  }
+
+  /**
+   * {@inheritDoc}
+   */
+  @Override
+  public List<ClassificationResult<BytesRef>> getClasses(String text, int max)
throws IOException {
+    TopDocs knnResults = knnSearch(text);
+    List<ClassificationResult<BytesRef>> assignedClasses = buildListFromTopDocs(knnResults);
+    Collections.sort(assignedClasses);
+    return assignedClasses.subList(0, max);
+  }
+
+  private TopDocs knnSearch(String text) throws IOException {
+    BooleanQuery.Builder bq = new BooleanQuery.Builder();
+    FuzzyLikeThisQuery fuzzyLikeThisQuery = new FuzzyLikeThisQuery(300, analyzer);
+    for (String fieldName : textFieldNames) {
+      fuzzyLikeThisQuery.addTerms(text, fieldName, 1f, 2); // TODO: make this parameters
configurable
+    }
+    bq.add(fuzzyLikeThisQuery, BooleanClause.Occur.MUST);
+    Query classFieldQuery = new WildcardQuery(new Term(classFieldName, "*"));
+    bq.add(new BooleanClause(classFieldQuery, BooleanClause.Occur.MUST));
+    if (query != null) {
+      bq.add(query, BooleanClause.Occur.MUST);
+    }
+    return indexSearcher.search(bq.build(), k);
+  }
+
+  /**
+   * build a list of classification results from search results
+   *
+   * @param topDocs the search results as a {@link TopDocs} object
+   * @return a {@link List} of {@link ClassificationResult}, one for each existing class
+   * @throws IOException if it's not possible to get the stored value of class field
+   */
+  protected List<ClassificationResult<BytesRef>> buildListFromTopDocs(TopDocs
topDocs) throws IOException {
+    Map<BytesRef, Integer> classCounts = new HashMap<>();
+    Map<BytesRef, Double> classBoosts = new HashMap<>(); // this is a boost based
on class ranking positions in topDocs
+    float maxScore = topDocs.getMaxScore();
+    for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
+      IndexableField storableField = indexSearcher.doc(scoreDoc.doc).getField(classFieldName);
+      if (storableField != null) {
+        BytesRef cl = new BytesRef(storableField.stringValue());
+        //update count
+        Integer count = classCounts.get(cl);
+        if (count != null) {
+          classCounts.put(cl, count + 1);
+        } else {
+          classCounts.put(cl, 1);
+        }
+        //update boost, the boost is based on the best score
+        Double totalBoost = classBoosts.get(cl);
+        double singleBoost = scoreDoc.score / maxScore;
+        if (totalBoost != null) {
+          classBoosts.put(cl, totalBoost + singleBoost);
+        } else {
+          classBoosts.put(cl, singleBoost);
+        }
+      }
+    }
+    List<ClassificationResult<BytesRef>> returnList = new ArrayList<>();
+    List<ClassificationResult<BytesRef>> temporaryList = new ArrayList<>();
+    int sumdoc = 0;
+    for (Map.Entry<BytesRef, Integer> entry : classCounts.entrySet()) {
+      Integer count = entry.getValue();
+      Double normBoost = classBoosts.get(entry.getKey()) / count; //the boost is normalized
to be 0<b<1
+      temporaryList.add(new ClassificationResult<>(entry.getKey().clone(), (count *
normBoost) / (double) k));
+      sumdoc += count;
+    }
+
+    //correction
+    if (sumdoc < k) {
+      for (ClassificationResult<BytesRef> cr : temporaryList) {
+        returnList.add(new ClassificationResult<>(cr.getAssignedClass(), cr.getScore()
* k / (double) sumdoc));
+      }
+    } else {
+      returnList = temporaryList;
+    }
+    return returnList;
+  }
+
+  @Override
+  public String toString() {
+    return "KNearestFuzzyClassifier{" +
+        "textFieldNames=" + Arrays.toString(textFieldNames) +
+        ", classFieldName='" + classFieldName + '\'' +
+        ", k=" + k +
+        ", query=" + query +
+        ", similarity=" + indexSearcher.getSimilarity(true) +
+        '}';
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
index 7ab674e..913fb7f 100644
--- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
+++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java
@@ -121,7 +121,7 @@ public class DatasetSplitter {
       int b = 0;
 
       // iterate over existing documents
-      for (GroupDocs group : topGroups.groups) {
+      for (GroupDocs<Object> group : topGroups.groups) {
         int totalHits = group.totalHits;
         double testSize = totalHits * testRatio;
         int tc = 0;

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java
b/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java
new file mode 100644
index 0000000..6e4c404
--- /dev/null
+++ b/lucene/classification/src/test/org/apache/lucene/classification/KNearestFuzzyClassifierTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.classification;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.analysis.Tokenizer;
+import org.apache.lucene.analysis.core.KeywordTokenizer;
+import org.apache.lucene.analysis.ngram.EdgeNGramTokenFilter;
+import org.apache.lucene.analysis.reverse.ReverseStringFilter;
+import org.apache.lucene.classification.utils.ConfusionMatrixGenerator;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.util.BytesRef;
+import org.junit.Test;
+
+/**
+ * Testcase for {@link KNearestFuzzyClassifier}
+ */
+public class KNearestFuzzyClassifierTest extends ClassificationTestBase<BytesRef> {
+
+  @Test
+  public void testBasicUsage() throws Exception {
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = getSampleIndex(analyzer);
+      Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null,
analyzer, null, 3, categoryFieldName, textFieldName);
+      checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+      checkCorrectClassification(classifier, POLITICS_INPUT, POLITICS_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testBasicUsageWithQuery() throws Exception {
+    LeafReader leafReader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      leafReader = getSampleIndex(analyzer);
+      TermQuery query = new TermQuery(new Term(textFieldName, "not"));
+      Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null,
analyzer, query, 3, categoryFieldName, textFieldName);
+      checkCorrectClassification(classifier, TECHNOLOGY_INPUT, TECHNOLOGY_RESULT);
+    } finally {
+      if (leafReader != null) {
+        leafReader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testPerformance() throws Exception {
+    MockAnalyzer analyzer = new MockAnalyzer(random());
+    LeafReader leafReader = getRandomIndex(analyzer, 100);
+    try {
+      long trainStart = System.currentTimeMillis();
+      Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(leafReader, null,
analyzer, null, 3, categoryFieldName, textFieldName);
+      long trainEnd = System.currentTimeMillis();
+      long trainTime = trainEnd - trainStart;
+      assertTrue("training took more than 10s: " + trainTime / 1000 + "s", trainTime <
10000);
+
+      long evaluationStart = System.currentTimeMillis();
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(leafReader,
+          classifier, categoryFieldName, textFieldName, -1);
+      assertNotNull(confusionMatrix);
+      long evaluationEnd = System.currentTimeMillis();
+      long evaluationTime = evaluationEnd - evaluationStart;
+      assertTrue("evaluation took more than 2m: " + evaluationTime / 1000 + "s", evaluationTime
< 120000);
+      double avgClassificationTime = confusionMatrix.getAvgClassificationTime();
+      assertTrue(5000 > avgClassificationTime);
+      double accuracy = confusionMatrix.getAccuracy();
+      assertTrue(accuracy >= 0d);
+      assertTrue(accuracy <= 1d);
+
+      double recall = confusionMatrix.getRecall();
+      assertTrue(recall >= 0d);
+      assertTrue(recall <= 1d);
+
+      double precision = confusionMatrix.getPrecision();
+      assertTrue(precision >= 0d);
+      assertTrue(precision <= 1d);
+
+      Terms terms = MultiFields.getTerms(leafReader, categoryFieldName);
+      TermsEnum iterator = terms.iterator();
+      BytesRef term;
+      while ((term = iterator.next()) != null) {
+        String s = term.utf8ToString();
+        recall = confusionMatrix.getRecall(s);
+        assertTrue(recall >= 0d);
+        assertTrue(recall <= 1d);
+        precision = confusionMatrix.getPrecision(s);
+        assertTrue(precision >= 0d);
+        assertTrue(precision <= 1d);
+        double f1Measure = confusionMatrix.getF1Measure(s);
+        assertTrue(f1Measure >= 0d);
+        assertTrue(f1Measure <= 1d);
+      }
+    } finally {
+      leafReader.close();
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/bd9e32d3/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
----------------------------------------------------------------------
diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
index 63cce2a..edb76b5 100644
--- a/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
+++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/ConfusionMatrixGeneratorTest.java
@@ -21,11 +21,13 @@ import java.io.IOException;
 import java.util.List;
 
 import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.classification.BM25NBClassifier;
 import org.apache.lucene.classification.BooleanPerceptronClassifier;
 import org.apache.lucene.classification.CachingNaiveBayesClassifier;
 import org.apache.lucene.classification.ClassificationResult;
 import org.apache.lucene.classification.ClassificationTestBase;
 import org.apache.lucene.classification.Classifier;
+import org.apache.lucene.classification.KNearestFuzzyClassifier;
 import org.apache.lucene.classification.KNearestNeighborClassifier;
 import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
 import org.apache.lucene.index.LeafReader;
@@ -94,22 +96,43 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       Classifier<BytesRef> classifier = new SimpleNaiveBayesClassifier(reader, analyzer,
null, categoryFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
           classifier, categoryFieldName, textFieldName, -1);
-      assertNotNull(confusionMatrix);
-      assertNotNull(confusionMatrix.getLinearizedMatrix());
-      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
-      assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      double accuracy = confusionMatrix.getAccuracy();
-      assertTrue(accuracy >= 0d);
-      assertTrue(accuracy <= 1d);
-      double precision = confusionMatrix.getPrecision();
-      assertTrue(precision >= 0d);
-      assertTrue(precision <= 1d);
-      double recall = confusionMatrix.getRecall();
-      assertTrue(recall >= 0d);
-      assertTrue(recall <= 1d);
-      double f1Measure = confusionMatrix.getF1Measure();
-      assertTrue(f1Measure >= 0d);
-      assertTrue(f1Measure <= 1d);
+      checkCM(confusionMatrix);
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+  }
+
+  private void checkCM(ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix) {
+    assertNotNull(confusionMatrix);
+    assertNotNull(confusionMatrix.getLinearizedMatrix());
+    assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
+    assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
+    double accuracy = confusionMatrix.getAccuracy();
+    assertTrue(accuracy >= 0d);
+    assertTrue(accuracy <= 1d);
+    double precision = confusionMatrix.getPrecision();
+    assertTrue(precision >= 0d);
+    assertTrue(precision <= 1d);
+    double recall = confusionMatrix.getRecall();
+    assertTrue(recall >= 0d);
+    assertTrue(recall <= 1d);
+    double f1Measure = confusionMatrix.getF1Measure();
+    assertTrue(f1Measure >= 0d);
+    assertTrue(f1Measure <= 1d);
+  }
+
+  @Test
+  public void testGetConfusionMatrixWithBM25NB() throws Exception {
+    LeafReader reader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      reader = getSampleIndex(analyzer);
+      Classifier<BytesRef> classifier = new BM25NBClassifier(reader, analyzer, null,
categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+          classifier, categoryFieldName, textFieldName, -1);
+      checkCM(confusionMatrix);
     } finally {
       if (reader != null) {
         reader.close();
@@ -126,22 +149,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       Classifier<BytesRef> classifier = new CachingNaiveBayesClassifier(reader, analyzer,
null, categoryFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
           classifier, categoryFieldName, textFieldName, -1);
-      assertNotNull(confusionMatrix);
-      assertNotNull(confusionMatrix.getLinearizedMatrix());
-      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
-      assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      double accuracy = confusionMatrix.getAccuracy();
-      assertTrue(accuracy >= 0d);
-      assertTrue(accuracy <= 1d);
-      double precision = confusionMatrix.getPrecision();
-      assertTrue(precision >= 0d);
-      assertTrue(precision <= 1d);
-      double recall = confusionMatrix.getRecall();
-      assertTrue(recall >= 0d);
-      assertTrue(recall <= 1d);
-      double f1Measure = confusionMatrix.getF1Measure();
-      assertTrue(f1Measure >= 0d);
-      assertTrue(f1Measure <= 1d);
+      checkCM(confusionMatrix);
     } finally {
       if (reader != null) {
         reader.close();
@@ -158,22 +166,24 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       Classifier<BytesRef> classifier = new KNearestNeighborClassifier(reader, null,
analyzer, null, 1, 0, 0, categoryFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
           classifier, categoryFieldName, textFieldName, -1);
-      assertNotNull(confusionMatrix);
-      assertNotNull(confusionMatrix.getLinearizedMatrix());
-      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
-      assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      double accuracy = confusionMatrix.getAccuracy();
-      assertTrue(accuracy >= 0d);
-      assertTrue(accuracy <= 1d);
-      double precision = confusionMatrix.getPrecision();
-      assertTrue(precision >= 0d);
-      assertTrue(precision <= 1d);
-      double recall = confusionMatrix.getRecall();
-      assertTrue(recall >= 0d);
-      assertTrue(recall <= 1d);
-      double f1Measure = confusionMatrix.getF1Measure();
-      assertTrue(f1Measure >= 0d);
-      assertTrue(f1Measure <= 1d);
+      checkCM(confusionMatrix);
+    } finally {
+      if (reader != null) {
+        reader.close();
+      }
+    }
+  }
+
+  @Test
+  public void testGetConfusionMatrixWithFLTKNN() throws Exception {
+    LeafReader reader = null;
+    try {
+      MockAnalyzer analyzer = new MockAnalyzer(random());
+      reader = getSampleIndex(analyzer);
+      Classifier<BytesRef> classifier = new KNearestFuzzyClassifier(reader, null, analyzer,
null, 1, categoryFieldName, textFieldName);
+      ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
+          classifier, categoryFieldName, textFieldName, -1);
+      checkCM(confusionMatrix);
     } finally {
       if (reader != null) {
         reader.close();
@@ -190,22 +200,7 @@ public class ConfusionMatrixGeneratorTest extends ClassificationTestBase<Object>
       Classifier<Boolean> classifier = new BooleanPerceptronClassifier(reader, analyzer,
null, 1, null, booleanFieldName, textFieldName);
       ConfusionMatrixGenerator.ConfusionMatrix confusionMatrix = ConfusionMatrixGenerator.getConfusionMatrix(reader,
           classifier, booleanFieldName, textFieldName, -1);
-      assertNotNull(confusionMatrix);
-      assertNotNull(confusionMatrix.getLinearizedMatrix());
-      assertEquals(7, confusionMatrix.getNumberOfEvaluatedDocs());
-      assertTrue(confusionMatrix.getAvgClassificationTime() >= 0d);
-      double accuracy = confusionMatrix.getAccuracy();
-      assertTrue(accuracy >= 0d);
-      assertTrue(accuracy <= 1d);
-      double precision = confusionMatrix.getPrecision();
-      assertTrue(precision >= 0d);
-      assertTrue(precision <= 1d);
-      double recall = confusionMatrix.getRecall();
-      assertTrue(recall >= 0d);
-      assertTrue(recall <= 1d);
-      double f1Measure = confusionMatrix.getF1Measure();
-      assertTrue(f1Measure >= 0d);
-      assertTrue(f1Measure <= 1d);
+      checkCM(confusionMatrix);
       assertTrue(confusionMatrix.getPrecision("true") >= 0d);
       assertTrue(confusionMatrix.getPrecision("true") <= 1d);
       assertTrue(confusionMatrix.getPrecision("false") >= 0d);


Mime
View raw message