lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jbern...@apache.org
Subject [3/3] lucene-solr:master: SOLR-9252: Feature selection and logistic regression on text
Date Wed, 03 Aug 2016 17:32:29 GMT
SOLR-9252: Feature selection and logistic regression on text


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/87938e00
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/87938e00
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/87938e00

Branch: refs/heads/master
Commit: 87938e00e9f1006801fbf0e8c0d7b2a84b5eda48
Parents: 9fc4624
Author: jbernste <jbernste@apache.org>
Authored: Wed Aug 3 11:12:57 2016 -0400
Committer: jbernste <jbernste@apache.org>
Committed: Wed Aug 3 11:43:00 2016 -0400

----------------------------------------------------------------------
 .../org/apache/solr/handler/StreamHandler.java  |   2 +
 .../solr/search/IGainTermsQParserPlugin.java    | 239 +++++++
 .../org/apache/solr/search/QParserPlugin.java   |  12 +-
 .../TextLogisticRegressionQParserPlugin.java    | 283 ++++++++
 .../apache/solr/search/QueryEqualityTest.java   |  18 +
 .../solrj/io/ClassificationEvaluation.java      |  85 +++
 .../io/stream/FeaturesSelectionStream.java      | 436 ++++++++++++
 .../client/solrj/io/stream/TextLogitStream.java | 657 +++++++++++++++++++
 .../solrj/io/stream/expr/Explanation.java       |   1 +
 .../solrj/solr/configsets/ml/conf/schema.xml    |  77 +++
 .../solr/configsets/ml/conf/solrconfig.xml      |  51 ++
 .../solrj/io/stream/StreamExpressionTest.java   | 185 +++++-
 .../stream/StreamExpressionToExpessionTest.java |  37 +-
 .../StreamExpressionToExplanationTest.java      |   1 -
 14 files changed, 2076 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
index b34cff5..e97df34 100644
--- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
+++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
@@ -122,6 +122,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
       .withFunctionName("intersect", IntersectStream.class)
       .withFunctionName("complement", ComplementStream.class)
       .withFunctionName("sort", SortStream.class)
+      .withFunctionName("train", TextLogitStream.class)
+      .withFunctionName("features", FeaturesSelectionStream.class)
       .withFunctionName("daemon", DaemonStream.class)
       .withFunctionName("shortestPath", ShortestPathStream.class)
       .withFunctionName("gatherNodes", GatherNodesStream.class)

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/core/src/java/org/apache/solr/search/IGainTermsQParserPlugin.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/search/IGainTermsQParserPlugin.java b/solr/core/src/java/org/apache/solr/search/IGainTermsQParserPlugin.java
new file mode 100644
index 0000000..6c99813
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/search/IGainTermsQParserPlugin.java
@@ -0,0 +1,239 @@
+package org.apache.solr.search;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.TreeSet;
+
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.index.PostingsEnum;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.SparseFixedBitSet;
+import org.apache.solr.common.params.SolrParams;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.handler.component.ResponseBuilder;
+import org.apache.solr.request.SolrQueryRequest;
+
+public class IGainTermsQParserPlugin extends QParserPlugin {
+
+  public static final String NAME = "igain";
+
+  @Override
+  public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
+    return new IGainTermsQParser(qstr, localParams, params, req);
+  }
+
+  private static class IGainTermsQParser extends QParser {
+
+    public IGainTermsQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
+      super(qstr, localParams, params, req);
+    }
+
+    @Override
+    public Query parse() throws SyntaxError {
+
+      String field = getParam("field");
+      String outcome = getParam("outcome");
+      int numTerms = Integer.parseInt(getParam("numTerms"));
+      int positiveLabel = Integer.parseInt(getParam("positiveLabel"));
+
+      return new IGainTermsQuery(field, outcome, positiveLabel, numTerms);
+    }
+  }
+
+  private static class IGainTermsQuery extends AnalyticsQuery {
+
+    private String field;
+    private String outcome;
+    private int numTerms;
+    private int positiveLabel;
+
+    public IGainTermsQuery(String field, String outcome, int positiveLabel, int numTerms) {
+      this.field = field;
+      this.outcome = outcome;
+      this.numTerms = numTerms;
+      this.positiveLabel = positiveLabel;
+    }
+
+    @Override
+    public DelegatingCollector getAnalyticsCollector(ResponseBuilder rb, IndexSearcher searcher) {
+      return new IGainTermsCollector(rb, searcher, field, outcome, positiveLabel, numTerms);
+    }
+  }
+
+  private static class IGainTermsCollector extends DelegatingCollector {
+
+    private String field;
+    private String outcome;
+    private IndexSearcher searcher;
+    private ResponseBuilder rb;
+    private int positiveLabel;
+    private int numTerms;
+    private int count;
+
+    private NumericDocValues leafOutcomeValue;
+    private SparseFixedBitSet positiveSet;
+    private SparseFixedBitSet negativeSet;
+
+
+    private int numPositiveDocs;
+
+
+    public IGainTermsCollector(ResponseBuilder rb, IndexSearcher searcher, String field, String outcome, int positiveLabel, int numTerms) {
+      this.rb = rb;
+      this.searcher = searcher;
+      this.field = field;
+      this.outcome = outcome;
+      this.positiveSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
+      this.negativeSet = new SparseFixedBitSet(searcher.getIndexReader().maxDoc());
+
+      this.numTerms = numTerms;
+      this.positiveLabel = positiveLabel;
+    }
+
+    @Override
+    protected void doSetNextReader(LeafReaderContext context) throws IOException {
+      super.doSetNextReader(context);
+      LeafReader reader = context.reader();
+      leafOutcomeValue = reader.getNumericDocValues(outcome);
+    }
+
+    @Override
+    public void collect(int doc) throws IOException {
+      super.collect(doc);
+      ++count;
+      if (leafOutcomeValue.get(doc) == positiveLabel) {
+        positiveSet.set(context.docBase + doc);
+        numPositiveDocs++;
+      } else {
+        negativeSet.set(context.docBase + doc);
+      }
+    }
+
+    @Override
+    public void finish() throws IOException {
+      NamedList<Double> analytics = new NamedList<Double>();
+      NamedList<Integer> topFreq = new NamedList();
+
+      NamedList<Integer> allFreq = new NamedList();
+
+      rb.rsp.add("featuredTerms", analytics);
+      rb.rsp.add("docFreq", topFreq);
+      rb.rsp.add("numDocs", count);
+
+      TreeSet<TermWithScore> topTerms = new TreeSet<>();
+
+      double numDocs = count;
+      double pc = numPositiveDocs / numDocs;
+      double entropyC = binaryEntropy(pc);
+
+      Terms terms = MultiFields.getFields(searcher.getIndexReader()).terms(field);
+      TermsEnum termsEnum = terms.iterator();
+      BytesRef term;
+      PostingsEnum postingsEnum = null;
+      while ((term = termsEnum.next()) != null) {
+        postingsEnum = termsEnum.postings(postingsEnum);
+        int xc = 0;
+        int nc = 0;
+        while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+          if (positiveSet.get(postingsEnum.docID())) {
+            xc++;
+          } else if (negativeSet.get(postingsEnum.docID())) {
+            nc++;
+          }
+        }
+
+        int docFreq = xc+nc;
+
+        double entropyContainsTerm = binaryEntropy( (double) xc / docFreq );
+        double entropyNotContainsTerm = binaryEntropy( (double) (numPositiveDocs - xc) / (numDocs - docFreq + 1) );
+        double score = entropyC - ( (docFreq / numDocs) * entropyContainsTerm + (1.0 - docFreq / numDocs) * entropyNotContainsTerm);
+
+        topFreq.add(term.utf8ToString(), docFreq);
+        if (topTerms.size() < numTerms) {
+          topTerms.add(new TermWithScore(term.utf8ToString(), score));
+        } else  {
+          if (topTerms.first().score < score) {
+            topTerms.pollFirst();
+            topTerms.add(new TermWithScore(term.utf8ToString(), score));
+          }
+        }
+      }
+
+      for (TermWithScore topTerm : topTerms) {
+        analytics.add(topTerm.term, topTerm.score);
+        topFreq.add(topTerm.term, allFreq.get(topTerm.term));
+      }
+
+      if (this.delegate instanceof DelegatingCollector) {
+        ((DelegatingCollector) this.delegate).finish();
+      }
+    }
+
+    private double binaryEntropy(double prob) {
+      if (prob == 0 || prob == 1) return 0;
+      return (-1 * prob * Math.log(prob)) + (-1 * (1.0 - prob) * Math.log(1.0 - prob));
+    }
+
+  }
+
+
+
+  private static class TermWithScore implements Comparable<TermWithScore>{
+    public final String term;
+    public final double score;
+
+    public TermWithScore(String term, double score) {
+      this.term = term;
+      this.score = score;
+    }
+
+    @Override
+    public int hashCode() {
+      return term.hashCode();
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+      if (obj == null) return false;
+      if (obj.getClass() != getClass()) return false;
+      TermWithScore other = (TermWithScore) obj;
+      return other.term.equals(this.term);
+    }
+
+    @Override
+    public int compareTo(TermWithScore o) {
+      int cmp = Double.compare(this.score, o.score);
+      if (cmp == 0) {
+        return this.term.compareTo(o.term);
+      } else {
+        return cmp;
+      }
+    }
+  }
+}
+
+

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/core/src/java/org/apache/solr/search/QParserPlugin.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/search/QParserPlugin.java b/solr/core/src/java/org/apache/solr/search/QParserPlugin.java
index 7a6247f..573286b 100644
--- a/solr/core/src/java/org/apache/solr/search/QParserPlugin.java
+++ b/solr/core/src/java/org/apache/solr/search/QParserPlugin.java
@@ -16,6 +16,11 @@
  */
 package org.apache.solr.search;
 
