lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jbern...@apache.org
Subject lucene-solr:master: SOLR-11785: Add multiVariateNormalDistribution Stream Evaluator
Date Wed, 20 Dec 2017 18:56:16 GMT
Repository: lucene-solr
Updated Branches:
  refs/heads/master d9695cca5 -> 960a5fd79


SOLR-11785: Add multiVariateNormalDistribution Stream Evaluator


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/960a5fd7
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/960a5fd7
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/960a5fd7

Branch: refs/heads/master
Commit: 960a5fd793aa26546379a17ee29c026fcb198a37
Parents: d9695cc
Author: Joel Bernstein <jbernste@apache.org>
Authored: Wed Dec 20 13:51:34 2017 -0500
Committer: Joel Bernstein <jbernste@apache.org>
Committed: Wed Dec 20 13:51:47 2017 -0500

----------------------------------------------------------------------
 .../org/apache/solr/handler/StreamHandler.java  |  4 +-
 ...MultiVariateNormalDistributionEvaluator.java | 54 ++++++++++++++++++++
 .../client/solrj/io/eval/SampleEvaluator.java   | 27 +++++++++-
 .../solrj/io/stream/StreamExpressionTest.java   | 47 +++++++++++++++++
 4 files changed, 128 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/960a5fd7/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
index 8a83160..949a040 100644
--- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
+++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
@@ -283,11 +283,11 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
         .withFunctionName("spline", SplineEvaluator.class)
         .withFunctionName("ttest", TTestEvaluator.class)
         .withFunctionName("pairedTtest", PairedTTestEvaluator.class)
-
+        .withFunctionName("multiVariateNormalDistribution", MultiVariateNormalDistributionEvaluator.class)
 
         // Boolean Stream Evaluators
 
