lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jbern...@apache.org
Subject lucene-solr:branch_7x: SOLR-11422: Add probabilities parameter to the enumeratedDistribution Stream Evaluator
Date Fri, 29 Sep 2017 17:29:15 GMT
Repository: lucene-solr
Updated Branches:
  refs/heads/branch_7x 71032d7c6 -> f19777b6d


SOLR-11422: Add probabilities parameter to the enumeratedDistribution Stream Evaluator


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/f19777b6
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/f19777b6
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/f19777b6

Branch: refs/heads/branch_7x
Commit: f19777b6ddfb3cb1a179cf8b941e89f9a870f63f
Parents: 71032d7
Author: Joel Bernstein <jbernste@apache.org>
Authored: Fri Sep 29 13:10:26 2017 -0400
Committer: Joel Bernstein <jbernste@apache.org>
Committed: Fri Sep 29 13:25:03 2017 -0400

----------------------------------------------------------------------
 .../eval/EnumeratedDistributionEvaluator.java   | 19 ++++++++++++++-----
 .../solrj/io/stream/StreamExpressionTest.java   | 20 ++++++++++++++++++++
 2 files changed, 34 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f19777b6/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java
b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java
index a14e54b..4a8b7f7 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EnumeratedDistributionEvaluator.java
@@ -25,7 +25,7 @@ import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
 
-public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements
OneValueWorker {
+public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator implements
ManyValueWorker {
 
   private static final long serialVersionUID = 1;
 
@@ -34,12 +34,21 @@ public class EnumeratedDistributionEvaluator extends RecursiveNumericEvaluator
i
   }
 
   @Override
-  public Object doWork(Object first) throws IOException{
-    if(null == first){
+  public Object doWork(Object... values) throws IOException{
+    if(values.length == 0){
       throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found
for the first value",toExpression(constructingFactory)));
     }
 
-     int[] samples = ((List)first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
-     return new EnumeratedIntegerDistribution(samples);
+    if(values.length == 1) {
+      List<Number> first = (List<Number>)values[0];
+      int[] samples = ((List) first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
+      return new EnumeratedIntegerDistribution(samples);
+    } else {
+      List<Number> first = (List<Number>)values[0];
+      List<Number> second = (List<Number>)values[1];
+      int[] singletons = ((List) first).stream().mapToInt(value -> ((BigDecimal) value).intValue()).toArray();
+      double[] probs = ((List) second).stream().mapToDouble(value -> ((BigDecimal) value).doubleValue()).toArray();
+      return new EnumeratedIntegerDistribution(singletons, probs);
+    }
   }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f19777b6/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index d84120f..9e4c6c3 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -6480,6 +6480,26 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     assertEquals(prob.doubleValue(), 0.1, 0.07);
     Number cprob = (Number)tuples.get(0).get("c");
     assertEquals(cprob.doubleValue(), 0.5, 0.07);
+
+
+    cexpr = "let(a=sample(enumeratedDistribution(array(1,2,3,4), array(40, 30, 20, 10)),
50000),"+
+                "b=freqTable(a),"+
+                "y=col(b, pct))";
+
+    paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    solrStream = new SolrStream(url, paramsLoc);
+    context = new StreamContext();
+    solrStream.setStreamContext(context);
+    tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    List<Number> freqs = (List<Number>)tuples.get(0).get("y");
+    assertEquals(freqs.get(0).doubleValue(), .40, .03);
+    assertEquals(freqs.get(1).doubleValue(), .30, .03);
+    assertEquals(freqs.get(2).doubleValue(), .20, .03);
+    assertEquals(freqs.get(3).doubleValue(), .10, .03);
   }
 
   @Test


Mime
View raw message