+import java.net.URL;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
 import org.apache.solr.common.params.SolrParams;
 import org.apache.solr.common.util.NamedList;
 import org.apache.solr.core.SolrInfoMBean;
@@ -26,11 +31,6 @@ import org.apache.solr.search.join.GraphQParserPlugin;
 import org.apache.solr.search.mlt.MLTQParserPlugin;
 import org.apache.solr.util.plugin.NamedListInitializedPlugin;
 
-import java.net.URL;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
 public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrInfoMBean {
   /** internal use - name of the default parser */
   public static final String DEFAULT_QTYPE = LuceneQParserPlugin.NAME;
@@ -77,6 +77,8 @@ public abstract class QParserPlugin implements NamedListInitializedPlugin, SolrI
     map.put(GraphQParserPlugin.NAME, GraphQParserPlugin.class);
     map.put(XmlQParserPlugin.NAME, XmlQParserPlugin.class);
     map.put(GraphTermsQParserPlugin.NAME, GraphTermsQParserPlugin.class);
+    map.put(IGainTermsQParserPlugin.NAME, IGainTermsQParserPlugin.class);
+    map.put(TextLogisticRegressionQParserPlugin.NAME, TextLogisticRegressionQParserPlugin.class);
     standardPlugins = Collections.unmodifiableMap(map);
   }
 

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/core/src/java/org/apache/solr/search/TextLogisticRegressionQParserPlugin.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/search/TextLogisticRegressionQParserPlugin.java b/solr/core/src/java/org/apache/solr/search/TextLogisticRegressionQParserPlugin.java
new file mode 100644
index 0000000..e8fbaf6
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/search/TextLogisticRegressionQParserPlugin.java
@@ -0,0 +1,283 @@
+package org.apache.solr.search;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.index.PostingsEnum;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.SparseFixedBitSet;
+import org.apache.solr.client.solrj.io.ClassificationEvaluation;
+import org.apache.solr.common.params.SolrParams;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.handler.component.ResponseBuilder;
+import org.apache.solr.request.SolrQueryRequest;
+
+/**
+ *   Returns an AnalyticsQuery implementation that performs
+ *   one Gradient Descent iteration of a result set to train a
+ *   logistic regression model
+ *
+ *   The TextLogitStream provides the parallel iterative framework for this class.
+ **/
+
+public class TextLogisticRegressionQParserPlugin extends QParserPlugin {
+  public static final String NAME = "tlogit";
+
+  @Override
+  public void init(NamedList args) {
+  }
+
+  @Override
+  public QParser createParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
+    return new TextLogisticRegressionQParser(qstr, localParams, params, req);
+  }
+
+  private static class TextLogisticRegressionQParser extends QParser{
+
+    TextLogisticRegressionQParser(String qstr, SolrParams localParams, SolrParams params, SolrQueryRequest req) {
+      super(qstr, localParams, params, req);
+    }
+
+    public Query parse() {
+
+      String fs = params.get("feature");
+      String[] terms = params.get("terms").split(",");
+      String ws = params.get("weights");
+      String dfsStr = params.get("idfs");
+      int iteration = params.getInt("iteration");
+      String outcome = params.get("outcome");
+      int positiveLabel = params.getInt("positiveLabel", 1);
+      double threshold = params.getDouble("threshold", 0.5);
+      double alpha = params.getDouble("alpha", 0.01);
+
+      double[] idfs = new double[terms.length];
+      String[] idfsArr = dfsStr.split(",");
+      for (int i = 0; i < idfsArr.length; i++) {
+        idfs[i] = Double.parseDouble(idfsArr[i]);
+      }
+
+      double[] weights = new double[terms.length+1];
+
+      if(ws != null) {
+        String[] wa = ws.split(",");
+        for (int i = 0; i < wa.length; i++) {
+          weights[i] = Double.parseDouble(wa[i]);
+        }
+      } else {
+        for(int i=0; i<weights.length; i++) {
+          weights[i]= 1.0d;
+        }
+      }
+
+      TrainingParams input = new TrainingParams(fs, terms, idfs, outcome, weights, iteration, alpha, positiveLabel, threshold);
+
+      return new TextLogisticRegressionQuery(input);
+    }
+  }
+
+  private static class TextLogisticRegressionQuery extends AnalyticsQuery {
+    private TrainingParams trainingParams;
+
+    public TextLogisticRegressionQuery(TrainingParams trainingParams) {
+      this.trainingParams = trainingParams;
+    }
+
+    public DelegatingCollector getAnalyticsCollector(ResponseBuilder rbsp, IndexSearcher indexSearcher) {
+      return new TextLogisticRegressionCollector(rbsp, indexSearcher, trainingParams);
+    }
+  }
+
+  private static class TextLogisticRegressionCollector extends DelegatingCollector {
+    private TrainingParams trainingParams;
+    private LeafReader leafReader;
+
+    private double[] workingDeltas;
+    private ClassificationEvaluation classificationEvaluation;
+    private double[] weights;
+
+    private ResponseBuilder rbsp;
+    private NumericDocValues leafOutcomeValue;
+    private double totalError;
+    private SparseFixedBitSet positiveDocsSet;
+    private SparseFixedBitSet docsSet;
+    private IndexSearcher searcher;
+
+    TextLogisticRegressionCollector(ResponseBuilder rbsp, IndexSearcher searcher,
+                                    TrainingParams trainingParams) {
+      this.trainingParams = trainingParams;
+      this.workingDeltas = new double[trainingParams.weights.length];
+      this.weights = Arrays.copyOf(trainingParams.weights, trainingParams.weights.length);
+      this.rbsp = rbsp;
+      this.classificationEvaluation = new ClassificationEvaluation();
+      this.searcher = searcher;
+      positiveDocsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
+      docsSet = new SparseFixedBitSet(searcher.getIndexReader().numDocs());
+    }
+
+    public void doSetNextReader(LeafReaderContext context) throws IOException {
+      super.doSetNextReader(context);
+      leafReader = context.reader();
+      leafOutcomeValue = leafReader.getNumericDocValues(trainingParams.outcome);
+    }
+
+    public void collect(int doc) throws IOException{
+
+      int outcome = (int) leafOutcomeValue.get(doc);
+      outcome = trainingParams.positiveLabel == outcome? 1 : 0;
+      if (outcome == 1) {
+        positiveDocsSet.set(context.docBase + doc);
+      }
+      docsSet.set(context.docBase+doc);
+
+    }
+
+    public void finish() throws IOException {
+
+      Map<Integer, double[]> docVectors = new HashMap<>();
+      Terms terms = MultiFields.getFields(searcher.getIndexReader()).terms(trainingParams.feature);
+      TermsEnum termsEnum = terms.iterator();
+      PostingsEnum postingsEnum = null;
+      int termIndex = 0;
+      for (String termStr : trainingParams.terms) {
+        BytesRef term = new BytesRef(termStr);
+        if (termsEnum.seekExact(term)) {
+          postingsEnum = termsEnum.postings(postingsEnum);
+          while (postingsEnum.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+            int docId = postingsEnum.docID();
+            if (docsSet.get(docId)) {
+              double[] vector = docVectors.get(docId);
+              if (vector == null) {
+                vector = new double[trainingParams.terms.length+1];
+                vector[0] = 1.0;
+                docVectors.put(docId, vector);
+              }
+              vector[termIndex + 1] = trainingParams.idfs[termIndex] * (1.0 + Math.log(postingsEnum.freq()));
+            }
+          }
+        }
+        termIndex++;
+      }
+
+      for (Map.Entry<Integer, double[]> entry : docVectors.entrySet()) {
+        double[] vector = entry.getValue();
+        int outcome = 0;
+        if (positiveDocsSet.get(entry.getKey())) {
+          outcome = 1;
+        }
+        double sig = sigmoid(sum(multiply(vector, weights)));
+        double error = sig - outcome;
+        double lastSig = sigmoid(sum(multiply(vector, trainingParams.weights)));
+        totalError += Math.abs(lastSig - outcome);
+        classificationEvaluation.count(outcome,  lastSig >= trainingParams.threshold ? 1 : 0);
+
+        workingDeltas = multiply(error * trainingParams.alpha, vector);
+
+        for(int i = 0; i< workingDeltas.length; i++) {
+          weights[i] -= workingDeltas[i];
+        }
+      }
+
+      NamedList analytics = new NamedList();
+      rbsp.rsp.add("logit", analytics);
+
+      List<Double> outWeights = new ArrayList<>();
+      for(Double d : weights) {
+        outWeights.add(d);
+      }
+
+      analytics.add("weights", outWeights);
+      analytics.add("error", totalError);
+      analytics.add("evaluation", classificationEvaluation.toMap());
+      analytics.add("feature", trainingParams.feature);
+      analytics.add("positiveLabel", trainingParams.positiveLabel);
+      if(this.delegate instanceof DelegatingCollector) {
+        ((DelegatingCollector)this.delegate).finish();
+      }
+    }
+
+    private double sigmoid(double in) {
+      double d = 1.0 / (1+Math.exp(-in));
+      return d;
+    }
+
+    private double[] multiply(double[] vals, double[] weights) {
+      for(int i = 0; i < vals.length; ++i) {
+        workingDeltas[i] = vals[i] * weights[i];
+      }
+
+      return workingDeltas;
+    }
+
+    private double[] multiply(double d, double[] vals) {
+      for(int i = 0; i<vals.length; ++i) {
+        workingDeltas[i] = vals[i] * d;
+      }
+
+      return workingDeltas;
+    }
+
+    private double sum(double[] vals) {
+      double d = 0.0d;
+      for(double val : vals) {
+        d += val;
+      }
+
+      return d;
+    }
+
+  }
+
+  private static class TrainingParams {
+    public final String feature;
+    public final String[] terms;
+    public final double[] idfs;
+    public final String outcome;
+    public final double[] weights;
+    public final int interation;
+    public final int positiveLabel;
+    public final double threshold;
+    public final double alpha;
+
+    public TrainingParams(String feature, String[] terms, double[] idfs, String outcome, double[] weights, int interation, double alpha, int positiveLabel, double threshold) {
+      this.feature = feature;
+      this.terms = terms;
+      this.idfs = idfs;
+      this.outcome = outcome;
+      this.weights = weights;
+      this.alpha = alpha;
+      this.interation = interation;
+      this.positiveLabel = positiveLabel;
+      this.threshold = threshold;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
----------------------------------------------------------------------
diff --git a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
index 9c51844..86c7ee8 100644
--- a/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
+++ b/solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
@@ -175,6 +175,24 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
     }
   }
 
