Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 44C1F200BD1 for ; Mon, 28 Nov 2016 18:26:36 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id 4376C160B00; Mon, 28 Nov 2016 17:26:36 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 50309160B22 for ; Mon, 28 Nov 2016 18:26:34 +0100 (CET) Received: (qmail 15044 invoked by uid 500); 28 Nov 2016 17:26:31 -0000 Mailing-List: contact commits-help@lucene.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@lucene.apache.org Delivered-To: mailing list commits@lucene.apache.org Received: (qmail 14043 invoked by uid 99); 28 Nov 2016 17:26:30 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 28 Nov 2016 17:26:30 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 8BCC7EEF4B; Mon, 28 Nov 2016 17:26:30 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: krisden@apache.org To: commits@lucene.apache.org Date: Mon, 28 Nov 2016 17:26:51 -0000 Message-Id: <396a7ff71e47408d9ae9d51eadd262b8@git.apache.org> In-Reply-To: <5f99c0456f774babb7a1d961a34f1e03@git.apache.org> References: <5f99c0456f774babb7a1d961a34f1e03@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [22/42] lucene-solr:jira/solr-8593: SOLR-8871 - various improvements to ClassificationURP archived-at: Mon, 28 Nov 2016 17:26:36 -0000 SOLR-8871 - various improvements to ClassificationURP Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/5ad741ee Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/5ad741ee Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/5ad741ee Branch: refs/heads/jira/solr-8593 Commit: 5ad741eef8241de86945e710cdcb32e77a7183a3 Parents: e9e4715 Author: Tommaso Teofili Authored: Thu Nov 24 23:43:57 2016 +0100 Committer: Tommaso Teofili Committed: Thu Nov 24 23:43:57 2016 +0100 ---------------------------------------------------------------------- .../UIMAUpdateRequestProcessorTest.java | 25 - .../ClassificationUpdateProcessor.java | 59 ++- .../ClassificationUpdateProcessorFactory.java | 197 +++---- .../ClassificationUpdateProcessorParams.java | 112 ++++ .../conf/solrconfig-classification.xml | 15 + ...lassificationUpdateProcessorFactoryTest.java | 208 ++------ ...ificationUpdateProcessorIntegrationTest.java | 192 +++++++ .../ClassificationUpdateProcessorTest.java | 507 +++++++++++++++++++ .../SignatureUpdateProcessorFactoryTest.java | 28 +- .../TestPartialUpdateDeduplication.java | 2 - .../java/org/apache/solr/SolrTestCaseJ4.java | 22 + 11 files changed, 1012 insertions(+), 355 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java ---------------------------------------------------------------------- diff --git a/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java b/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java index 5879c78..3833696 100644 --- a/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java +++ b/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java @@ -16,22 +16,12 @@ */ package org.apache.solr.uima.processor; -import java.util.ArrayList; -import java.util.HashMap; import java.util.Map; import org.apache.lucene.util.LuceneTestCase.Slow; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrException; -import org.apache.solr.common.params.MultiMapSolrParams; -import org.apache.solr.common.params.SolrParams; -import org.apache.solr.common.params.UpdateParams; -import org.apache.solr.common.util.ContentStream; -import org.apache.solr.common.util.ContentStreamBase; import org.apache.solr.core.SolrCore; -import org.apache.solr.handler.UpdateRequestHandler; -import org.apache.solr.request.SolrQueryRequestBase; -import org.apache.solr.response.SolrQueryResponse; import org.apache.solr.uima.processor.SolrUIMAConfiguration.MapField; import org.apache.solr.update.processor.UpdateRequestProcessor; import org.apache.solr.update.processor.UpdateRequestProcessorChain; @@ -188,19 +178,4 @@ public class UIMAUpdateRequestProcessorTest extends SolrTestCaseJ4 { } } - private void addDoc(String chain, String doc) throws Exception { - Map params = new HashMap<>(); - params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain }); - MultiMapSolrParams mmparams = new MultiMapSolrParams(params); - SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), (SolrParams) mmparams) { - }; - - UpdateRequestHandler handler = new UpdateRequestHandler(); - handler.init(null); - ArrayList streams = new ArrayList<>(2); - streams.add(new ContentStreamBase.StringStream(doc)); - req.setContentStreams(streams); - handler.handleRequestBody(req, new SolrQueryResponse()); - } - } http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java ---------------------------------------------------------------------- diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java index 050fff0..8ce9814 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java +++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java @@ -19,6 +19,7 @@ package org.apache.solr.update.processor; import java.io.IOException; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.lucene.analysis.Analyzer; @@ -33,6 +34,7 @@ import org.apache.solr.common.SolrInputDocument; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; import org.apache.solr.update.AddUpdateCommand; +import org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm; /** * This Class is a Request Update Processor to classify the document in input and add a field @@ -42,43 +44,54 @@ import org.apache.solr.update.AddUpdateCommand; class ClassificationUpdateProcessor extends UpdateRequestProcessor { - private String classFieldName; // the field to index the assigned class - + private final String trainingClassField; + private final String predictedClassField; + private final int maxOutputClasses; private DocumentClassifier classifier; /** * Sole constructor * - * @param inputFieldNames fields to be used as classifier's inputs - * @param classFieldName field to be used as classifier's output - * @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"} - * @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"} - * @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"} - * @param algorithm the name of the classifier to use + * @param classificationParams classification advanced params * @param next next update processor in the chain * @param indexReader index reader * @param schema schema */ - public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm, - UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) { + public ClassificationUpdateProcessor(ClassificationUpdateProcessorParams classificationParams, UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) { super(next); - this.classFieldName = classFieldName; - Map field2analyzer = new HashMap(); + this.trainingClassField = classificationParams.getTrainingClassField(); + this.predictedClassField = classificationParams.getPredictedClassField(); + this.maxOutputClasses = classificationParams.getMaxPredictedClasses(); + String[] inputFieldNamesWithBoost = classificationParams.getInputFieldNames(); + Algorithm classificationAlgorithm = classificationParams.getAlgorithm(); + + Map field2analyzer = new HashMap<>(); + String[] inputFieldNames = this.removeBoost(inputFieldNamesWithBoost); for (String fieldName : inputFieldNames) { SchemaField fieldFromSolrSchema = schema.getField(fieldName); Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer(); field2analyzer.put(fieldName, indexAnalyzer); } - switch (algorithm) { - case "knn": - classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames); + switch (classificationAlgorithm) { + case KNN: + classifier = new KNearestNeighborDocumentClassifier(indexReader, null, classificationParams.getTrainingFilterQuery(), classificationParams.getK(), classificationParams.getMinDf(), classificationParams.getMinTf(), trainingClassField, field2analyzer, inputFieldNamesWithBoost); break; - case "bayes": - classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames); + case BAYES: + classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, trainingClassField, field2analyzer, inputFieldNamesWithBoost); break; } } + private String[] removeBoost(String[] inputFieldNamesWithBoost) { + String[] inputFieldNames = new String[inputFieldNamesWithBoost.length]; + for (int i = 0; i < inputFieldNamesWithBoost.length; i++) { + String singleFieldNameWithBoost = inputFieldNamesWithBoost[i]; + String[] fieldName2boost = singleFieldNameWithBoost.split("\\^"); + inputFieldNames[i] = fieldName2boost[0]; + } + return inputFieldNames; + } + /** * @param cmd the update command in input containing the Document to classify * @throws IOException If there is a low-level I/O error @@ -89,12 +102,14 @@ class ClassificationUpdateProcessor SolrInputDocument doc = cmd.getSolrInputDocument(); Document luceneDocument = cmd.getLuceneDocument(); String assignedClass; - Object documentClass = doc.getFieldValue(classFieldName); + Object documentClass = doc.getFieldValue(trainingClassField); if (documentClass == null) { - ClassificationResult classificationResult = classifier.assignClass(luceneDocument); - if (classificationResult != null) { - assignedClass = classificationResult.getAssignedClass().utf8ToString(); - doc.addField(classFieldName, assignedClass); + List> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses); + if (assignedClassifications != null) { + for (ClassificationResult singleClassification : assignedClassifications) { + assignedClass = singleClassification.getAssignedClass().utf8ToString(); + doc.addField(predictedClassField, assignedClass); + } } } super.processAdd(cmd); http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java ---------------------------------------------------------------------- diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java index 81bec2f..19e0dfe 100644 --- a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java +++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java @@ -18,12 +18,17 @@ package org.apache.solr.update.processor; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Query; import org.apache.solr.common.SolrException; import org.apache.solr.common.params.SolrParams; import org.apache.solr.common.util.NamedList; import org.apache.solr.request.SolrQueryRequest; import org.apache.solr.response.SolrQueryResponse; import org.apache.solr.schema.IndexSchema; +import org.apache.solr.search.LuceneQParser; +import org.apache.solr.search.SyntaxError; + +import static org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm.KNN; /** * This class implements an UpdateProcessorFactory for the Classification Update Processor. @@ -33,49 +38,67 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor // Update Processor Config params private static final String INPUT_FIELDS_PARAM = "inputFields"; - private static final String CLASS_FIELD_PARAM = "classField"; + private static final String TRAINING_CLASS_FIELD_PARAM = "classField"; + private static final String PREDICTED_CLASS_FIELD_PARAM = "predictedClassField"; + private static final String MAX_CLASSES_TO_ASSIGN_PARAM = "predictedClass.maxCount"; private static final String ALGORITHM_PARAM = "algorithm"; private static final String KNN_MIN_TF_PARAM = "knn.minTf"; private static final String KNN_MIN_DF_PARAM = "knn.minDf"; private static final String KNN_K_PARAM = "knn.k"; + private static final String KNN_FILTER_QUERY = "knn.filterQuery"; + + public enum Algorithm {KNN, BAYES} //Update Processor Defaults + private static final int DEFAULT_MAX_CLASSES_TO_ASSIGN = 1; private static final int DEFAULT_MIN_TF = 1; private static final int DEFAULT_MIN_DF = 1; private static final int DEFAULT_K = 10; - private static final String DEFAULT_ALGORITHM = "knn"; - - private String[] inputFieldNames; // the array of fields to be sent to the Classifier - - private String classFieldName; // the field containing the class for the Document - - private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes' - - private int minTf; // knn specific - the minimum Term Frequency for considering a term - - private int minDf; // knn specific - the minimum Document Frequency for considering a term + private static final Algorithm DEFAULT_ALGORITHM = KNN; - private int k; // knn specific - thw window of top results to evaluate, when assigning the class + private SolrParams params; + private ClassificationUpdateProcessorParams classificationParams; @Override public void init(final NamedList args) { if (args != null) { - SolrParams params = SolrParams.toSolrParams(args); + params = SolrParams.toSolrParams(args); + classificationParams = new ClassificationUpdateProcessorParams(); String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields checkNotNull(INPUT_FIELDS_PARAM, fieldNames); - inputFieldNames = fieldNames.split("\\,"); - - classFieldName = params.get(CLASS_FIELD_PARAM); - checkNotNull(CLASS_FIELD_PARAM, classFieldName); - - algorithm = params.get(ALGORITHM_PARAM); - if (algorithm == null) - algorithm = DEFAULT_ALGORITHM; - - minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF); - minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF); - k = getIntParam(params, KNN_K_PARAM, DEFAULT_K); + classificationParams.setInputFieldNames(fieldNames.split("\\,")); + + String trainingClassField = (params.get(TRAINING_CLASS_FIELD_PARAM)); + checkNotNull(TRAINING_CLASS_FIELD_PARAM, trainingClassField); + classificationParams.setTrainingClassField(trainingClassField); + + String predictedClassField = (params.get(PREDICTED_CLASS_FIELD_PARAM)); + if (predictedClassField == null || predictedClassField.isEmpty()) { + predictedClassField = trainingClassField; + } + classificationParams.setPredictedClassField(predictedClassField); + + classificationParams.setMaxPredictedClasses(getIntParam(params, MAX_CLASSES_TO_ASSIGN_PARAM, DEFAULT_MAX_CLASSES_TO_ASSIGN)); + + String algorithmString = params.get(ALGORITHM_PARAM); + Algorithm classificationAlgorithm; + try { + if (algorithmString == null || Algorithm.valueOf(algorithmString.toUpperCase()) == null) { + classificationAlgorithm = DEFAULT_ALGORITHM; + } else { + classificationAlgorithm = Algorithm.valueOf(algorithmString.toUpperCase()); + } + } catch (IllegalArgumentException e) { + throw new SolrException + (SolrException.ErrorCode.SERVER_ERROR, + "Classification UpdateProcessor Algorithm: '" + algorithmString + "' not supported"); + } + classificationParams.setAlgorithm(classificationAlgorithm); + + classificationParams.setMinTf(getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF)); + classificationParams.setMinDf(getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF)); + classificationParams.setK(getIntParam(params, KNN_K_PARAM, DEFAULT_K)); } } @@ -108,116 +131,34 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor @Override public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) { + String trainingFilterQueryString = (params.get(KNN_FILTER_QUERY)); + try { + if (trainingFilterQueryString != null && !trainingFilterQueryString.isEmpty()) { + Query trainingFilterQuery = this.parseFilterQuery(trainingFilterQueryString, params, req); + classificationParams.setTrainingFilterQuery(trainingFilterQuery); + } + } catch (SyntaxError | RuntimeException syntaxError) { + throw new SolrException + (SolrException.ErrorCode.SERVER_ERROR, + "Classification UpdateProcessor Training Filter Query: '" + trainingFilterQueryString + "' is not supported", syntaxError); + } + IndexSchema schema = req.getSchema(); IndexReader indexReader = req.getSearcher().getIndexReader(); - return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema); - } - /** - * get field names used as classifier's inputs - * - * @return the input field names - */ - public String[] getInputFieldNames() { - return inputFieldNames; - } - - /** - * set field names used as classifier's inputs - * - * @param inputFieldNames the input field names - */ - public void setInputFieldNames(String[] inputFieldNames) { - this.inputFieldNames = inputFieldNames; + return new ClassificationUpdateProcessor(classificationParams, next, indexReader, schema); } - /** - * get field names used as classifier's output - * - * @return the output field name - */ - public String getClassFieldName() { - return classFieldName; + private Query parseFilterQuery(String trainingFilterQueryString, SolrParams params, SolrQueryRequest req) throws SyntaxError { + LuceneQParser parser = new LuceneQParser(trainingFilterQueryString, null, params, req); + return parser.parse(); } - /** - * set field names used as classifier's output - * - * @param classFieldName the output field name - */ - public void setClassFieldName(String classFieldName) { - this.classFieldName = classFieldName; + public ClassificationUpdateProcessorParams getClassificationParams() { + return classificationParams; } - /** - * get the name of the classifier algorithm used - * - * @return the classifier algorithm used - */ - public String getAlgorithm() { - return algorithm; - } - - /** - * set the name of the classifier algorithm used - * - * @param algorithm the classifier algorithm used - */ - public void setAlgorithm(String algorithm) { - this.algorithm = algorithm; - } - - /** - * get the min term frequency value to be used in case algorithm is {@code "knn"} - * - * @return the min term frequency - */ - public int getMinTf() { - return minTf; - } - - /** - * set the min term frequency value to be used in case algorithm is {@code "knn"} - * - * @param minTf the min term frequency - */ - public void setMinTf(int minTf) { - this.minTf = minTf; - } - - /** - * get the min document frequency value to be used in case algorithm is {@code "knn"} - * - * @return the min document frequency - */ - public int getMinDf() { - return minDf; - } - - /** - * set the min document frequency value to be used in case algorithm is {@code "knn"} - * - * @param minDf the min document frequency - */ - public void setMinDf(int minDf) { - this.minDf = minDf; - } - - /** - * get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"} - * - * @return the no. of neighbors to analyze - */ - public int getK() { - return k; - } - - /** - * set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"} - * - * @param k the no. of neighbors to analyze - */ - public void setK(int k) { - this.k = k; + public void setClassificationParams(ClassificationUpdateProcessorParams classificationParams) { + this.classificationParams = classificationParams; } } http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java ---------------------------------------------------------------------- diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java new file mode 100644 index 0000000..536cec3 --- /dev/null +++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.update.processor; + +import org.apache.lucene.search.Query; + +public class ClassificationUpdateProcessorParams { + + private String[] inputFieldNames; // the array of fields to be sent to the Classifier + + private Query trainingFilterQuery; // a filter query to reduce the training set to a subset + + private String trainingClassField; // the field containing the class for the Document + + private String predictedClassField; // the field that will contain the predicted class + + private int maxPredictedClasses; // the max number of classes to assign + + private ClassificationUpdateProcessorFactory.Algorithm algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes' + + private int minTf; // knn specific - the minimum Term Frequency for considering a term + + private int minDf; // knn specific - the minimum Document Frequency for considering a term + + private int k; // knn specific - thw window of top results to evaluate, when assigning the class + + public String[] getInputFieldNames() { + return inputFieldNames; + } + + public void setInputFieldNames(String[] inputFieldNames) { + this.inputFieldNames = inputFieldNames; + } + + public Query getTrainingFilterQuery() { + return trainingFilterQuery; + } + + public void setTrainingFilterQuery(Query trainingFilterQuery) { + this.trainingFilterQuery = trainingFilterQuery; + } + + public String getTrainingClassField() { + return trainingClassField; + } + + public void setTrainingClassField(String trainingClassField) { + this.trainingClassField = trainingClassField; + } + + public String getPredictedClassField() { + return predictedClassField; + } + + public void setPredictedClassField(String predictedClassField) { + this.predictedClassField = predictedClassField; + } + + public int getMaxPredictedClasses() { + return maxPredictedClasses; + } + + public void setMaxPredictedClasses(int maxPredictedClasses) { + this.maxPredictedClasses = maxPredictedClasses; + } + + public ClassificationUpdateProcessorFactory.Algorithm getAlgorithm() { + return algorithm; + } + + public void setAlgorithm(ClassificationUpdateProcessorFactory.Algorithm algorithm) { + this.algorithm = algorithm; + } + + public int getMinTf() { + return minTf; + } + + public void setMinTf(int minTf) { + this.minTf = minTf; + } + + public int getMinDf() { + return minDf; + } + + public void setMinDf(int minDf) { + this.minDf = minDf; + } + + public int getK() { + return k; + } + + public void setK(int k) { + this.k = k; + } +} http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml ---------------------------------------------------------------------- diff --git a/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml index 3656335..f688ed1 100644 --- a/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml +++ b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml @@ -47,6 +47,21 @@ 1 1 5 + cat:(class1 OR class2) + + + + + + + title,content,author + cat + + knn + 1 + 1 + 5 + not valid ( lucene query http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java ---------------------------------------------------------------------- diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java index 05d112f..fe22918 100644 --- a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java +++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java @@ -14,71 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.solr.update.processor; -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Map; - -import org.apache.lucene.document.Document; -import org.apache.lucene.index.Term; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TermQuery; -import org.apache.lucene.search.TopDocs; import org.apache.solr.SolrTestCaseJ4; import org.apache.solr.common.SolrException; -import org.apache.solr.common.params.MultiMapSolrParams; -import org.apache.solr.common.params.SolrParams; -import org.apache.solr.common.params.UpdateParams; -import org.apache.solr.common.util.ContentStream; -import org.apache.solr.common.util.ContentStreamBase; import org.apache.solr.common.util.NamedList; -import org.apache.solr.handler.UpdateRequestHandler; import org.apache.solr.request.SolrQueryRequest; -import org.apache.solr.request.SolrQueryRequestBase; import org.apache.solr.response.SolrQueryResponse; -import org.apache.solr.search.SolrIndexSearcher; import org.junit.Before; -import org.junit.BeforeClass; import org.junit.Test; +import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.mock; + /** - * Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory} + * Tests for {@link ClassificationUpdateProcessorFactory} */ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 { - // field names are used in accordance with the solrconfig and schema supplied - private static final String ID = "id"; - private static final String TITLE = "title"; - private static final String CONTENT = "content"; - private static final String AUTHOR = "author"; - private static final String CLASS = "cat"; - - private static final String CHAIN = "classification"; - - private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory(); private NamedList args = new NamedList(); - @BeforeClass - public static void beforeClass() throws Exception { - System.setProperty("enable.update.log", "false"); - initCore("solrconfig-classification.xml", "schema-classification.xml"); - } - - @Override - @Before - public void setUp() throws Exception { - super.setUp(); - clearIndex(); - assertU(commit()); - } - @Before public void initArgs() { args.add("inputFields", "inputField1,inputField2"); args.add("classField", "classField1"); + args.add("predictedClassField", "classFieldX"); args.add("algorithm", "bayes"); args.add("knn.k", "9"); args.add("knn.minDf", "8"); @@ -86,22 +46,23 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 { } @Test - public void testFullInit() { + public void init_fullArgs_shouldInitFullClassificationParams() { cFactoryToTest.init(args); + ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams(); - String[] inputFieldNames = cFactoryToTest.getInputFieldNames(); + String[] inputFieldNames = classificationParams.getInputFieldNames(); assertEquals("inputField1", inputFieldNames[0]); assertEquals("inputField2", inputFieldNames[1]); - assertEquals("classField1", cFactoryToTest.getClassFieldName()); - assertEquals("bayes", cFactoryToTest.getAlgorithm()); - assertEquals(8, cFactoryToTest.getMinDf()); - assertEquals(10, cFactoryToTest.getMinTf()); - assertEquals(9, cFactoryToTest.getK()); - + assertEquals("classField1", classificationParams.getTrainingClassField()); + assertEquals("classFieldX", classificationParams.getPredictedClassField()); + assertEquals(ClassificationUpdateProcessorFactory.Algorithm.BAYES, classificationParams.getAlgorithm()); + assertEquals(8, classificationParams.getMinDf()); + assertEquals(10, classificationParams.getMinTf()); + assertEquals(9, classificationParams.getK()); } @Test - public void testInitEmptyInputField() { + public void init_emptyInputFields_shouldThrowExceptionWithDetailedMessage() { args.removeAll("inputFields"); try { cFactoryToTest.init(args); @@ -111,7 +72,7 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 { } @Test - public void testInitEmptyClassField() { + public void init_emptyClassField_shouldThrowExceptionWithDetailedMessage() { args.removeAll("classField"); try { cFactoryToTest.init(args); @@ -121,114 +82,53 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 { } @Test - public void testDefaults() { - args.removeAll("algorithm"); - args.removeAll("knn.k"); - args.removeAll("knn.minDf"); - args.removeAll("knn.minTf"); - cFactoryToTest.init(args); - assertEquals("knn", cFactoryToTest.getAlgorithm()); - assertEquals(1, cFactoryToTest.getMinDf()); - assertEquals(1, cFactoryToTest.getMinTf()); - assertEquals(10, cFactoryToTest.getK()); - } + public void init_emptyPredictedClassField_shouldDefaultToTrainingClassField() { + args.removeAll("predictedClassField"); - @Test - public void testBasicClassification() throws Exception { - prepareTrainedIndex(); - // To be classified,we index documents without a class and verify the expected one is returned - addDoc(adoc(ID, "10", - TITLE, "word4 word4 word4", - CONTENT, "word5 word5 ", - AUTHOR, "Name1 Surname1")); - addDoc(adoc(ID, "11", - TITLE, "word1 word1", - CONTENT, "word2 word2", - AUTHOR, "Name Surname")); - addDoc(commit()); + cFactoryToTest.init(args); - Document doc10 = getDoc("10"); - assertEquals("class2", doc10.get(CLASS)); - Document doc11 = getDoc("11"); - assertEquals("class1", doc11.get(CLASS)); + ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams(); + assertThat(classificationParams.getPredictedClassField(), is("classField1")); } - /** - * Index some example documents with a class manually assigned. - * This will be our trained model. - * - * @throws Exception If there is a low-level I/O error - */ - private void prepareTrainedIndex() throws Exception { - //class1 - addDoc(adoc(ID, "1", - TITLE, "word1 word1 word1", - CONTENT, "word2 word2 word2", - AUTHOR, "Name Surname", - CLASS, "class1")); - addDoc(adoc(ID, "2", - TITLE, "word1 word1", - CONTENT, "word2 word2", - AUTHOR, "Name Surname", - CLASS, "class1")); - addDoc(adoc(ID, "3", - TITLE, "word1 word1 word1", - CONTENT, "word2", - AUTHOR, "Name Surname", - CLASS, "class1")); - addDoc(adoc(ID, "4", - TITLE, "word1 word1 word1", - CONTENT, "word2 word2 word2", - AUTHOR, "Name Surname", - CLASS, "class1")); - //class2 - addDoc(adoc(ID, "5", - TITLE, "word4 word4 word4", - CONTENT, "word5 word5", - AUTHOR, "Name1 Surname1", - CLASS, "class2")); - addDoc(adoc(ID, "6", - TITLE, "word4 word4", - CONTENT, "word5", - AUTHOR, "Name1 Surname1", - CLASS, "class2")); - addDoc(adoc(ID, "7", - TITLE, "word4 word4 word4", - CONTENT, "word5 word5 word5", - AUTHOR, "Name1 Surname1", - CLASS, "class2")); - addDoc(adoc(ID, "8", - TITLE, "word4", - CONTENT, "word5 word5 word5 word5", - AUTHOR, "Name1 Surname1", - CLASS, "class2")); - addDoc(commit()); + @Test + public void init_unsupportedAlgorithm_shouldThrowExceptionWithDetailedMessage() { + args.removeAll("algorithm"); + args.add("algorithm", "unsupported"); + try { + cFactoryToTest.init(args); + } catch (SolrException e) { + assertEquals("Classification UpdateProcessor Algorithm: 'unsupported' not supported", e.getMessage()); + } } - private Document getDoc(String id) throws IOException { - try (SolrQueryRequest req = req()) { - SolrIndexSearcher searcher = req.getSearcher(); - TermQuery query = new TermQuery(new Term(ID, id)); - TopDocs doc1 = searcher.search(query, 1); - ScoreDoc scoreDoc = doc1.scoreDocs[0]; - return searcher.doc(scoreDoc.doc); + @Test + public void init_unsupportedFilterQuery_shouldThrowExceptionWithDetailedMessage() { + UpdateRequestProcessor mockProcessor = mock(UpdateRequestProcessor.class); + SolrQueryRequest mockRequest = mock(SolrQueryRequest.class); + SolrQueryResponse mockResponse = mock(SolrQueryResponse.class); + args.add("knn.filterQuery", "not supported query"); + try { + cFactoryToTest.init(args); + /* parsing failure happens because of the mocks, fine enough to check a proper exception propagation */ + cFactoryToTest.getInstance(mockRequest, mockResponse, mockProcessor); + } catch (SolrException e) { + assertEquals("Classification UpdateProcessor Training Filter Query: 'not supported query' is not supported", e.getMessage()); } } - static void addDoc(String doc) throws Exception { - Map params = new HashMap<>(); - MultiMapSolrParams mmparams = new MultiMapSolrParams(params); - params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN}); - SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), - (SolrParams) mmparams) { - }; + @Test + public void init_emptyArgs_shouldDefaultClassificationParams() { + args.removeAll("algorithm"); + args.removeAll("knn.k"); + args.removeAll("knn.minDf"); + args.removeAll("knn.minTf"); + cFactoryToTest.init(args); + ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams(); - UpdateRequestHandler handler = new UpdateRequestHandler(); - handler.init(null); - ArrayList streams = new ArrayList<>(2); - streams.add(new ContentStreamBase.StringStream(doc)); - req.setContentStreams(streams); - handler.handleRequestBody(req, new SolrQueryResponse()); - req.close(); + assertEquals(ClassificationUpdateProcessorFactory.Algorithm.KNN, classificationParams.getAlgorithm()); + assertEquals(1, classificationParams.getMinDf()); + assertEquals(1, classificationParams.getMinTf()); + assertEquals(10, classificationParams.getK()); } } http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java ---------------------------------------------------------------------- diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java new file mode 100644 index 0000000..3aee1be --- /dev/null +++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.solr.update.processor; + +import java.io.IOException; + +import org.apache.lucene.document.Document; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrException; +import org.apache.solr.common.util.NamedList; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.search.SolrIndexSearcher; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.hamcrest.core.Is.is; + +/** + * Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory} + */ +public class ClassificationUpdateProcessorIntegrationTest extends SolrTestCaseJ4 { + /* field names are used in accordance with the solrconfig and schema supplied */ + private static final String ID = "id"; + private static final String TITLE = "title"; + private static final String CONTENT = "content"; + private static final String AUTHOR = "author"; + private static final String CLASS = "cat"; + + private static final String CHAIN = "classification"; + private static final String BROKEN_CHAIN_FILTER_QUERY = "classification-unsupported-filterQuery"; + + private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory(); + private NamedList args = new NamedList(); + + @BeforeClass + public static void beforeClass() throws Exception { + System.setProperty("enable.update.log", "false"); + initCore("solrconfig-classification.xml", "schema-classification.xml"); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + clearIndex(); + assertU(commit()); + } + + @Test + public void classify_fullConfiguration_shouldAutoClassify() throws Exception { + indexTrainingSet(); + // To be classified,we index documents without a class and verify the expected one is returned + addDoc(adoc(ID, "22", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 ", + AUTHOR, "Name1 Surname1"), CHAIN); + addDoc(adoc(ID, "21", + TITLE, "word1 word1", + CONTENT, "word2 word2", + AUTHOR, "Name Surname"), CHAIN); + addDoc(commit()); + + Document doc22 = getDoc("22"); + assertThat(doc22.get(CLASS),is("class2")); + Document doc21 = getDoc("21"); + assertThat(doc21.get(CLASS),is("class1")); + } + + @Test + public void classify_unsupportedFilterQueryConfiguration_shouldThrowExceptionWithDetailedMessage() throws Exception { + indexTrainingSet(); + try { + addDoc(adoc(ID, "21", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 ", + AUTHOR, "Name1 Surname1"), BROKEN_CHAIN_FILTER_QUERY); + addDoc(adoc(ID, "22", + TITLE, "word1 word1", + CONTENT, "word2 word2", + AUTHOR, "Name Surname"), BROKEN_CHAIN_FILTER_QUERY); + addDoc(commit()); + } catch (SolrException e) { + assertEquals("Classification UpdateProcessor Training Filter Query: 'not valid ( lucene query' is not supported", e.getMessage()); + } + } + + /** + * Index some example documents with a class manually assigned. + * This will be our trained model. + * + * @throws Exception If there is a low-level I/O error + */ + private void indexTrainingSet() throws Exception { + //class1 + addDoc(adoc(ID, "1", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "Name Surname", + CLASS, "class1"), CHAIN); + addDoc(adoc(ID, "2", + TITLE, "word1 word1", + CONTENT, "word2 word2", + AUTHOR, "Name Surname", + CLASS, "class1"), CHAIN); + addDoc(adoc(ID, "3", + TITLE, "word1 word1 word1", + CONTENT, "word2", + AUTHOR, "Name Surname", + CLASS, "class1"), CHAIN); + addDoc(adoc(ID, "4", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "Name Surname", + CLASS, "class1"), CHAIN); + //class2 + addDoc(adoc(ID, "5", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5", + AUTHOR, "Name Surname", + CLASS, "class2"), CHAIN); + addDoc(adoc(ID, "6", + TITLE, "word4 word4", + CONTENT, "word5", + AUTHOR, "Name Surname", + CLASS, "class2"), CHAIN); + addDoc(adoc(ID, "7", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 word5", + AUTHOR, "Name Surname", + CLASS, "class2"), CHAIN); + addDoc(adoc(ID, "8", + TITLE, "word4", + CONTENT, "word5 word5 word5 word5", + AUTHOR, "Name Surname", + CLASS, "class2"), CHAIN); + //class3 + addDoc(adoc(ID, "9", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5", + AUTHOR, "Name1 Surname1", + CLASS, "class3"), CHAIN); + addDoc(adoc(ID, "10", + TITLE, "word4 word4", + CONTENT, "word5", + AUTHOR, "Name1 Surname1", + CLASS, "class3"), CHAIN); + addDoc(adoc(ID, "11", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 word5", + AUTHOR, "Name1 Surname1", + CLASS, "class3"), CHAIN); + addDoc(adoc(ID, "12", + TITLE, "word4", + CONTENT, "word5 word5 word5 word5", + AUTHOR, "Name1 Surname1", + CLASS, "class3"), CHAIN); + addDoc(commit()); + } + + private Document getDoc(String id) throws IOException { + try (SolrQueryRequest req = req()) { + SolrIndexSearcher searcher = req.getSearcher(); + TermQuery query = new TermQuery(new Term(ID, id)); + TopDocs doc1 = searcher.search(query, 1); + ScoreDoc scoreDoc = doc1.scoreDocs[0]; + return searcher.doc(scoreDoc.doc); + } + } + + private void addDoc(String doc) throws Exception { + addDoc(doc, CHAIN); + } +} http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java ---------------------------------------------------------------------- diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java new file mode 100644 index 0000000..938dfc5 --- /dev/null +++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java @@ -0,0 +1,507 @@ +package org.apache.solr.update.processor; + +import java.io.IOException; +import java.util.ArrayList; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.MockAnalyzer; +import org.apache.lucene.analysis.MockTokenizer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.Directory; +import org.apache.solr.SolrTestCaseJ4; +import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.update.AddUpdateCommand; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.hamcrest.core.Is.is; +import static org.mockito.Mockito.mock; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Tests for {@link ClassificationUpdateProcessor} + */ +public class ClassificationUpdateProcessorTest extends SolrTestCaseJ4 { + /* field names are used in accordance with the solrconfig and schema supplied */ + private static final String ID = "id"; + private static final String TITLE = "title"; + private static final String CONTENT = "content"; + private static final String AUTHOR = "author"; + private static final String TRAINING_CLASS = "cat"; + private static final String PREDICTED_CLASS = "predicted"; + public static final String KNN = "knn"; + + protected Directory directory; + protected IndexReader reader; + protected IndexSearcher searcher; + protected Analyzer analyzer = new MockAnalyzer(random(), MockTokenizer.WHITESPACE, false); + private ClassificationUpdateProcessor updateProcessorToTest; + + @BeforeClass + public static void beforeClass() throws Exception { + System.setProperty("enable.update.log", "false"); + initCore("solrconfig-classification.xml", "schema-classification.xml"); + } + + @Override + public void setUp() throws Exception { + super.setUp(); + } + + @Override + public void tearDown() throws Exception { + reader.close(); + directory.close(); + analyzer.close(); + super.tearDown(); + } + + + + + @Test + public void classificationMonoClass_predictedClassFieldSet_shouldAssignClassInPredictedClassField() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMonoClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN); + params.setPredictedClassField(PREDICTED_CLASS); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + updateProcessorToTest.processAdd(update); + + assertThat(unseenDocument1.getFieldValue(PREDICTED_CLASS),is("class1")); + } + + @Test + public void knnMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMonoClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + updateProcessorToTest.processAdd(update); + + assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1")); + } + + @Test + public void knnMonoClass_boostFields_shouldAssignCorrectClass() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMonoClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN); + params.setInputFieldNames(new String[]{TITLE + "^1.5", CONTENT + "^0.5", AUTHOR + "^2.5"}); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + + updateProcessorToTest.processAdd(update); + + assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2")); + } + + @Test + public void bayesMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMonoClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + updateProcessorToTest.processAdd(update); + + assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1")); + } + + @Test + public void knnMonoClass_contextQueryFiltered_shouldAssignCorrectClass() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMonoClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "a"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN); + Query class3DocsChunk=new TermQuery(new Term(TITLE,"word6")); + params.setTrainingFilterQuery(class3DocsChunk); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + updateProcessorToTest.processAdd(update); + + assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class3")); + } + + @Test + public void bayesMonoClass_boostFields_shouldAssignCorrectClass() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMonoClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES); + params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"}); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + + updateProcessorToTest.processAdd(update); + + assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2")); + } + + @Test + public void knnClassification_maxOutputClassesGreaterThanAvailable_shouldAssignCorrectClass() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMultiClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN); + params.setMaxPredictedClasses(100); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + updateProcessorToTest.processAdd(update); + + ArrayList assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS); + assertThat(assignedClasses.get(0),is("class2")); + assertThat(assignedClasses.get(1),is("class1")); + } + + @Test + public void knnMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMultiClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN); + params.setMaxPredictedClasses(2); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + updateProcessorToTest.processAdd(update); + + ArrayList assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS); + assertThat(assignedClasses.size(),is(2)); + assertThat(assignedClasses.get(0),is("class2")); + assertThat(assignedClasses.get(1),is("class1")); + } + + @Test + public void bayesMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMultiClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES); + params.setMaxPredictedClasses(2); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + updateProcessorToTest.processAdd(update); + + ArrayList assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS); + assertThat(assignedClasses.size(),is(2)); + assertThat(assignedClasses.get(0),is("class2")); + assertThat(assignedClasses.get(1),is("class1")); + } + + @Test + public void knnMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMultiClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN); + params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"}); + params.setMaxPredictedClasses(2); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + + updateProcessorToTest.processAdd(update); + + ArrayList assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS); + assertThat(assignedClasses.size(),is(2)); + assertThat(assignedClasses.get(0),is("class4")); + assertThat(assignedClasses.get(1),is("class6")); + } + + @Test + public void bayesMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception { + UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class); + prepareTrainedIndexMultiClass(); + + AddUpdateCommand update=new AddUpdateCommand(req()); + SolrInputDocument unseenDocument1 = sdoc(ID, "10", + TITLE, "word4 word4 word4", + CONTENT, "word2 word2 ", + AUTHOR, "unseenAuthor"); + update.solrDoc=unseenDocument1; + + ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES); + params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"}); + params.setMaxPredictedClasses(2); + + updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema()); + + updateProcessorToTest.processAdd(update); + + ArrayList assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS); + assertThat(assignedClasses.size(),is(2)); + assertThat(assignedClasses.get(0),is("class4")); + assertThat(assignedClasses.get(1),is("class6")); + } + + private ClassificationUpdateProcessorParams initParams(ClassificationUpdateProcessorFactory.Algorithm classificationAlgorithm) { + ClassificationUpdateProcessorParams params= new ClassificationUpdateProcessorParams(); + params.setInputFieldNames(new String[]{TITLE,CONTENT,AUTHOR}); + params.setTrainingClassField(TRAINING_CLASS); + params.setPredictedClassField(TRAINING_CLASS); + params.setMinTf(1); + params.setMinDf(1); + params.setK(5); + params.setAlgorithm(classificationAlgorithm); + params.setMaxPredictedClasses(1); + return params; + } + + /** + * Index some example documents with a class manually assigned. + * This will be our trained model. + * + * @throws Exception If there is a low-level I/O error + */ + private void prepareTrainedIndexMonoClass() throws Exception { + directory = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), directory); + + //class1 + addDoc(writer, buildLuceneDocument(ID, "1", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "a", + TRAINING_CLASS, "class1")); + addDoc(writer, buildLuceneDocument(ID, "2", + TITLE, "word1 word1", + CONTENT, "word2 word2", + AUTHOR, "a", + TRAINING_CLASS, "class1")); + addDoc(writer, buildLuceneDocument(ID, "3", + TITLE, "word1 word1 word1", + CONTENT, "word2", + AUTHOR, "a", + TRAINING_CLASS, "class1")); + addDoc(writer, buildLuceneDocument(ID, "4", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "a", + TRAINING_CLASS, "class1")); + //class2 + addDoc(writer, buildLuceneDocument(ID, "5", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5", + AUTHOR, "c", + TRAINING_CLASS, "class2")); + addDoc(writer, buildLuceneDocument(ID, "6", + TITLE, "word4 word4", + CONTENT, "word5", + AUTHOR, "c", + TRAINING_CLASS, "class2")); + addDoc(writer, buildLuceneDocument(ID, "7", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 word5", + AUTHOR, "c", + TRAINING_CLASS, "class2")); + addDoc(writer, buildLuceneDocument(ID, "8", + TITLE, "word4", + CONTENT, "word5 word5 word5 word5", + AUTHOR, "c", + TRAINING_CLASS, "class2")); + //class3 + addDoc(writer, buildLuceneDocument(ID, "9", + TITLE, "word6", + CONTENT, "word7", + AUTHOR, "a", + TRAINING_CLASS, "class3")); + addDoc(writer, buildLuceneDocument(ID, "10", + TITLE, "word6", + CONTENT, "word7", + AUTHOR, "a", + TRAINING_CLASS, "class3")); + addDoc(writer, buildLuceneDocument(ID, "11", + TITLE, "word6", + CONTENT, "word7", + AUTHOR, "a", + TRAINING_CLASS, "class3")); + addDoc(writer, buildLuceneDocument(ID, "12", + TITLE, "word6", + CONTENT, "word7", + AUTHOR, "a", + TRAINING_CLASS, "class3")); + + reader = writer.getReader(); + writer.close(); + searcher = newSearcher(reader); + } + + private void prepareTrainedIndexMultiClass() throws Exception { + directory = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter(random(), directory); + + //class1 + addDoc(writer, buildLuceneDocument(ID, "1", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "Name Surname", + TRAINING_CLASS, "class1", + TRAINING_CLASS, "class2" + )); + addDoc(writer, buildLuceneDocument(ID, "2", + TITLE, "word1 word1", + CONTENT, "word2 word2", + AUTHOR, "Name Surname", + TRAINING_CLASS, "class3", + TRAINING_CLASS, "class2" + )); + addDoc(writer, buildLuceneDocument(ID, "3", + TITLE, "word1 word1 word1", + CONTENT, "word2", + AUTHOR, "Name Surname", + TRAINING_CLASS, "class1", + TRAINING_CLASS, "class2" + )); + addDoc(writer, buildLuceneDocument(ID, "4", + TITLE, "word1 word1 word1", + CONTENT, "word2 word2 word2", + AUTHOR, "Name Surname", + TRAINING_CLASS, "class1", + TRAINING_CLASS, "class2" + )); + //class2 + addDoc(writer, buildLuceneDocument(ID, "5", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5", + AUTHOR, "Name1 Surname1", + TRAINING_CLASS, "class6", + TRAINING_CLASS, "class4" + )); + addDoc(writer, buildLuceneDocument(ID, "6", + TITLE, "word4 word4", + CONTENT, "word5", + AUTHOR, "Name1 Surname1", + TRAINING_CLASS, "class5", + TRAINING_CLASS, "class4" + )); + addDoc(writer, buildLuceneDocument(ID, "7", + TITLE, "word4 word4 word4", + CONTENT, "word5 word5 word5", + AUTHOR, "Name1 Surname1", + TRAINING_CLASS, "class6", + TRAINING_CLASS, "class4" + )); + addDoc(writer, buildLuceneDocument(ID, "8", + TITLE, "word4", + CONTENT, "word5 word5 word5 word5", + AUTHOR, "Name1 Surname1", + TRAINING_CLASS, "class6", + TRAINING_CLASS, "class4" + )); + + reader = writer.getReader(); + writer.close(); + searcher = newSearcher(reader); + } + + public static Document buildLuceneDocument(Object... fieldsAndValues) { + Document luceneDoc = new Document(); + for (int i=0; i params = new HashMap<>(); - MultiMapSolrParams mmparams = new MultiMapSolrParams(params); - params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain }); - SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), - (SolrParams) mmparams) { - }; - - UpdateRequestHandler handler = new UpdateRequestHandler(); - handler.init(null); - ArrayList streams = new ArrayList<>(2); - streams.add(new ContentStreamBase.StringStream(doc)); - req.setContentStreams(streams); - handler.handleRequestBody(req, new SolrQueryResponse()); - req.close(); - } } http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java ---------------------------------------------------------------------- diff --git a/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java b/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java index d494eb6..bab5cd3 100644 --- a/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java +++ b/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java @@ -25,8 +25,6 @@ import org.junit.Test; import java.util.Map; -import static org.apache.solr.update.processor.SignatureUpdateProcessorFactoryTest.addDoc; - public class TestPartialUpdateDeduplication extends SolrTestCaseJ4 { @BeforeClass public static void beforeClass() throws Exception { http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java ---------------------------------------------------------------------- diff --git a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java index 3adad49..19bf601 100644 --- a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java +++ b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java @@ -83,7 +83,11 @@ import org.apache.solr.common.SolrInputDocument; import org.apache.solr.common.SolrInputField; import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.ModifiableSolrParams; +import org.apache.solr.common.params.MultiMapSolrParams; import org.apache.solr.common.params.SolrParams; +import org.apache.solr.common.params.UpdateParams; +import org.apache.solr.common.util.ContentStream; +import org.apache.solr.common.util.ContentStreamBase; import org.apache.solr.common.util.ObjectReleaseTracker; import org.apache.solr.common.util.XML; import org.apache.solr.core.CoreContainer; @@ -96,7 +100,9 @@ import org.apache.solr.core.SolrXmlConfig; import org.apache.solr.handler.UpdateRequestHandler; import org.apache.solr.request.LocalSolrQueryRequest; import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.request.SolrQueryRequestBase; import org.apache.solr.request.SolrRequestHandler; +import org.apache.solr.response.SolrQueryResponse; import org.apache.solr.schema.IndexSchema; import org.apache.solr.schema.SchemaField; import org.apache.solr.search.SolrIndexSearcher; @@ -1009,6 +1015,22 @@ public abstract class SolrTestCaseJ4 extends LuceneTestCase { return out.toString(); } + public static void addDoc(String doc, String updateRequestProcessorChain) throws Exception { + Map params = new HashMap<>(); + MultiMapSolrParams mmparams = new MultiMapSolrParams(params); + params.put(UpdateParams.UPDATE_CHAIN, new String[]{updateRequestProcessorChain}); + SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), + (SolrParams) mmparams) { + }; + + UpdateRequestHandler handler = new UpdateRequestHandler(); + handler.init(null); + ArrayList streams = new ArrayList<>(2); + streams.add(new ContentStreamBase.StringStream(doc)); + req.setContentStreams(streams); + handler.handleRequestBody(req, new SolrQueryResponse()); + req.close(); + } /** * Generates an <add><doc>... XML String with options