Return-Path: X-Original-To: apmail-lucene-commits-archive@www.apache.org Delivered-To: apmail-lucene-commits-archive@www.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id BEC6C19C45 for ; Wed, 13 Apr 2016 23:21:19 +0000 (UTC) Received: (qmail 21408 invoked by uid 500); 13 Apr 2016 23:21:17 -0000 Mailing-List: contact commits-help@lucene.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@lucene.apache.org Delivered-To: mailing list commits@lucene.apache.org Received: (qmail 20767 invoked by uid 99); 13 Apr 2016 23:21:17 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 13 Apr 2016 23:21:17 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 5B8B5EAB48; Wed, 13 Apr 2016 23:21:17 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: thelabdude@apache.org To: commits@lucene.apache.org Date: Wed, 13 Apr 2016 23:21:29 -0000 Message-Id: In-Reply-To: References: X-Mailer: ASF-Git Admin Mailer Subject: [13/50] lucene-solr:jira/SOLR-8908: LUCENE-7196 - guaranteed class coverage in split indexes through grouping by class LUCENE-7196 - guaranteed class coverage in split indexes through grouping by class Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/112078ea Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/112078ea Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/112078ea Branch: refs/heads/jira/SOLR-8908 Commit: 112078eaf909bdda27a092d13a8431c046a82ac0 Parents: d7867b8 Author: Tommaso Teofili Authored: Fri Apr 8 16:59:41 2016 +0200 Committer: Tommaso Teofili Committed: Mon Apr 11 09:59:22 2016 +0200 ---------------------------------------------------------------------- lucene/classification/build.xml | 6 +- .../classification/utils/DatasetSplitter.java | 126 +++++++++++++------ .../classification/utils/DataSplitterTest.java | 33 ++--- 3 files changed, 106 insertions(+), 59 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/112078ea/lucene/classification/build.xml ---------------------------------------------------------------------- diff --git a/lucene/classification/build.xml b/lucene/classification/build.xml index 930b1fa..fd15239 100644 --- a/lucene/classification/build.xml +++ b/lucene/classification/build.xml @@ -27,6 +27,7 @@ + @@ -35,15 +36,16 @@ - + - + http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/112078ea/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java ---------------------------------------------------------------------- diff --git a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java index 0b03b94..fce786b 100644 --- a/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java +++ b/lucene/classification/src/java/org/apache/lucene/classification/utils/DatasetSplitter.java @@ -18,6 +18,7 @@ package org.apache.lucene.classification.utils; import java.io.IOException; +import java.util.HashMap; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.document.Document; @@ -28,11 +29,16 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.Terms; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.grouping.GroupDocs; +import org.apache.lucene.search.grouping.GroupingSearch; +import org.apache.lucene.search.grouping.TopGroups; import org.apache.lucene.store.Directory; +import org.apache.lucene.uninverting.UninvertingReader; /** * Utility class for creating training / test / cross validation indexes from the original index. @@ -61,67 +67,78 @@ public class DatasetSplitter { * @param testIndex a {@link Directory} used to write the test index * @param crossValidationIndex a {@link Directory} used to write the cross validation index * @param analyzer {@link Analyzer} used to create the new docs + * @param termVectors {@code true} if term vectors should be kept + * @param classFieldName names of the field used as the label for classification * @param fieldNames names of fields that need to be put in the new indexes or null if all should be used * @throws IOException if any writing operation fails on any of the indexes */ public void split(LeafReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, - Analyzer analyzer, String... fieldNames) throws IOException { + Analyzer analyzer, boolean termVectors, String classFieldName, String... fieldNames) throws IOException { // create IWs for train / test / cv IDXs IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer)); IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer)); IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer)); + // try to get the exact no. of existing classes + Terms terms = originalIndex.terms(classFieldName); + long noOfClasses = -1; + if (terms != null) { + noOfClasses = terms.size(); + + } + if (noOfClasses == -1) { + noOfClasses = 10000; // fallback + } + + HashMap mapping = new HashMap<>(); + mapping.put(classFieldName, UninvertingReader.Type.SORTED); + UninvertingReader uninvertingReader = new UninvertingReader(originalIndex, mapping); + try { - int size = originalIndex.maxDoc(); - IndexSearcher indexSearcher = new IndexSearcher(originalIndex); - TopDocs topDocs = indexSearcher.search(new MatchAllDocsQuery(), Integer.MAX_VALUE); + IndexSearcher indexSearcher = new IndexSearcher(uninvertingReader); + GroupingSearch gs = new GroupingSearch(classFieldName); + gs.setGroupSort(Sort.INDEXORDER); + gs.setSortWithinGroup(Sort.INDEXORDER); + gs.setAllGroups(true); + gs.setGroupDocsLimit(originalIndex.maxDoc()); + TopGroups topGroups = gs.search(indexSearcher, new MatchAllDocsQuery(), 0, (int) noOfClasses); // set the type to be indexed, stored, with term vectors FieldType ft = new FieldType(TextField.TYPE_STORED); - ft.setStoreTermVectors(true); - ft.setStoreTermVectorOffsets(true); - ft.setStoreTermVectorPositions(true); + if (termVectors) { + ft.setStoreTermVectors(true); + ft.setStoreTermVectorOffsets(true); + ft.setStoreTermVectorPositions(true); + } int b = 0; // iterate over existing documents - for (ScoreDoc scoreDoc : topDocs.scoreDocs) { - - // create a new document for indexing - Document doc = new Document(); - Document document = originalIndex.document(scoreDoc.doc); - if (fieldNames != null && fieldNames.length > 0) { - for (String fieldName : fieldNames) { - IndexableField field = document.getField(fieldName); - if (field != null) { - doc.add(new Field(fieldName, field.stringValue(), ft)); - } + for (GroupDocs group : topGroups.groups) { + int totalHits = group.totalHits; + double testSize = totalHits * testRatio; + int tc = 0; + double cvSize = totalHits * crossValidationRatio; + int cvc = 0; + for (ScoreDoc scoreDoc : group.scoreDocs) { + + // create a new document for indexing + Document doc = createNewDoc(originalIndex, ft, scoreDoc, fieldNames); + + // add it to one of the IDXs + if (b % 2 == 0 && tc < testSize) { + testWriter.addDocument(doc); + tc++; + } else if (cvc < cvSize) { + cvWriter.addDocument(doc); + cvc++; + } else { + trainingWriter.addDocument(doc); } - } else { - for (IndexableField field : document.getFields()) { - if (field.readerValue() != null) { - doc.add(new Field(field.name(), field.readerValue(), ft)); - } else if (field.binaryValue() != null) { - doc.add(new Field(field.name(), field.binaryValue(), ft)); - } else if (field.stringValue() != null) { - doc.add(new Field(field.name(), field.stringValue(), ft)); - } else if (field.numericValue() != null) { - doc.add(new Field(field.name(), field.numericValue().toString(), ft)); - } - } - } - - // add it to one of the IDXs - if (b % 2 == 0 && testWriter.maxDoc() < size * testRatio) { - testWriter.addDocument(doc); - } else if (cvWriter.maxDoc() < size * crossValidationRatio) { - cvWriter.addDocument(doc); - } else { - trainingWriter.addDocument(doc); + b++; } - b++; } // commit testWriter.commit(); @@ -139,7 +156,34 @@ public class DatasetSplitter { testWriter.close(); cvWriter.close(); trainingWriter.close(); + uninvertingReader.close(); + } + } + + private Document createNewDoc(LeafReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException { + Document doc = new Document(); + Document document = originalIndex.document(scoreDoc.doc); + if (fieldNames != null && fieldNames.length > 0) { + for (String fieldName : fieldNames) { + IndexableField field = document.getField(fieldName); + if (field != null) { + doc.add(new Field(fieldName, field.stringValue(), ft)); + } + } + } else { + for (IndexableField field : document.getFields()) { + if (field.readerValue() != null) { + doc.add(new Field(field.name(), field.readerValue(), ft)); + } else if (field.binaryValue() != null) { + doc.add(new Field(field.name(), field.binaryValue(), ft)); + } else if (field.stringValue() != null) { + doc.add(new Field(field.name(), field.stringValue(), ft)); + } else if (field.numericValue() != null) { + doc.add(new Field(field.name(), field.numericValue().toString(), ft)); + } + } } + return doc; } } http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/112078ea/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java ---------------------------------------------------------------------- diff --git a/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java b/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java index 2984bb5..0b6f077 100644 --- a/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java +++ b/lucene/classification/src/test/org/apache/lucene/classification/utils/DataSplitterTest.java @@ -17,27 +17,28 @@ package org.apache.lucene.classification.utils; -import org.apache.lucene.analysis.Analyzer; +import java.io.IOException; +import java.util.Random; + import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.TextField; -import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.store.BaseDirectoryWrapper; import org.apache.lucene.store.Directory; -import org.apache.lucene.util.TestUtil; +import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util.TestUtil; import org.junit.After; import org.junit.Before; import org.junit.Test; -import java.io.IOException; -import java.util.Random; - /** * Testcase for {@link org.apache.lucene.classification.utils.DatasetSplitter} */ @@ -47,9 +48,9 @@ public class DataSplitterTest extends LuceneTestCase { private RandomIndexWriter indexWriter; private Directory dir; - private String textFieldName = "text"; - private String classFieldName = "class"; - private String idFieldName = "id"; + private static final String textFieldName = "text"; + private static final String classFieldName = "class"; + private static final String idFieldName = "id"; @Override @Before @@ -65,11 +66,11 @@ public class DataSplitterTest extends LuceneTestCase { Document doc; Random rnd = random(); - for (int i = 0; i < 100; i++) { + for (int i = 0; i < 1000; i++) { doc = new Document(); - doc.add(new Field(idFieldName, Integer.toString(i), ft)); + doc.add(new Field(idFieldName, "id" + Integer.toString(i), ft)); doc.add(new Field(textFieldName, TestUtil.randomUnicodeString(rnd, 1024), ft)); - doc.add(new Field(classFieldName, TestUtil.randomUnicodeString(rnd, 10), ft)); + doc.add(new Field(classFieldName, Integer.toString(rnd.nextInt(10)), ft)); indexWriter.addDocument(doc); } @@ -108,18 +109,18 @@ public class DataSplitterTest extends LuceneTestCase { try { DatasetSplitter datasetSplitter = new DatasetSplitter(testRatio, crossValidationRatio); - datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), fieldNames); + datasetSplitter.split(originalIndex, trainingIndex, testIndex, crossValidationIndex, new MockAnalyzer(random()), true, classFieldName, fieldNames); assertNotNull(trainingIndex); assertNotNull(testIndex); assertNotNull(crossValidationIndex); DirectoryReader trainingReader = DirectoryReader.open(trainingIndex); - assertTrue((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)) == trainingReader.maxDoc()); + assertEquals((int) (originalIndex.maxDoc() * (1d - testRatio - crossValidationRatio)), trainingReader.maxDoc(), 20); DirectoryReader testReader = DirectoryReader.open(testIndex); - assertTrue((int) (originalIndex.maxDoc() * testRatio) == testReader.maxDoc()); + assertEquals((int) (originalIndex.maxDoc() * testRatio), testReader.maxDoc(), 20); DirectoryReader cvReader = DirectoryReader.open(crossValidationIndex); - assertTrue((int) (originalIndex.maxDoc() * crossValidationRatio) == cvReader.maxDoc()); + assertEquals((int) (originalIndex.maxDoc() * crossValidationRatio), cvReader.maxDoc(), 20); trainingReader.close(); testReader.close();