+  public void testTlogitQuery() throws Exception {
+    SolrQueryRequest req = req("q", "*:*", "feature", "f", "terms","a,b,c", "weights", "100,200,300", "idfs","1,5,7","iteration","1", "outcome","a","positiveLabel","1");
+    try {
+      assertQueryEquals("tlogit", req, "{!tlogit}");
+    } finally {
+      req.close();
+    }
+  }
+
+  public void testIGainQuery() throws Exception {
+    SolrQueryRequest req = req("q", "*:*", "outcome", "b", "positiveLabel", "1", "field", "x", "numTerms","200");
+    try {
+      assertQueryEquals("igain", req, "{!igain}");
+    } finally {
+      req.close();
+    }
+  }
+
   public void testQuerySwitch() throws Exception {
     SolrQueryRequest req = req("myXXX", "XXX", 
                                "myField", "foo_s",

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/solrj/src/java/org/apache/solr/client/solrj/io/ClassificationEvaluation.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/ClassificationEvaluation.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ClassificationEvaluation.java
new file mode 100644
index 0000000..470f985
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/ClassificationEvaluation.java
@@ -0,0 +1,85 @@
+package org.apache.solr.client.solrj.io;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class ClassificationEvaluation {
+  private long truePositive;
+  private long falsePositive;
+  private long trueNegative;
+  private long falseNegative;
+
+  public void count(int actual, int predicted) {
+    if (predicted == 1) {
+      if (actual == 1) truePositive++;
+      else falsePositive++;
+    } else {
+      if (actual == 0) trueNegative++;
+      else falseNegative++;
+    }
+  }
+
+  public void putToMap(Map map) {
+    map.put("truePositive_i",truePositive);
+    map.put("trueNegative_i",trueNegative);
+    map.put("falsePositive_i",falsePositive);
+    map.put("falseNegative_i",falseNegative);
+  }
+
+  public Map toMap() {
+    HashMap map = new HashMap();
+    putToMap(map);
+    return map;
+  }
+
+  public static ClassificationEvaluation create(Map map) {
+    ClassificationEvaluation evaluation = new ClassificationEvaluation();
+    evaluation.addEvaluation(map);
+    return evaluation;
+  }
+
+  public void addEvaluation(Map map) {
+    this.truePositive += (long) map.get("truePositive_i");
+    this.trueNegative += (long) map.get("trueNegative_i");
+    this.falsePositive += (long) map.get("falsePositive_i");
+    this.falseNegative += (long) map.get("falseNegative_i");
+  }
+
+  public double getPrecision() {
+    if (truePositive + falsePositive == 0) return 0;
+    return (double) truePositive / (truePositive + falsePositive);
+  }
+
+  public double getRecall() {
+    if (truePositive + falseNegative == 0) return 0;
+    return (double) truePositive / (truePositive + falseNegative);
+  }
+
+  public double getF1() {
+    double precision = getPrecision();
+    double recall = getRecall();
+    if (precision + recall == 0) return 0;
+    return 2 * (precision * recall) / (precision + recall);
+  }
+
+  public double getAccuracy() {
+    return (double) (truePositive + trueNegative) / (truePositive + trueNegative + falseNegative + falsePositive);
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FeaturesSelectionStream.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FeaturesSelectionStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FeaturesSelectionStream.java
new file mode 100644
index 0000000..007e3d8
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FeaturesSelectionStream.java
@@ -0,0 +1,436 @@
+package org.apache.solr.client.solrj.io.stream;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.stream.Stream;
+
+import org.apache.solr.client.solrj.impl.CloudSolrClient;
+import org.apache.solr.client.solrj.impl.HttpSolrClient;
+import org.apache.solr.client.solrj.io.SolrClientCache;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.client.solrj.request.QueryRequest;
+import org.apache.solr.client.solrj.response.QueryResponse;
+import org.apache.solr.common.cloud.ClusterState;
+import org.apache.solr.common.cloud.Replica;
+import org.apache.solr.common.cloud.Slice;
+import org.apache.solr.common.cloud.ZkCoreNodeProps;
+import org.apache.solr.common.cloud.ZkStateReader;
+import org.apache.solr.common.params.ModifiableSolrParams;
+import org.apache.solr.common.util.ExecutorUtil;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.common.util.SolrjNamedThreadFactory;
+
+public class FeaturesSelectionStream extends TupleStream implements Expressible{
+
+  private static final long serialVersionUID = 1;
+
+  protected String zkHost;
+  protected String collection;
+  protected Map<String,String> params;
+  protected Iterator<Tuple> tupleIterator;
+  protected String field;
+  protected String outcome;
+  protected String featureSet;
+  protected int positiveLabel;
+  protected int numTerms;
+
+
+
+  protected transient SolrClientCache cache;
+  protected transient boolean isCloseCache;
+  protected transient CloudSolrClient cloudSolrClient;
+
+  protected transient StreamContext streamContext;
+  protected ExecutorService executorService;
+
+
+  public FeaturesSelectionStream(String zkHost,
+                     String collectionName,
+                     Map params,
+                     String field,
+                     String outcome,
+                     String featureSet,
+                     int positiveLabel,
+                     int numTerms) throws IOException {
+
+    init(collectionName, zkHost, params, field, outcome, featureSet, positiveLabel, numTerms);
+  }
+
+  /**
+   *   logit(collection, zkHost="", features="a,b,c,d,e,f,g", outcome="y", maxIteration="20")
+   **/
+
+  public FeaturesSelectionStream(StreamExpression expression, StreamFactory factory) throws IOException{
+    // grab all parameters out
+    String collectionName = factory.getValueOperand(expression, 0);
+    List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+    StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
+
+    // Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice
+    if(expression.getParameters().size() != 1 + namedParams.size()){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
+    }
+
+    // Collection Name
+    if(null == collectionName){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
+    }
+
+    // Named parameters - passed directly to solr as solrparams
+    if(0 == namedParams.size()){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
+    }
+
+    Map<String,String> params = new HashMap<String,String>();
+    for(StreamExpressionNamedParameter namedParam : namedParams){
+      if(!namedParam.getName().equals("zkHost")) {
+        params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
+      }
+    }
+
+    String fieldParam = params.get("field");
+    if(fieldParam != null) {
+      params.remove("field");
+    } else {
+      throw new IOException("field param cannot be null for FeaturesSelectionStream");
+    }
+
+    String outcomeParam = params.get("outcome");
+    if(outcomeParam != null) {
+      params.remove("outcome");
+    } else {
+      throw new IOException("outcome param cannot be null for FeaturesSelectionStream");
+    }
+
+    String featureSetParam = params.get("featureSet");
+    if(featureSetParam != null) {
+      params.remove("featureSet");
+    } else {
+      throw new IOException("featureSet param cannot be null for FeaturesSelectionStream");
+    }
+
+    String positiveLabelParam = params.get("positiveLabel");
+    int positiveLabel = 1;
+    if(positiveLabelParam != null) {
+      params.remove("positiveLabel");
+      positiveLabel = Integer.parseInt(positiveLabelParam);
+    }
+
+    String numTermsParam = params.get("numTerms");
+    int numTerms = 1;
+    if(numTermsParam != null) {
+      numTerms = Integer.parseInt(numTermsParam);
+      params.remove("numTerms");
+    } else {
+      throw new IOException("numTerms param cannot be null for FeaturesSelectionStream");
+    }
+
+    // zkHost, optional - if not provided then will look into factory list to get
+    String zkHost = null;
+    if(null == zkHostExpression){
+      zkHost = factory.getCollectionZkHost(collectionName);
+    }
+    else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
+      zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
+    }
+    if(null == zkHost){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
+    }
+
+    // We've got all the required items
+    init(collectionName, zkHost, params, fieldParam, outcomeParam, featureSetParam, positiveLabel, numTerms);
+  }
+
+  @Override
+  public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
+    // functionName(collectionName, param1, param2, ..., paramN, sort="comp", [aliases="field=alias,..."])
+
+    // function name
+    StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
+
+    // collection
+    expression.addParameter(collection);
+
+    // parameters
+    for(Map.Entry<String,String> param : params.entrySet()){
+      expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
+    }
+
+    expression.addParameter(new StreamExpressionNamedParameter("field", field));
+    expression.addParameter(new StreamExpressionNamedParameter("outcome", outcome));
+    expression.addParameter(new StreamExpressionNamedParameter("featureSet", featureSet));
+    expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", String.valueOf(positiveLabel)));
+    expression.addParameter(new StreamExpressionNamedParameter("numTerms", String.valueOf(numTerms)));
+
+    // zkHost
+    expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
+
+    return expression;
+  }
+
+  private void init(String collectionName,
+                    String zkHost,
+                    Map params,
+                    String field,
+                    String outcome,
+                    String featureSet,
+                    int positiveLabel, int numTopTerms) throws IOException {
+    this.zkHost = zkHost;
+    this.collection = collectionName;
+    this.params = params;
+    this.field = field;
+    this.outcome = outcome;
+    this.featureSet = featureSet;
+    this.positiveLabel = positiveLabel;
+    this.numTerms = numTopTerms;
+  }
+
+  public void setStreamContext(StreamContext context) {
+    this.cache = context.getSolrClientCache();
+    this.streamContext = context;
+  }
+
+  /**
+   * Opens the CloudSolrStream
+   *
+   ***/
+  public void open() throws IOException {
+    if (cache == null) {
+      isCloseCache = true;
+      cache = new SolrClientCache();
+    } else {
+      isCloseCache = false;
+    }
+
+    this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
+    this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("FeaturesSelectionStream"));
+  }
+
+  public List<TupleStream> children() {
+    return null;
+  }
+
+  private List<String> getShardUrls() throws IOException {
+
+    try {
+
+      ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
+      ClusterState clusterState = zkStateReader.getClusterState();
+
+      Collection<Slice> slices = clusterState.getActiveSlices(this.collection);
+      Set<String> liveNodes = clusterState.getLiveNodes();
+
+      List<String> baseUrls = new ArrayList<>();
+
+      for(Slice slice : slices) {
+        Collection<Replica> replicas = slice.getReplicas();
+        List<Replica> shuffler = new ArrayList<>();
+        for(Replica replica : replicas) {
+          if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
+            shuffler.add(replica);
+          }
+        }
+
+        Collections.shuffle(shuffler, new Random());
+        Replica rep = shuffler.get(0);
+        ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
+        String url = zkProps.getCoreUrl();
+        baseUrls.add(url);
+      }
+
+      return baseUrls;
+
+    } catch (Exception e) {
+      throw new IOException(e);
+    }
+  }
+
+  private List<Future<NamedList>> callShards(List<String> baseUrls) throws IOException {
+
+    List<Future<NamedList>> futures = new ArrayList<>();
+    for (String baseUrl : baseUrls) {
+      FeaturesSelectionCall lc = new FeaturesSelectionCall(baseUrl,
+          this.params,
+          this.field,
+          this.outcome);
+
+      Future<NamedList> future = executorService.submit(lc);
+      futures.add(future);
+    }
+
+    return futures;
+  }
+
+  public void close() throws IOException {
+    if (isCloseCache) {
+      cache.close();
+    }
+
+    executorService.shutdown();
+  }
+
+  /** Return the stream sort - ie, the order in which records are returned */
+  public StreamComparator getStreamSort(){
+    return null;
+  }
+
+  @Override
+  public Explanation toExplanation(StreamFactory factory) throws IOException {
+    return new StreamExplanation(getStreamNodeId().toString())
+        .withFunctionName(factory.getFunctionName(this.getClass()))
+        .withImplementingClass(this.getClass().getName())
+        .withExpressionType(Explanation.ExpressionType.STREAM_DECORATOR)
+        .withExpression(toExpression(factory).toString());
+  }
+
+  public Tuple read() throws IOException {
+    try {
+      if (tupleIterator == null) {
+        Map<String, Double> termScores = new HashMap<>();
+        Map<String, Long> docFreqs = new HashMap<>();
+
+
+        long numDocs = 0;
+        for (Future<NamedList> getTopTermsCall : callShards(getShardUrls())) {
+          NamedList resp = getTopTermsCall.get();
+
+          NamedList<Double> shardTopTerms = (NamedList<Double>)resp.get("featuredTerms");
+          NamedList<Integer> shardDocFreqs = (NamedList<Integer>)resp.get("docFreq");
+
+          numDocs += (Integer)resp.get("numDocs");
+
+          for (int i = 0; i < shardTopTerms.size(); i++) {
+            String term = shardTopTerms.getName(i);
+            double score = shardTopTerms.getVal(i);
+            int docFreq = shardDocFreqs.get(term);
+            double prevScore = termScores.containsKey(term) ? termScores.get(term) : 0;
+            long prevDocFreq = docFreqs.containsKey(term) ? docFreqs.get(term) : 0;
+            termScores.put(term, prevScore + score);
+            docFreqs.put(term, prevDocFreq + docFreq);
+
+          }
+        }
+
+        List<Tuple> tuples = new ArrayList<>(numTerms);
+        termScores = sortByValue(termScores);
+        int index = 0;
+        for (Map.Entry<String, Double> termScore : termScores.entrySet()) {
+          if (tuples.size() == numTerms) break;
+          index++;
+          Map map = new HashMap();
+          map.put("id", featureSet + "_" + index);
+          map.put("index_i", index);
+          map.put("term_s", termScore.getKey());
+          map.put("score_f", termScore.getValue());
+          map.put("featureSet_s", featureSet);
+          long docFreq = docFreqs.get(termScore.getKey());
+          double d = Math.log(((double)numDocs / (double)(docFreq + 1)));
+          map.put("idf_d", d);
+          tuples.add(new Tuple(map));
+        }
+
+        Map map = new HashMap();
+        map.put("EOF", true);
+        tuples.add(new Tuple(map));
+
+        tupleIterator = tuples.iterator();
+      }
+
+      return tupleIterator.next();
+    } catch(Exception e) {
+      throw new IOException(e);
+    }
+  }
+
+  private  <K, V extends Comparable<? super V>> Map<K, V> sortByValue( Map<K, V> map )
+  {
+    Map<K, V> result = new LinkedHashMap<>();
+    Stream<Map.Entry<K, V>> st = map.entrySet().stream();
+
+    st.sorted( Map.Entry.comparingByValue(
+        (c1, c2) -> c2.compareTo(c1)
+    ) ).forEachOrdered( e -> result.put(e.getKey(), e.getValue()) );
+
+    return result;
+  }
+
+  protected class FeaturesSelectionCall implements Callable<NamedList> {
+
+    private String baseUrl;
+    private String outcome;
+    private String field;
+    private Map<String, String> paramsMap;
+
+    public FeaturesSelectionCall(String baseUrl,
+                                 Map<String, String> paramsMap,
+                                 String field,
+                                 String outcome) {
+
+      this.baseUrl = baseUrl;
+      this.outcome = outcome;
+      this.field = field;
+      this.paramsMap = paramsMap;
+    }
+
+    public NamedList<Double> call() throws Exception {
+      ModifiableSolrParams params = new ModifiableSolrParams();
+      HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl);
+
+      params.add("distrib", "false");
+      params.add("fq","{!igain}");
+
+      for(String key : paramsMap.keySet()) {
+        params.add(key, paramsMap.get(key));
+      }
+
+      params.add("outcome", outcome);
+      params.add("positiveLabel", Integer.toString(positiveLabel));
+      params.add("field", field);
+      params.add("numTerms", String.valueOf(numTerms));
+
+      QueryRequest request= new QueryRequest(params);
+      QueryResponse response = request.process(solrClient);
+      NamedList res = response.getResponse();
+      return res;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TextLogitStream.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TextLogitStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TextLogitStream.java
new file mode 100644
index 0000000..f49168f
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TextLogitStream.java
@@ -0,0 +1,657 @@
+package org.apache.solr.client.solrj.io.stream;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.solr.client.solrj.SolrRequest;
+import org.apache.solr.client.solrj.SolrServerException;
+import org.apache.solr.client.solrj.impl.CloudSolrClient;
+import org.apache.solr.client.solrj.impl.HttpSolrClient;
+import org.apache.solr.client.solrj.io.ClassificationEvaluation;
+import org.apache.solr.client.solrj.io.SolrClientCache;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.comp.StreamComparator;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.client.solrj.request.QueryRequest;
+import org.apache.solr.client.solrj.response.QueryResponse;
+import org.apache.solr.common.cloud.ClusterState;
+import org.apache.solr.common.cloud.Replica;
+import org.apache.solr.common.cloud.Slice;
+import org.apache.solr.common.cloud.ZkCoreNodeProps;
+import org.apache.solr.common.cloud.ZkStateReader;
+import org.apache.solr.common.params.ModifiableSolrParams;
+import org.apache.solr.common.util.ExecutorUtil;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.common.util.SolrjNamedThreadFactory;
+
+public class TextLogitStream extends TupleStream implements Expressible {
+
+  private static final long serialVersionUID = 1;
+
+  protected String zkHost;
+  protected String collection;
+  protected Map<String,String> params;
+  protected String field;
+  protected String name;
+  protected String outcome;
+  protected int positiveLabel;
+  protected double threshold;
+  protected List<Double> weights;
+  protected int maxIterations;
+  protected int iteration;
+  protected double error;
+  protected List<Double> idfs;
+  protected ClassificationEvaluation evaluation;
+
+  protected transient SolrClientCache cache;
+  protected transient boolean isCloseCache;
+  protected transient CloudSolrClient cloudSolrClient;
+
+  protected transient StreamContext streamContext;
+  protected ExecutorService executorService;
+  protected TupleStream termsStream;
+  private List<String> terms;
+
+  private double learningRate = 0.01;
+  private double lastError = 0;
+
+  public TextLogitStream(String zkHost,
+                     String collectionName,
+                     Map params,
+                     String name,
+                     String field,
+                     TupleStream termsStream,
+                     List<Double> weights,
+                     String outcome,
+                     int positiveLabel,
+                     double threshold,
+                     int maxIterations) throws IOException {
+
+    init(collectionName, zkHost, params, name, field, termsStream, weights, outcome, positiveLabel, threshold, maxIterations, iteration);
+  }
+
+  /**
+   *   logit(collection, zkHost="", features="a,b,c,d,e,f,g", outcome="y", maxIteration="20")
+   **/
+
+  public TextLogitStream(StreamExpression expression, StreamFactory factory) throws IOException{
+    // grab all parameters out
+    String collectionName = factory.getValueOperand(expression, 0);
+    List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+    StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
+    List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
+
+    // Validate there are no unknown parameters - zkHost and alias are namedParameter so we don't need to count it twice
+    if(expression.getParameters().size() != 1 + namedParams.size() + streamExpressions.size()){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - unknown operands found",expression));
+    }
+
+    // Collection Name
+    if(null == collectionName){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
+    }
+
+    // Named parameters - passed directly to solr as solrparams
+    if(0 == namedParams.size()){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - at least one named parameter expected. eg. 'q=*:*'",expression));
+    }
+
+    Map<String,String> params = new HashMap<String,String>();
+    for(StreamExpressionNamedParameter namedParam : namedParams){
+      if(!namedParam.getName().equals("zkHost")) {
+        params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
+      }
+    }
+
+    String name = params.get("name");
+    if (name != null) {
+      params.remove("name");
+    } else {
+      throw new IOException("name param cannot be null for TextLogitStream");
+    }
+
+    String feature = params.get("field");
+    if (feature != null) {
+      params.remove("field");
+    } else {
+      throw new IOException("field param cannot be null for TextLogitStream");
+    }
+
+    TupleStream stream = null;
+
+    if (streamExpressions.size() > 0) {
+      stream = factory.constructStream(streamExpressions.get(0));
+    } else {
+      throw new IOException("features must be present for TextLogitStream");
+    }
+
+    String maxIterationsParam = params.get("maxIterations");
+    int maxIterations = 0;
+    if(maxIterationsParam != null) {
+      maxIterations = Integer.parseInt(maxIterationsParam);
+      params.remove("maxIterations");
+    } else {
+      throw new IOException("maxIterations param cannot be null for TextLogitStream");
+    }
+
+    String outcomeParam = params.get("outcome");
+
+    if(outcomeParam != null) {
+      params.remove("outcome");
+    } else {
+      throw new IOException("outcome param cannot be null for TextLogitStream");
+    }
+
+    String positiveLabelParam = params.get("positiveLabel");
+    int positiveLabel = 1;
+    if(positiveLabelParam != null) {
+      positiveLabel = Integer.parseInt(positiveLabelParam);
+      params.remove("positiveLabel");
+    }
+
+    String thresholdParam = params.get("threshold");
+    double threshold = 0.5;
+    if(thresholdParam != null) {
+      threshold = Double.parseDouble(thresholdParam);
+      params.remove("threshold");
+    }
+
+    int iteration = 0;
+    String iterationParam = params.get("iteration");
+    if(iterationParam != null) {
+      iteration = Integer.parseInt(iterationParam);
+      params.remove("iteration");
+    }
+
+    List<Double> weights = null;
+    String weightsParam = params.get("weights");
+    if(weightsParam != null) {
+      weights = new ArrayList<>();
+      String[] weightsArray = weightsParam.split(",");
+      for(String weightString : weightsArray) {
+        weights.add(Double.parseDouble(weightString));
+      }
+      params.remove("weights");
+    }
+
+    // zkHost, optional - if not provided then will look into factory list to get
+    String zkHost = null;
+    if(null == zkHostExpression){
+      zkHost = factory.getCollectionZkHost(collectionName);
+    }
+    else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
+      zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
+    }
+    if(null == zkHost){
+      throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
+    }
+
+    // We've got all the required items
+    init(collectionName, zkHost, params, name, feature, stream, weights, outcomeParam, positiveLabel, threshold, maxIterations, iteration);
+  }
+
+  @Override
+  public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
+    return toExpression(factory, true);
+  }
+
+  private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
+    // function name
+    StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
+
+    // collection
+    expression.addParameter(collection);
+
+    if (includeStreams && !(termsStream instanceof TermsStream)) {
+      if (termsStream instanceof Expressible) {
+        expression.addParameter(((Expressible)termsStream).toExpression(factory));
+      } else {
+        throw new IOException("This TextLogitStream contains a non-expressible TupleStream - it cannot be converted to an expression");
+      }
+    }
+
+    // parameters
+    for(Entry<String,String> param : params.entrySet()){
+      expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
+    }
+
+    expression.addParameter(new StreamExpressionNamedParameter("field", field));
+    expression.addParameter(new StreamExpressionNamedParameter("name", name));
+    if (termsStream instanceof TermsStream) {
+      loadTerms();
+      expression.addParameter(new StreamExpressionNamedParameter("terms", toString(terms)));
+    }
+
+    expression.addParameter(new StreamExpressionNamedParameter("outcome", outcome));
+    if(weights != null) {
+      expression.addParameter(new StreamExpressionNamedParameter("weights", toString(weights)));
+    }
+    expression.addParameter(new StreamExpressionNamedParameter("maxIterations", Integer.toString(maxIterations)));
+
+    if(iteration > 0) {
+      expression.addParameter(new StreamExpressionNamedParameter("iteration", Integer.toString(iteration)));
+    }
+
+    expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", Integer.toString(positiveLabel)));
+    expression.addParameter(new StreamExpressionNamedParameter("threshold", Double.toString(threshold)));
+
+    // zkHost
+    expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
+
+    return expression;
+  }
+
+  private void init(String collectionName,
+                    String zkHost,
+                    Map params,
+                    String name,
+                    String feature,
+                    TupleStream termsStream,
+                    List<Double> weights,
+                    String outcome,
+                    int positiveLabel,
+                    double threshold,
+                    int maxIterations,
+                    int iteration) throws IOException {
+    this.zkHost = zkHost;
+    this.collection = collectionName;
+    this.params = params;
+    this.name = name;
+    this.field = feature;
+    this.termsStream = termsStream;
+    this.outcome = outcome;
+    this.positiveLabel = positiveLabel;
+    this.threshold = threshold;
+    this.weights = weights;
+    this.maxIterations = maxIterations;
+    this.iteration = iteration;
+  }
+
+  public void setStreamContext(StreamContext context) {
+    this.cache = context.getSolrClientCache();
+    this.streamContext = context;
+    this.termsStream.setStreamContext(context);
+  }
+
+  /**
+   * Opens the CloudSolrStream
+   *
+   ***/
+  public void open() throws IOException {
+    if (cache == null) {
+      isCloseCache = true;
+      cache = new SolrClientCache();
+    } else {
+      isCloseCache = false;
+    }
+
+    this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
+    this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("TextLogitSolrStream"));
+  }
+
+  public List<TupleStream> children() {
+    List<TupleStream> l =  new ArrayList();
+    l.add(termsStream);
+    return l;
+  }
+
+  protected List<String> getShardUrls() throws IOException {
+
+    try {
+
+      ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
+      ClusterState clusterState = zkStateReader.getClusterState();
+      Set<String> liveNodes = clusterState.getLiveNodes();
+
+      Collection<Slice> slices = clusterState.getActiveSlices(this.collection);
+      List baseUrls = new ArrayList();
+
+      for(Slice slice : slices) {
+        Collection<Replica> replicas = slice.getReplicas();
+        List<Replica> shuffler = new ArrayList();
+        for(Replica replica : replicas) {
+          if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
+            shuffler.add(replica);
+          }
+        }
+
+        Collections.shuffle(shuffler, new Random());
+        Replica rep = shuffler.get(0);
+        ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
+        String url = zkProps.getCoreUrl();
+        baseUrls.add(url);
+      }
+
+      return baseUrls;
+
+    } catch (Exception e) {
+      throw new IOException(e);
+    }
+  }
+
+  private List<Future<Tuple>> callShards(List<String> baseUrls) throws IOException {
+
+    List<Future<Tuple>> futures = new ArrayList();
+    for (String baseUrl : baseUrls) {
+      LogitCall lc = new LogitCall(baseUrl,
+          this.params,
+          this.field,
+          this.terms,
+          this.weights,
+          this.outcome,
+          this.positiveLabel,
+          this.learningRate,
+          this.iteration);
+
+      Future<Tuple> future = executorService.submit(lc);
+      futures.add(future);
+    }
+
+    return futures;
+  }
+
+  public void close() throws IOException {
+    if (isCloseCache) {
+      cache.close();
+    }
+
+    executorService.shutdown();
+    termsStream.close();
+  }
+
+  /** Return the stream sort - ie, the order in which records are returned */
+  public StreamComparator getStreamSort(){
+    return null;
+  }
+
+  @Override
+  public Explanation toExplanation(StreamFactory factory) throws IOException {
+    StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString());
+    explanation.setFunctionName(factory.getFunctionName(this.getClass()));
+    explanation.setImplementingClass(this.getClass().getName());
+    explanation.setExpressionType(Explanation.ExpressionType.MACHINE_LEARNING_MODEL);
+    explanation.setExpression(toExpression(factory).toString());
+
+    explanation.addChild(termsStream.toExplanation(factory));
+
+    return explanation;
+  }
+
+  public void loadTerms() throws IOException {
+    if (this.terms == null) {
+      termsStream.open();
+      this.terms = new ArrayList<>();
+      this.idfs = new ArrayList();
+
+      while (true) {
+        Tuple termTuple = termsStream.read();
+        if (termTuple.EOF) {
+          break;
+        } else {
+          terms.add(termTuple.getString("term_s"));
+          idfs.add(termTuple.getDouble("idf_d"));
+        }
+      }
+      termsStream.close();
+    }
+  }
+
+  public Tuple read() throws IOException {
+    try {
+
+      if(++iteration > maxIterations) {
+        Map map = new HashMap();
+        map.put("EOF", true);
+        return new Tuple(map);
+      } else {
+
+        if (this.idfs == null) {
+          loadTerms();
+
+          if (weights != null && terms.size() + 1 != weights.size()) {
+            throw new IOException(String.format(Locale.ROOT,"invalid expression %s - the number of weights must be %d, found %d", terms.size()+1, weights.size()));
+          }
+        }
+
+        List<List<Double>> allWeights = new ArrayList();
+        this.evaluation = new ClassificationEvaluation();
+
+        this.error = 0;
+        for (Future<Tuple> logitCall : callShards(getShardUrls())) {
+
+          Tuple tuple = logitCall.get();
+          List<Double> shardWeights = (List<Double>) tuple.get("weights");
+          allWeights.add(shardWeights);
+          this.error += tuple.getDouble("error");
+          Map shardEvaluation = (Map) tuple.get("evaluation");
+          this.evaluation.addEvaluation(shardEvaluation);
+        }
+
+        this.weights = averageWeights(allWeights);
+        Map map = new HashMap();
+        map.put("id", name+"_"+iteration);
+        map.put("name_s", name);
+        map.put("field_s", field);
+        map.put("terms_ss", terms);
+        map.put("iteration_i", iteration);
+
+        if(weights != null) {
+          map.put("weights_ds", weights);
+        }
+
+        map.put("error_d", error);
+        evaluation.putToMap(map);
+        map.put("alpha_d", this.learningRate);
+        map.put("idfs_ds", this.idfs);
+
+        if (iteration != 1) {
+          if (lastError <= error) {
+            this.learningRate *= 0.5;
+          } else {
+            this.learningRate *= 1.05;
+          }
+        }
+
+        lastError = error;
+
+        return new Tuple(map);
+      }
+
+    } catch(Exception e) {
+      throw new IOException(e);
+    }
+  }
+
+  private List<Double> averageWeights(List<List<Double>> allWeights) {
+    double[] working = new double[allWeights.get(0).size()];
+    for(List<Double> shardWeights: allWeights) {
+      for(int i=0; i<working.length; i++) {
+        working[i] += shardWeights.get(i);
+      }
+    }
+
+    for(int i=0; i<working.length; i++) {
+      working[i] = working[i] / allWeights.size();
+    }
+
+    List<Double> ave = new ArrayList();
+    for(double d : working) {
+      ave.add(d);
+    }
+
+    return ave;
+  }
+
+  static String toString(List items) {
+    StringBuilder buf = new StringBuilder();
+    for(Object item : items) {
+      if(buf.length() > 0) {
+        buf.append(",");
+      }
+
+      buf.append(item.toString());
+    }
+
+    return buf.toString();
+  }
+
+  protected class TermsStream extends TupleStream {
+
+    private List<String> terms;
+    private Iterator<String> it;
+
+    public TermsStream(List<String> terms) {
+      this.terms = terms;
+    }
+
+    @Override
+    public void setStreamContext(StreamContext context) {}
+
+    @Override
+    public List<TupleStream> children() { return new ArrayList<>(); }
+
+    @Override
+    public void open() throws IOException { this.it = this.terms.iterator();}
+
+    @Override
+    public void close() throws IOException {}
+
+    @Override
+    public Tuple read() throws IOException {
+      HashMap map = new HashMap();
+      if(it.hasNext()) {
+        map.put("term_s",it.next());
+        map.put("score_f",1.0);
+        return new Tuple(map);
+      } else {
+        map.put("EOF", true);
+        return new Tuple(map);
+      }
+    }
+
+    @Override
+    public StreamComparator getStreamSort() {return null;}
+
+    @Override
+    public Explanation toExplanation(StreamFactory factory) throws IOException {
+      return new StreamExplanation(getStreamNodeId().toString())
+          .withFunctionName("non-expressible")
+          .withImplementingClass(this.getClass().getName())
+          .withExpressionType(Explanation.ExpressionType.STREAM_SOURCE)
+          .withExpression("non-expressible");
+    }
+  }
+
+  protected class LogitCall implements Callable<Tuple> {
+
+    private String baseUrl;
+    private String feature;
+    private List<String> terms;
+    private List<Double> weights;
+    private int iteration;
+    private String outcome;
+    private int positiveLabel;
+    private double learningRate;
+    private Map<String, String> paramsMap;
+
+    public LogitCall(String baseUrl,
+                     Map<String, String> paramsMap,
+                     String feature,
+                     List<String> terms,
+                     List<Double> weights,
+                     String outcome,
+                     int positiveLabel,
+                     double learningRate,
+                     int iteration) {
+
+      this.baseUrl = baseUrl;
+      this.feature = feature;
+      this.terms = terms;
+      this.weights = weights;
+      this.iteration = iteration;
+      this.outcome = outcome;
+      this.positiveLabel = positiveLabel;
+      this.learningRate = learningRate;
+      this.paramsMap = paramsMap;
+    }
+
+    public Tuple call() throws Exception {
+      ModifiableSolrParams params = new ModifiableSolrParams();
+      HttpSolrClient solrClient = cache.getHttpSolrClient(baseUrl);
+
+      params.add("distrib", "false");
+      params.add("fq","{!tlogit}");
+      params.add("feature", feature);
+      params.add("terms", TextLogitStream.toString(terms));
+      params.add("idfs", TextLogitStream.toString(idfs));
+
+      for(String key : paramsMap.keySet()) {
+        params.add(key, paramsMap.get(key));
+      }
+
+      if(weights != null) {
+        params.add("weights", TextLogitStream.toString(weights));
+      }
+
+      params.add("iteration", Integer.toString(iteration));
+      params.add("outcome", outcome);
+      params.add("positiveLabel", Integer.toString(positiveLabel));
+      params.add("threshold", Double.toString(threshold));
+      params.add("alpha", Double.toString(learningRate));
+
+      QueryRequest  request= new QueryRequest(params, SolrRequest.METHOD.POST);
+      QueryResponse response = request.process(solrClient);
+      NamedList res = response.getResponse();
+
+      NamedList logit = (NamedList)res.get("logit");
+
+      List<Double> shardWeights = (List<Double>)logit.get("weights");
+      double shardError = (double)logit.get("error");
+
+      Map map = new HashMap();
+
+      map.put("error", shardError);
+      map.put("weights", shardWeights);
+      map.put("evaluation", logit.get("evaluation"));
+
+      return new Tuple(map);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/expr/Explanation.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/expr/Explanation.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/expr/Explanation.java
index 5f028aa..5db9779 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/expr/Explanation.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/expr/Explanation.java
@@ -147,6 +147,7 @@ public class Explanation {
   
   public static interface ExpressionType{
     public static final String GRAPH_SOURCE = "graph-source";
+    public static final String MACHINE_LEARNING_MODEL = "ml-model";
     public static final String STREAM_SOURCE = "stream-source";
     public static final String STREAM_DECORATOR = "stream-decorator";
     public static final String DATASTORE = "datastore";

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
new file mode 100644
index 0000000..3206811
--- /dev/null
+++ b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/schema.xml
@@ -0,0 +1,77 @@
+<?xml version="1.0" ?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements.  See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<!-- The Solr schema file. This file should be named "schema.xml" and
+     should be located where the classloader for the Solr webapp can find it.
+
+     This schema is used for testing, and as such has everything and the
+     kitchen sink thrown in. See example/solr/conf/schema.xml for a
+     more concise example.
+
+  -->
+
+<schema name="test" version="1.6">
+
+    <fieldType name="int" docValues="true" class="solr.TrieIntField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
+    <fieldType name="float" docValues="true" class="solr.TrieFloatField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
+    <fieldType name="long" class="solr.TrieLongField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
+    <fieldType name="double" class="solr.TrieDoubleField" precisionStep="0" omitNorms="true" positionIncrementGap="0"/>
+
+    <fieldType name="tint" class="solr.TrieIntField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
+    <fieldType name="tfloat" class="solr.TrieFloatField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
+    <fieldType name="tlong" class="solr.TrieLongField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
+    <fieldType name="tdouble" class="solr.TrieDoubleField" precisionStep="8" omitNorms="true" positionIncrementGap="0"/>
+
+    <fieldType name="random" class="solr.RandomSortField" indexed="true" />
+
+    <fieldtype name="boolean" class="solr.BoolField" sortMissingLast="true"/>
+    <fieldtype name="string" class="solr.StrField" sortMissingLast="true" docValues="true"/>
+
+    <!-- format for date is 1995-12-31T23:59:59.999Z and only the fractional
+         seconds part (.999) is optional.
+      -->
+    <fieldtype name="date" class="solr.TrieDateField" precisionStep="0"/>
+    <fieldtype name="tdate" class="solr.TrieDateField" precisionStep="6"/>
+
+
+    <field name="id" type="string" indexed="true" stored="true" multiValued="false" required="false"/>
+    <field name="_version_" type="long" indexed="true" stored="true"/>
+
+    <!-- Dynamic field definitions.  If a field name is not found, dynamicFields
+         will be used if the name matches any of the patterns.
+         RESTRICTION: the glob-like pattern in the name attribute must have
+         a "*" only at the start or the end.
+         EXAMPLE:  name="*_i" will match any field ending in _i (like myid_i, z_i)
+         Longer patterns will be matched first.  if equal size patterns
+         both match, the first appearing in the schema will be used.
+    -->
+    <dynamicField name="*_b"  type="boolean" indexed="true"  stored="true" multiValued="false"/>
+    <dynamicField name="*_bs"  type="boolean" indexed="true"  stored="true" multiValued="true"/>
+    <dynamicField name="*_i"  type="int"    indexed="true"  stored="true" multiValued="false"/>
+    <dynamicField name="*_is"  type="int"    indexed="true" stored="true" multiValued="true"/>
+    <dynamicField name="*_l"  type="long"   indexed="true"  stored="true" multiValued="false"/>
+    <dynamicField name="*_ls"  type="long"   indexed="true"  stored="true" multiValued="true"/>
+    <dynamicField name="*_f"  type="float"   indexed="true"  stored="true" multiValued="false"/>
+    <dynamicField name="*_fs"  type="float"   indexed="true"  stored="true" multiValued="true"/>
+    <dynamicField name="*_d"  type="double"   indexed="true"  stored="true" multiValued="false"/>
+    <dynamicField name="*_ds"  type="double"   indexed="true"  stored="true" multiValued="true"/>
+
+    <dynamicField name="*_s"  type="string"  indexed="true"  stored="true" multiValued="false"/>
+    <dynamicField name="*_ss" type="string"  indexed="true"  stored="true" multiValued="true"/>
+  <uniqueKey>id</uniqueKey>
+</schema>

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/solrconfig.xml
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/solrconfig.xml b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/solrconfig.xml
new file mode 100644
index 0000000..6b10869
--- /dev/null
+++ b/solr/solrj/src/test-files/solrj/solr/configsets/ml/conf/solrconfig.xml
@@ -0,0 +1,51 @@
+<?xml version="1.0" encoding="UTF-8" ?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements.  See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<!--
+ This is a stripped down config file used for a simple example...
+ It is *not* a good example to work from.
+-->
+<config>
+  <luceneMatchVersion>${tests.luceneMatchVersion:LUCENE_CURRENT}</luceneMatchVersion>
+  <indexConfig>
+    <useCompoundFile>${useCompoundFile:false}</useCompoundFile>
+  </indexConfig>
+  <dataDir>${solr.data.dir:}</dataDir>
+  <directoryFactory name="DirectoryFactory" class="${solr.directoryFactory:solr.StandardDirectoryFactory}"/>
+  <schemaFactory class="ClassicIndexSchemaFactory"/>
+
+  <updateHandler class="solr.DirectUpdateHandler2">
+    <updateLog>
+      <str name="dir">${solr.data.dir:}</str>
+    </updateLog>
+  </updateHandler>
+
+
+  <requestDispatcher handleSelect="true" >
+    <requestParsers enableRemoteStreaming="false" multipartUploadLimitInKB="2048" />
+  </requestDispatcher>
+
+  <requestHandler name="standard" class="solr.StandardRequestHandler" default="true" />
+
+  <!-- config for the admin interface -->
+  <admin>
+    <defaultQuery>solr</defaultQuery>
+  </admin>
+
+</config>
+

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/87938e00/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index f2446f3..74e7fb1 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -16,16 +16,23 @@
  */
 package org.apache.solr.client.solrj.io.stream;
 
+import java.io.BufferedReader;
 import java.io.IOException;
+import java.io.InputStreamReader;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
+import java.util.zip.ZipEntry;
+import java.util.zip.ZipFile;
 
 import org.apache.lucene.util.LuceneTestCase;
 import org.apache.lucene.util.LuceneTestCase.Slow;
 import org.apache.solr.client.solrj.embedded.JettySolrRunner;
+import org.apache.solr.client.solrj.io.ClassificationEvaluation;
 import org.apache.solr.client.solrj.io.SolrClientCache;
 import org.apache.solr.client.solrj.io.Tuple;
 import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
@@ -71,6 +78,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
   public static void setupCluster() throws Exception {
     configureCluster(4)
         .addConfig("conf", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("streaming").resolve("conf"))
+        .addConfig("ml", getFile("solrj").toPath().resolve("solr").resolve("configsets").resolve("ml").resolve("conf"))
         .configure();
 
     CollectionAdminRequest.createCollection(COLLECTION, "conf", 2, 1).process(cluster.getSolrClient());
@@ -2773,6 +2781,8 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     assert(tuple.getDouble("a_f") == 4.0);
     assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
     assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
+
+    CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
   }
 
   @Test
@@ -2863,6 +2873,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
     assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
 
+    CollectionAdminRequest.deleteCollection("parallelDestinationCollection").process(cluster.getSolrClient());
   }
 
   @Test
@@ -3025,6 +3036,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     assertList(tuple.getStrings("s_multi"), "aaaa3", "bbbb3");
     assertList(tuple.getLongs("i_multi"), Long.parseLong("4444"), Long.parseLong("7777"));
 
+    CollectionAdminRequest.deleteCollection("parallelDestinationCollection1").process(cluster.getSolrClient());
   }
 
   @Test
@@ -3061,10 +3073,125 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     tuples = getTuples(stream);
     
     assert(tuples.size() == 5);
-    assertOrder(tuples, 0,7,3,4,8);
+    assertOrder(tuples, 0, 7, 3, 4, 8);
 
   }
 
+
+  @Test
+  public void testBasicTextLogitStream() throws Exception {
+    CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
+    AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(),
+        false, true, TIMEOUT);
+
+    UpdateRequest updateRequest = new UpdateRequest();
+    for (int i = 0; i < 5000; i+=2) {
+      updateRequest.add(id, String.valueOf(i), "tv_text", "a b c c d", "out_i", "1");
+      updateRequest.add(id, String.valueOf(i+1), "tv_text", "a b e e f", "out_i", "0");
+    }
+    updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+
+    StreamExpression expression;
+    TupleStream stream;
+    List<Tuple> tuples;
+
+    StreamFactory factory = new StreamFactory()
+        .withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
+        .withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress())
+        .withFunctionName("features", FeaturesSelectionStream.class)
+        .withFunctionName("train", TextLogitStream.class)
+        .withFunctionName("search", CloudSolrStream.class)
+        .withFunctionName("update", UpdateStream.class);
+
+    expression = StreamExpressionParser.parse("features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4)");
+    stream = new FeaturesSelectionStream(expression, factory);
+    tuples = getTuples(stream);
+
+    assert(tuples.size() == 4);
+    HashSet<String> terms = new HashSet<>();
+    for (Tuple tuple : tuples) {
+      terms.add((String) tuple.get("term_s"));
+    }
+    assertTrue(terms.contains("d"));
+    assertTrue(terms.contains("c"));
+    assertTrue(terms.contains("e"));
+    assertTrue(terms.contains("f"));
+
+    String textLogitExpression = "train(" +
+        "collection1, " +
+        "features(collection1, q=\"*:*\", featureSet=\"first\", field=\"tv_text\", outcome=\"out_i\", numTerms=4),"+
+        "q=\"*:*\", " +
+        "name=\"model\", " +
+        "field=\"tv_text\", " +
+        "outcome=\"out_i\", " +
+        "maxIterations=100)";
+    stream = factory.constructStream(textLogitExpression);
+    tuples = getTuples(stream);
+    Tuple lastTuple = tuples.get(tuples.size() - 1);
+    List<Double> lastWeights = lastTuple.getDoubles("weights_ds");
+    Double[] lastWeightsArray = lastWeights.toArray(new Double[lastWeights.size()]);
+
+    // first feature is bias value
+    Double[] testRecord = {1.0, 1.17, 0.691, 0.0, 0.0};
+    double d = sum(multiply(testRecord, lastWeightsArray));
+    double prob = sigmoid(d);
+    assertEquals(prob, 1.0, 0.1);
+
+    // first feature is bias value
+    Double[] testRecord2 = {1.0, 0.0, 0.0, 1.17, 0.691};
+    d = sum(multiply(testRecord2, lastWeightsArray));
+    prob = sigmoid(d);
+    assertEquals(prob, 0, 0.1);
+
+    stream = factory.constructStream("update(destinationCollection, batchSize=5, "+textLogitExpression+")");
+    getTuples(stream);
+    cluster.getSolrClient().commit("destinationCollection");
+
+    stream = factory.constructStream("search(destinationCollection, " +
+        "q=*:*, " +
+        "fl=\"iteration_i,* \", " +
+        "rows=100, " +
+        "sort=\"iteration_i desc\")");
+    tuples = getTuples(stream);
+    assertEquals(100, tuples.size());
+    Tuple lastModel = tuples.get(0);
+    ClassificationEvaluation evaluation = ClassificationEvaluation.create(lastModel.fields);
+    assertTrue(evaluation.getF1() >= 1.0);
+    assertEquals(Math.log( 5000.0 / (2500 + 1)), lastModel.getDoubles("idfs_ds").get(0), 0.0001);
+    // make sure the tuples is retrieved in correct order
+    Tuple firstTuple = tuples.get(99);
+    assertEquals(1L, (long) firstTuple.getLong("iteration_i"));
+
+    CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
+  }
+
+  private double sigmoid(double in) {
+
+    double d = 1.0 / (1+Math.exp(-in));
+    return d;
+  }
+
+  private double[] multiply(Double[] vec1, Double[] vec2) {
+    double[] working = new double[vec1.length];
+    for(int i=0; i<vec1.length; i++) {
+      working[i]= vec1[i]*vec2[i];
+    }
+
+    return working;
+  }
+
+
+  private double sum(double[] vec) {
+    double d = 0.0;
+
+    for(double v : vec) {
+      d += v;
+    }
+
+    return d;
+  }
+
+
   @Test
   public void testParallelIntersectStream() throws Exception {
 
@@ -3104,6 +3231,62 @@ public class StreamExpressionTest extends SolrCloudTestCase {
   }
 
   @Test
+  public void testFeaturesSelectionStream() throws Exception {
+
+    CollectionAdminRequest.createCollection("destinationCollection", "ml", 2, 1).process(cluster.getSolrClient());
+    AbstractDistribZkTestBase.waitForRecoveriesToFinish("destinationCollection", cluster.getSolrClient().getZkStateReader(),
+        false, true, TIMEOUT);
+
+    UpdateRequest updateRequest = new UpdateRequest();
+    for (int i = 0; i < 5000; i+=2) {
+      updateRequest.add(id, String.valueOf(i), "whitetok", "a b c d", "out_i", "1");
+      updateRequest.add(id, String.valueOf(i+1), "whitetok", "a b e f", "out_i", "0");
+    }
+    updateRequest.commit(cluster.getSolrClient(), COLLECTION);
+
+    StreamExpression expression;
+    TupleStream stream;
+    List<Tuple> tuples;
+
+    StreamFactory factory = new StreamFactory()
+        .withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
+        .withCollectionZkHost("destinationCollection", cluster.getZkServer().getZkAddress())
+        .withFunctionName("featuresSelection", FeaturesSelectionStream.class)
+        .withFunctionName("search", CloudSolrStream.class)
+        .withFunctionName("update", UpdateStream.class);
+
+    String featuresExpression = "featuresSelection(collection1, q=\"*:*\", featureSet=\"first\", field=\"whitetok\", outcome=\"out_i\", numTerms=4)";
+    // basic
+    expression = StreamExpressionParser.parse(featuresExpression);
+    stream = new FeaturesSelectionStream(expression, factory);
+    tuples = getTuples(stream);
+
+    assert(tuples.size() == 4);
+
+    assertTrue(tuples.get(0).get("term_s").equals("c"));
+    assertTrue(tuples.get(1).get("term_s").equals("d"));
+    assertTrue(tuples.get(2).get("term_s").equals("e"));
+    assertTrue(tuples.get(3).get("term_s").equals("f"));
+
+    // update
+    expression = StreamExpressionParser.parse("update(destinationCollection, batchSize=5, "+featuresExpression+")");
+    stream = new UpdateStream(expression, factory);
+    getTuples(stream);
+    cluster.getSolrClient().commit("destinationCollection");
+
+    expression = StreamExpressionParser.parse("search(destinationCollection, q=featureSet_s:first, fl=\"index_i, term_s\", sort=\"index_i asc\")");
+    stream = new CloudSolrStream(expression, factory);
+    tuples = getTuples(stream);
+    assertEquals(4, tuples.size());
+    assertTrue(tuples.get(0).get("term_s").equals("c"));
+    assertTrue(tuples.get(1).get("term_s").equals("d"));
+    assertTrue(tuples.get(2).get("term_s").equals("e"));
+    assertTrue(tuples.get(3).get("term_s").equals("f"));
+
+    CollectionAdminRequest.deleteCollection("destinationCollection").process(cluster.getSolrClient());
+  }
+
+  @Test
   public void testComplementStream() throws Exception {
 
     new UpdateRequest()


Mime
View raw message