lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jbern...@apache.org
Subject lucene-solr:master: SOLR-12671: Add robust flag to knnRegress Stream Evaluator
Date Fri, 17 Aug 2018 19:49:09 GMT
Repository: lucene-solr
Updated Branches:
  refs/heads/master 124be4e20 -> 52f9cee97


SOLR-12671: Add robust flag to knnRegress Stream Evaluator


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/52f9cee9
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/52f9cee9
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/52f9cee9

Branch: refs/heads/master
Commit: 52f9cee97b4f293af26de0e7b4ec534cb6b11b10
Parents: 124be4e
Author: Joel Bernstein <jbernste@apache.org>
Authored: Fri Aug 17 14:26:05 2018 -0400
Committer: Joel Bernstein <jbernste@apache.org>
Committed: Fri Aug 17 14:26:17 2018 -0400

----------------------------------------------------------------------
 .../solr/client/solrj/io/eval/KnnEvaluator.java |  4 ++
 .../solrj/io/eval/KnnRegressionEvaluator.java   | 65 ++++++++++++++++----
 .../client/solrj/io/eval/PredictEvaluator.java  |  8 ++-
 .../solrj/io/stream/MathExpressionTest.java     | 36 ++++++++---
 4 files changed, 92 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
index 81607cf..17fb011 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
@@ -144,6 +144,10 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements
ManyValueW
     }
 
     public int compareTo(Neighbor neighbor) {
+      if(this.distance.compareTo(neighbor.getDistance()) == 0) {
+        return row-neighbor.getRow();
+      }
+
       return this.distance.compareTo(neighbor.getDistance());
     }
   }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
index 957936e..e6f6d80 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
@@ -25,15 +25,32 @@ import java.util.HashMap;
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.ml.distance.DistanceMeasure;
 import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
 import org.apache.solr.client.solrj.io.Tuple;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
 
 public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker
{
   protected static final long serialVersionUID = 1L;
 
+  private boolean robust=false;
+  private boolean scale=false;
+
   public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws
IOException{
     super(expression, factory);
+
+    List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+
+    for(StreamExpressionNamedParameter namedParam : namedParams){
+      if(namedParam.getName().equals("scale")){
+        this.scale = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
+      } else if(namedParam.getName().equals("robust")) {
+        this.robust = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
+      } else {
+        throw new IOException("Unexpected named parameter:"+namedParam.getName());
+      }
+    }
   }
 
   @Override
@@ -84,7 +101,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
     map.put("features", observations.getColumnCount());
     map.put("distance", distanceMeasure.getClass().getSimpleName());
 
-    return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
+    return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale,
robust);
   }
 
 
@@ -95,17 +112,27 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator
implements
     private double[] outcomes;
     private int k;
     private DistanceMeasure distanceMeasure;
+    private boolean scale;
+    private boolean robust;
 
     public KnnRegressionTuple(Matrix observations,
                               double[] outcomes,
                               int k,
                               DistanceMeasure distanceMeasure,
-                              Map<?,?> map) {
+                              Map<?,?> map,
+                              boolean scale,
+                              boolean robust) {
       super(map);
       this.observations = observations;
       this.outcomes = outcomes;
       this.k = k;
       this.distanceMeasure = distanceMeasure;
+      this.scale = scale;
+      this.robust = robust;
+    }
+
+    public boolean getScale() {
+      return this.scale;
     }
 
     //MinMax Scale both the observations and the predictors