-        .withFunctionName("and", AndEvaluator.class)
+            .withFunctionName("and", AndEvaluator.class)
         .withFunctionName("eor", ExclusiveOrEvaluator.class)
         .withFunctionName("eq", EqualToEvaluator.class)
         .withFunctionName("gt", GreaterThanEvaluator.class)

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/960a5fd7/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiVariateNormalDistributionEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiVariateNormalDistributionEvaluator.java
b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiVariateNormalDistributionEvaluator.java
new file mode 100644
index 0000000..bc2fbcb
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MultiVariateNormalDistributionEvaluator.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.client.solrj.io.eval;
+
+import java.io.IOException;
+import java.util.Locale;
+import java.util.List;
+
+import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+public class MultiVariateNormalDistributionEvaluator extends RecursiveObjectEvaluator implements
TwoValueWorker {
+
+  private static final long serialVersionUID = 1;
+
+  public MultiVariateNormalDistributionEvaluator(StreamExpression expression, StreamFactory
factory) throws IOException {
+    super(expression, factory);
+  }
+
+  @Override
+  public Object doWork(Object first, Object second) throws IOException{
+    if(null == first){
+      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found
for the first value",toExpression(constructingFactory)));
+    }
+    if(null == second){
+      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found
for the second value",toExpression(constructingFactory)));
+    }
+
+    List<Number> means = (List<Number>)first;
+    Matrix covar = (Matrix)second;
+
+    double[] m = new double[means.size()];
+    for(int i=0; i< m.length; i++) {
+      m[i] = means.get(i).doubleValue();
+    }
+
+    return new MultivariateNormalDistribution(m, covar.getData());
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/960a5fd7/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java
b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java
index 9b7aca5..5ea29e6 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/SampleEvaluator.java
@@ -18,12 +18,16 @@
 package org.apache.solr.client.solrj.io.eval;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Locale;
 import java.util.stream.Collectors;
+import java.util.List;
 
 import org.apache.commons.math3.distribution.IntegerDistribution;
+import org.apache.commons.math3.distribution.MultivariateRealDistribution;
 import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
 
@@ -43,7 +47,7 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements
ManyVal
 
     Object first = objects[0];
 
-    if(!(first instanceof RealDistribution) && !(first instanceof IntegerDistribution)
&& !(first instanceof MarkovChainEvaluator.MarkovChain)){
+    if(!(first instanceof MultivariateRealDistribution) && !(first instanceof RealDistribution)
&& !(first instanceof IntegerDistribution) && !(first instanceof MarkovChainEvaluator.MarkovChain)){
       throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type
%s for the first value, expecting a Markov Chain, Real or Integer Distribution",toExpression(constructingFactory),
first.getClass().getSimpleName()));
     }
 
@@ -61,11 +65,30 @@ public class SampleEvaluator extends RecursiveObjectEvaluator implements
ManyVal
       }
     } else if (first instanceof RealDistribution) {
       RealDistribution realDistribution = (RealDistribution) first;
-      if(second != null) {
+      if (second != null) {
         return Arrays.stream(realDistribution.sample(((Number) second).intValue())).mapToObj(item
-> item).collect(Collectors.toList());
       } else {
         return realDistribution.sample();
       }
+    }else if(first instanceof MultivariateNormalDistribution) {
+      if(second != null) {
+        MultivariateNormalDistribution multivariateNormalDistribution = (MultivariateNormalDistribution)first;
+        int size = ((Number)second).intValue();
+        double[][] samples = new double[size][];
+        for(int i=0; i<size; ++i) {
+          samples[i] =  multivariateNormalDistribution.sample();
+        }
+
+        return new Matrix(samples);
+      } else {
+        MultivariateNormalDistribution multivariateNormalDistribution = (MultivariateNormalDistribution)first;
+        double[] sample = multivariateNormalDistribution.sample();
+        List<Number> sampleList = new ArrayList(sample.length);
+        for(int i=0; i<sample.length; i++) {
+          sampleList.add(sample[i]);
+        }
+        return sampleList;
+      }
     } else {
       IntegerDistribution integerDistribution = (IntegerDistribution) first;
       if(second != null) {

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/960a5fd7/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index c8fe2ff..01ba1dd 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -7179,6 +7179,53 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     assertEquals(pval3.doubleValue(), 0.0404907407662755, .0001);
   }
 
+  @Test
+  public void testMultiVariateNormalDistribution() throws Exception {
+    String cexpr = "let(echo=true," +
+        "     a=array(1,2,3,4,5,6,7)," +
+        "     b=array(100, 110, 120, 130,140,150,180)," +
+        "     c=transpose(matrix(a, b))," +
+        "     d=array(mean(a), mean(b))," +
+        "     e=cov(c)," +
+        "     f=multiVariateNormalDistribution(d, e)," +
+        "     g=sample(f, 10000)," +
+        "     h=cov(g)," +
+        "     i=sample(f))";
+
+    ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    TupleStream solrStream = new SolrStream(url, paramsLoc);
+    StreamContext context = new StreamContext();
+    solrStream.setStreamContext(context);
+    List<Tuple> tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    List<List<Number>> cov = (List<List<Number>>)tuples.get(0).get("h");
+    assertEquals(cov.size(), 2);
+    List<Number> row1 = cov.get(0);
+    assertEquals(row1.size(), 2);
+
+    double a = row1.get(0).doubleValue();
+    double b = row1.get(1).doubleValue();
+    assertEquals(a, 4.666666666666667, 2.5);
+    assertEquals(b, 56.66666666666667, 7);
+
+    List<Number> row2 = cov.get(1);
+
+    double c = row2.get(0).doubleValue();
+    double d = row2.get(1).doubleValue();
+    assertEquals(c, 56.66666666666667, 7);
+    assertEquals(d, 723.8095238095239, 50);
+
+    List<Number> sample = (List<Number>)tuples.get(0).get("i");
+    assertEquals(sample.size(), 2);
+    Number sample1 = sample.get(0);
+    Number sample2 = sample.get(1);
+    assertTrue(sample1.doubleValue() > -30 && sample1.doubleValue() < 30);
+    assertTrue(sample2.doubleValue() > 50 && sample2.doubleValue() < 250);
+  }
+
 
   @Test
   public void testLoess() throws Exception {


Mime
View raw message