@@ -175,19 +202,33 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator
implements
 
     public double predict(double[] values) {
 
-      Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
+      Matrix obs = scaledObservations != null ? scaledObservations : observations;
+      Matrix knn = KnnEvaluator.search(obs, values, k, distanceMeasure);
       List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
 
-      double sum = 0;
-
-      //Collect the outcomes for the nearest neighbors
-      for(Number n : indexes) {
-        sum += outcomes[n.intValue()];
+      if(robust) {
+        //Get the median of the results.
+        double[] vals = new double[indexes.size()];
+        Percentile percentile = new Percentile();
+        int i=0;
+        for (Number n : indexes) {
+           vals[i++]=outcomes[n.intValue()];
+        }
+
+        //Return 50 percentile.
+        return percentile.evaluate(vals, 50);
+      } else {
+        //Get the average of the results
+        double sum = 0;
+
+        //Collect the outcomes for the nearest neighbors
+        for (Number n : indexes) {
+          sum += outcomes[n.intValue()];
+        }
+
+        //Return the average of the outcomes as the prediction.
+        return sum / ((double) indexes.size());
       }
-
-      //Return the average of the outcomes as the prediction.
-
-      return sum/((double)indexes.size());
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
index 9385928..c8e83ba 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
@@ -97,13 +97,17 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements
ManyVa
           predictors[i] = list.get(i).doubleValue();
         }
 
-        predictors = regressedTuple.scale(predictors);
+        if(regressedTuple.getScale()) {
+          predictors = regressedTuple.scale(predictors);
+        }
 
         return regressedTuple.predict(predictors);
       } else if (second instanceof Matrix) {
 
         Matrix m = (Matrix) second;
-        m = regressedTuple.scale(m);
+        if(regressedTuple.getScale()) {
+          m = regressedTuple.scale(m);
+        }
         double[][] data = m.getData();
         List<Number> predictions = new ArrayList();
         for (double[] predictors : data) {

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
index bfd4160..6565b76 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
@@ -3450,7 +3450,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
                                   "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981,
10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
                                   "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061,
54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
         "e=transpose(matrix(a, b, c))," +
-        "f=knnRegress(e, d, 1)," +
+        "f=knnRegress(e, d, 1, scale=true)," +
         "g=predict(f, e))";
     ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
@@ -3480,7 +3480,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
         "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809,
5.900000095, 20.79999924, 7.900000095)," +
         "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924,
79.40000153, 91, 135.3999939, 89.30000305)," +
         "e=transpose(matrix(a, b, c))," +
-        "f=knnRegress(e, d, 1)," +
+        "f=knnRegress(e, d, 1, scale=true)," +
         "g=predict(f, array(8, 5, 4)))";
     paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
@@ -3494,12 +3494,14 @@ public class MathExpressionTest extends SolrCloudTestCase {
     Number prediction = (Number)tuples.get(0).get("g");
     assertEquals(prediction.doubleValue(), 85.09999847, 0);
 
-    cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905,
3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
-        "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048,
3.700000048, 7.599999905, 5.699999809, 4.5)," +
-        "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809,
5.900000095, 20.79999924, 4.900000095)," +
+    //Test robust. Take the median rather then average
+
+    cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905,
3.5, 9.199999809, 9, 8.10000038, 8.19999981), " +
+        "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048,
3.700000048, 5.599999905, 5.699999809, 4.5)," +
+        "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809,
5.900000095, 4.79999924, 4.900000095)," +
         "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924,
79.40000153, 91, 135.3999939, 89.30000305)," +
         "e=transpose(matrix(a, b, c))," +
-        "f=knnRegress(e, d, 2)," +
+        "f=knnRegress(e, d, 3, scale=true, robust=true)," +
         "g=predict(f, array(8, 5, 4)))";
     paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
@@ -3511,7 +3513,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
     tuples = getTuples(solrStream);
     assertTrue(tuples.size() == 1);
     prediction = (Number)tuples.get(0).get("g");
-    assertEquals(prediction.doubleValue(), 87.20000076, 0);
+    assertEquals(prediction.doubleValue(), 89.30000305, 0);
+
+
+    //Test univariate regression with scaling off
+
+    cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
+        "b=transpose(matrix(a))," +
+        "c=knnRegress(b, a, 3)," +
+        "d=predict(c, array(3)))";
+    paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    solrStream = new SolrStream(url, paramsLoc);
+    context = new StreamContext();
+    solrStream.setStreamContext(context);
+    tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    prediction = (Number)tuples.get(0).get("d");
+    assertEquals(prediction.doubleValue(), 3, 0);
+
   }
 
   @Test


Mime
View raw message