lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sar...@apache.org
Subject lucene-solr:master: LUCENE-7974: Add N-dimensional FloatPoint K-nearest-neighbor implementation
Date Sun, 01 Oct 2017 22:51:14 GMT
Repository: lucene-solr
Updated Branches:
  refs/heads/master 472d52022 -> d52564c07


LUCENE-7974: Add N-dimensional FloatPoint K-nearest-neighbor implementation


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/d52564c0
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/d52564c0
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/d52564c0

Branch: refs/heads/master
Commit: d52564c079bb7ca842a1041bc8baff468e1938d4
Parents: 472d520
Author: Steve Rowe <sarowe@apache.org>
Authored: Sun Oct 1 18:50:06 2017 -0400
Committer: Steve Rowe <sarowe@apache.org>
Committed: Sun Oct 1 18:50:55 2017 -0400

----------------------------------------------------------------------
 lucene/CHANGES.txt                              |   3 +
 .../document/FloatPointNearestNeighbor.java     | 382 +++++++++++++++++++
 .../document/TestFloatPointNearestNeighbor.java | 239 ++++++++++++
 3 files changed, 624 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/d52564c0/lucene/CHANGES.txt
----------------------------------------------------------------------
diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 16482e9..e0c8124 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -34,6 +34,9 @@ New Features
 
 * LUCENE-7973: Update dictionary version for Ukrainian analyzer to 3.9.0 (Andriy
   Rysin via Dawid Weiss)
+  
+* LUCENE-7974: Add FloatPointNearestNeighbor, an N-dimensional FloatPoint
+  K-nearest-neighbor search implementation.  (Steve Rowe)
 
 Optimizations
 

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/d52564c0/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java
----------------------------------------------------------------------
diff --git a/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java
b/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java
new file mode 100644
index 0000000..d3360a8
--- /dev/null
+++ b/lucene/sandbox/src/java/org/apache/lucene/document/FloatPointNearestNeighbor.java
@@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.document;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.PriorityQueue;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PointValues;
+import org.apache.lucene.search.FieldDoc;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopFieldDocs;
+import org.apache.lucene.util.Bits;
+import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.bkd.BKDReader;
+
+/**
+ * KNN search on top of N dimensional indexed float points.
+ *
+ * @lucene.experimental
+ */
+public class FloatPointNearestNeighbor {
+
+  static class Cell implements Comparable<Cell> {
+    final int readerIndex;
+    final byte[] minPacked;
+    final byte[] maxPacked;
+    final BKDReader.IndexTree index;
+
+    /** The closest possible distance^2 of all points in this cell */
+    final double distanceSquared;
+    
+    Cell(BKDReader.IndexTree index, int readerIndex, byte[] minPacked, byte[] maxPacked,
double distanceSquared) {
+      this.index = index;
+      this.readerIndex = readerIndex;
+      this.minPacked = minPacked.clone();
+      this.maxPacked = maxPacked.clone();
+      this.distanceSquared = distanceSquared;
+    }
+
+    public int compareTo(Cell other) {
+      return Double.compare(distanceSquared, other.distanceSquared);
+    }
+
+    @Override
+    public String toString() {
+      return "Cell(readerIndex=" + readerIndex + " nodeID=" + index.getNodeID()
+          + " isLeaf=" + index.isLeafNode() + " distanceSquared=" + distanceSquared + ")";
+    }
+  }
+
+  private static class NearestVisitor implements PointValues.IntersectVisitor {
+    int curDocBase;
+    Bits curLiveDocs;
+    final int topN;
+    final PriorityQueue<NearestHit> hitQueue;
+    final float[] origin;
+    private int dims;
+    private int updateMinMaxCounter;
+    private float[] min;
+    private float[] max;
+
+
+    public NearestVisitor(PriorityQueue<NearestHit> hitQueue, int topN, float[] origin)
{
+      this.hitQueue = hitQueue;
+      this.topN = topN;
+      this.origin = origin;
+      dims = origin.length;
+      min = new float[dims];
+      max = new float[dims];
+      Arrays.fill(min, Float.NEGATIVE_INFINITY);
+      Arrays.fill(max, Float.POSITIVE_INFINITY);
+    }
+
+    @Override
+    public void visit(int docID) {
+      throw new AssertionError();
+    }
+
+    private static final int MANTISSA_BITS = 23; 
+    
+    /**
+     * Returns the minimum value that will change the given distance when added to it.
+     * 
+     * This value is calculated from the distance exponent reduced by (at most) 23,
+     * the number of bits in a float mantissa. This is necessary when the result of
+     * subtracting/adding the distance in a single dimension has an exponent that
+     * differs significantly from that of the distance value. Without this fudge
+     * factor (i.e. only subtracting/adding the distance), cells and values can be
+     * inappropriately judged as outside the search radius.
+     */
+    private float getMinDelta(float distance) {
+      int exponent = Float.floatToIntBits(distance) >> MANTISSA_BITS; // extract biased
exponent (distance is positive)
+      if (exponent == 0) {
+        return Float.MIN_VALUE;
+      } else {
+        exponent = exponent <= MANTISSA_BITS ? 1 : exponent - MANTISSA_BITS; // Avoid
underflow
+        return Float.intBitsToFloat(exponent << MANTISSA_BITS);
+      }
+    }
+    
+    private void maybeUpdateMinMax() {
+      if (updateMinMaxCounter < 1024 || (updateMinMaxCounter & 0x3F) == 0x3F) {
+        NearestHit hit = hitQueue.peek();
+        float distance = (float)Math.sqrt(hit.distanceSquared);
+        float minDelta = getMinDelta(distance);
+        // String oldMin = Arrays.toString(min);
+        // String oldMax = Arrays.toString(max); 
+        for (int d = 0 ; d < dims ; ++d) {
+          min[d] = (origin[d] - distance) - minDelta;
+          max[d] = (origin[d] + distance) + minDelta;
+          // System.out.println("origin[" + d + "] (" + origin[d] + ") - distance (" + distance
+ ") - minDelta (" + minDelta + ") = min[" + d + "] (" + min[d] + ")");
+          // System.out.println("origin[" + d + "] (" + origin[d] + ") + distance (" + distance
+ ") + minDelta (" + minDelta + ") = max[" + d + "] (" + max[d] + ")");
+        }
+        // System.out.println("maybeUpdateMinMax:  min: " + oldMin + " -> " + Arrays.toString(min)
+ "   max: " + oldMax + " -> " + Arrays.toString(max));
+      }
+      ++updateMinMaxCounter;
+    }
+
+    @Override
+    public void visit(int docID, byte[] packedValue) {
+      // System.out.println("visit docID=" + docID + " liveDocs=" + curLiveDocs);
+
+      if (curLiveDocs != null && curLiveDocs.get(docID) == false) {
+        return;
+      }
+
+      float[] docPoint = new float[dims];
+      for (int d = 0, offset = 0 ; d < dims ; ++d, offset += Float.BYTES) {
+        docPoint[d] = FloatPoint.decodeDimension(packedValue, offset);
+        if (docPoint[d] > max[d] || docPoint[d] < min[d]) {
+
+          // if (docPoint[d] > max[d]) {
+          //   System.out.println("  skipped because docPoint[" + d + "] (" + docPoint[d]
+ ") > max[" + d + "] (" + max[d] + ")");
+          // } else {
+          //   System.out.println("  skipped because docPoint[" + d + "] (" + docPoint[d]
+ ") < min[" + d + "] (" + min[d] + ")");
+          // }
+
+          return;
+        }
+      }
+        
+      double distanceSquared = euclideanDistanceSquared(origin, docPoint);
+
+      // System.out.println("    visit docID=" + docID + " distanceSquared=" + distanceSquared
+ " value: " + Arrays.toString(docPoint));
+
+      int fullDocID = curDocBase + docID;
+
+      if (hitQueue.size() == topN) { // queue already full
+        NearestHit bottom = hitQueue.peek();
+        // System.out.println("      bottom distanceSquared=" + bottom.distanceSquared);
+        if (distanceSquared < bottom.distanceSquared
+            // we don't collect docs in order here, so we must also test the tie-break case
ourselves:
+            || (distanceSquared == bottom.distanceSquared && fullDocID < bottom.docID))
{
+          hitQueue.poll();
+          bottom.docID = fullDocID;
+          bottom.distanceSquared = distanceSquared;
+          hitQueue.offer(bottom);
+          // System.out.println("      ** keep1, now bottom=" + bottom);
+          maybeUpdateMinMax();
+        }
+      } else {
+        NearestHit hit = new NearestHit();
+        hit.docID = fullDocID;
+        hit.distanceSquared = distanceSquared;
+        hitQueue.offer(hit);
+        // System.out.println("      ** keep2, new addition=" + hit);
+      }
+    }
+
+    @Override
+    public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
+      throw new AssertionError();
+    }
+  }
+
+  static class NearestHit {
+    public int docID;
+    public double distanceSquared;
+
+    @Override
+    public String toString() {
+      return "NearestHit(docID=" + docID + " distanceSquared=" + distanceSquared + ")";
+    }
+  }
+
+  public static NearestHit[] nearest(List<BKDReader> readers, List<Bits> liveDocs,
List<Integer> docBases, final int topN, float[] origin) throws IOException {
+
+    // System.out.println("NEAREST: readers=" + readers + " liveDocs=" + liveDocs + " origin:
" + Arrays.toString(origin));
+
+    // Holds closest collected points seen so far:
+    // TODO: if we used lucene's PQ we could just updateTop instead of poll/offer:
+    final PriorityQueue<NearestHit> hitQueue = new PriorityQueue<>(topN, (a,
b) -> {
+      // sort by opposite distance natural order
+      int cmp = Double.compare(a.distanceSquared, b.distanceSquared);
+      return cmp != 0 ? -cmp : b.docID - a.docID; // tie-break by higher docID
+    });
+
+    // Holds all cells, sorted by closest to the point:
+    PriorityQueue<Cell> cellQueue = new PriorityQueue<>();
+
+    NearestVisitor visitor = new NearestVisitor(hitQueue, topN, origin);
+    List<BKDReader.IntersectState> states = new ArrayList<>();
+
+    // Add root cell for each reader into the queue:
+    int bytesPerDim = -1;
+
+    for (int i = 0 ; i < readers.size() ; ++i) {
+      BKDReader reader = readers.get(i);
+      if (bytesPerDim == -1) {
+        bytesPerDim = reader.getBytesPerDimension();
+      } else if (bytesPerDim != reader.getBytesPerDimension()) {
+        throw new IllegalStateException("bytesPerDim changed from " + bytesPerDim
+            + " to " + reader.getBytesPerDimension() + " across readers");
+      }
+      byte[] minPackedValue = reader.getMinPackedValue();
+      byte[] maxPackedValue = reader.getMaxPackedValue();
+      BKDReader.IntersectState state = reader.getIntersectState(visitor);
+      states.add(state);
+
+      cellQueue.offer(new Cell(state.index, i, reader.getMinPackedValue(), reader.getMaxPackedValue(),
+          approxBestDistanceSquared(minPackedValue, maxPackedValue, origin)));
+    }
+
+    LOOP_OVER_CELLS: while (cellQueue.size() > 0) {
+      Cell cell = cellQueue.poll();
+      // System.out.println("  visit " + cell);
+
+      // TODO: if we replace approxBestDistance with actualBestDistance, we can put an opto
here to break once this "best" cell is fully outside of the hitQueue bottom's radius:
+      BKDReader reader = readers.get(cell.readerIndex);
+
+      if (cell.index.isLeafNode()) {
+        // System.out.println("    leaf");
+        // Leaf block: visit all points and possibly collect them:
+        visitor.curDocBase = docBases.get(cell.readerIndex);
+        visitor.curLiveDocs = liveDocs.get(cell.readerIndex);
+        reader.visitLeafBlockValues(cell.index, states.get(cell.readerIndex));
+        // System.out.println("    now " + hitQueue.size() + " hits");
+      } else {
+        // System.out.println("    non-leaf");
+        // Non-leaf block: split into two cells and put them back into the queue:
+
+        if (hitQueue.size() == topN) {
+          for (int d = 0, offset = 0; d < visitor.dims; ++d, offset += Float.BYTES) {
+            float cellMaxAtDim = FloatPoint.decodeDimension(cell.maxPacked, offset);
+            float cellMinAtDim = FloatPoint.decodeDimension(cell.minPacked, offset);
+            if (cellMaxAtDim < visitor.min[d] || cellMinAtDim > visitor.max[d]) {
+              // this cell is outside our search radius; don't bother exploring any more
+
+              // if (cellMaxAtDim < visitor.min[d]) {
+              //   System.out.println("  skipped because cell max at " + d + " (" + cellMaxAtDim
+ ") < visitor.min[" + d + "] (" + visitor.min[d] + ")");
+              // } else {
+              //   System.out.println("  skipped because cell min at " + d + " (" + cellMinAtDim
+ ") > visitor.max[" + d + "] (" + visitor.max[d] + ")");
+              // }
+
+              continue LOOP_OVER_CELLS;
+            }
+          }
+        }
+        BytesRef splitValue = BytesRef.deepCopyOf(cell.index.getSplitDimValue());
+        int splitDim = cell.index.getSplitDim();
+
+        // we must clone the index so that we we can recurse left and right "concurrently":
+        BKDReader.IndexTree newIndex = cell.index.clone();
+        byte[] splitPackedValue = cell.maxPacked.clone();
+        System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim
* bytesPerDim, bytesPerDim);
+
+        cell.index.pushLeft();
+        cellQueue.offer(new Cell(cell.index, cell.readerIndex, cell.minPacked, splitPackedValue,
+            approxBestDistanceSquared(cell.minPacked, splitPackedValue, origin)));
+
+        splitPackedValue = cell.minPacked.clone();
+        System.arraycopy(splitValue.bytes, splitValue.offset, splitPackedValue, splitDim
* bytesPerDim, bytesPerDim);
+
+        newIndex.pushRight();
+        cellQueue.offer(new Cell(newIndex, cell.readerIndex, splitPackedValue, cell.maxPacked,
+            approxBestDistanceSquared(splitPackedValue, cell.maxPacked, origin)));
+      }
+    }
+
+    NearestHit[] hits = new NearestHit[hitQueue.size()];
+    int downTo = hitQueue.size()-1;
+    while (hitQueue.size() != 0) {
+      hits[downTo] = hitQueue.poll();
+      downTo--;
+    }
+    return hits;
+  }
+
+  private static double approxBestDistanceSquared(byte[] minPackedValue, byte[] maxPackedValue,
float[] value) {
+    boolean insideCell = true;
+    float[] min = new float[value.length];
+    float[] max = new float[value.length];
+    double[] closest = new double[value.length];
+    for (int i = 0, offset = 0 ; i < value.length ; ++i, offset += Float.BYTES) {
+      min[i] = FloatPoint.decodeDimension(minPackedValue, offset);
+      max[i] = FloatPoint.decodeDimension(maxPackedValue, offset);
+      if (insideCell) {
+        if (value[i] < min[i] || value[i] > max[i]) {
+          insideCell = false;
+        }
+      }
+      double minDiff = Math.abs((double)value[i] - (double)min[i]);
+      double maxDiff = Math.abs((double)value[i] - (double)max[i]);
+      closest[i] = minDiff < maxDiff ? minDiff : maxDiff;
+    }
+    if (insideCell) {
+      return 0.0f;
+    }
+    double sumOfSquaredDiffs = 0.0d;
+    for (int d = 0 ; d < value.length ; ++d) {
+      sumOfSquaredDiffs += closest[d] * closest[d];
+    }
+    return sumOfSquaredDiffs;
+  }
+  
+  static double euclideanDistanceSquared(float[] a, float[] b) {
+    double sumOfSquaredDifferences = 0.0d;
+    for (int d = 0 ; d < a.length ; ++d) {
+      double diff = (double)a[d] - (double)b[d]; 
+      sumOfSquaredDifferences += diff * diff;
+    }
+    return sumOfSquaredDifferences;
+  }
+
+  public static TopFieldDocs nearest(IndexSearcher searcher, String field, int topN, float...
origin) throws IOException {
+    if (topN < 1) {
+      throw new IllegalArgumentException("topN must be at least 1; got " + topN);
+    }
+    if (field == null) {
+      throw new IllegalArgumentException("field must not be null");
+    }
+    if (searcher == null) {
+      throw new IllegalArgumentException("searcher must not be null");
+    }
+    List<BKDReader> readers = new ArrayList<>();
+    List<Integer> docBases = new ArrayList<>();
+    List<Bits> liveDocs = new ArrayList<>();
+    int totalHits = 0;
+    for (LeafReaderContext leaf : searcher.getIndexReader().leaves()) {
+      PointValues points = leaf.reader().getPointValues(field);
+      if (points != null) {
+        if (points instanceof BKDReader == false) {
+          throw new IllegalArgumentException("can only run on Lucene60PointsReader points
implementation, but got " + points);
+        }
+        totalHits += points.getDocCount();
+        readers.add((BKDReader)points);
+        docBases.add(leaf.docBase);
+        liveDocs.add(leaf.reader().getLiveDocs());
+      }
+    }
+
+    NearestHit[] hits = nearest(readers, liveDocs, docBases, topN, origin);
+
+    // Convert to TopFieldDocs:
+    ScoreDoc[] scoreDocs = new ScoreDoc[hits.length];
+    for(int i=0;i<hits.length;i++) {
+      NearestHit hit = hits[i];
+      scoreDocs[i] = new FieldDoc(hit.docID, 0.0f, new Object[] { (float)Math.sqrt(hit.distanceSquared)
});
+    }
+    return new TopFieldDocs(totalHits, scoreDocs, null, 0.0f);
+  }
+}

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/d52564c0/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java
----------------------------------------------------------------------
diff --git a/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java
b/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java
new file mode 100644
index 0000000..9e3c3c1
--- /dev/null
+++ b/lucene/sandbox/src/test/org/apache/lucene/document/TestFloatPointNearestNeighbor.java
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.document;
+
+import java.util.Arrays;
+
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.PointValues;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.index.SerialMergeScheduler;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.FieldDoc;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.LuceneTestCase;
+import org.apache.lucene.util.TestUtil;
+
+public class TestFloatPointNearestNeighbor extends LuceneTestCase {
+
+  public void testNearestNeighborWithDeletedDocs() throws Exception {
+    Directory dir = newDirectory();
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, getIndexWriterConfig());
+    Document doc = new Document();
+    doc.add(new FloatPoint("point", 40.0f, 50.0f));
+    doc.add(new StringField("id", "0", Field.Store.YES));
+    w.addDocument(doc);
+
+    doc = new Document();
+    doc.add(new FloatPoint("point", 45.0f, 55.0f));
+    doc.add(new StringField("id", "1", Field.Store.YES));
+    w.addDocument(doc);
+
+    DirectoryReader r = w.getReader();
+    // can't wrap because we require Lucene60PointsFormat directly but e.g. ParallelReader
wraps with its own points impl:
+    IndexSearcher s = newSearcher(r, false);
+    FieldDoc hit = (FieldDoc)FloatPointNearestNeighbor.nearest(s, "point", 1, 40.0f, 50.0f).scoreDocs[0];
+    assertEquals("0", r.document(hit.doc).getField("id").stringValue());
+    r.close();
+
+    w.deleteDocuments(new Term("id", "0"));
+    r = w.getReader();
+    // can't wrap because we require Lucene60PointsFormat directly but e.g. ParallelReader
wraps with its own points impl:
+    s = newSearcher(r, false);
+    hit = (FieldDoc)LatLonPoint.nearest(s, "point", 40.0, 50.0, 1).scoreDocs[0];
+    assertEquals("1", r.document(hit.doc).getField("id").stringValue());
+    r.close();
+    w.close();
+    dir.close();
+  }
+
+  public void testNearestNeighborWithAllDeletedDocs() throws Exception {
+    Directory dir = newDirectory();
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, getIndexWriterConfig());
+    Document doc = new Document();
+    doc.add(new FloatPoint("point", 40.0f, 50.0f));
+    doc.add(new StringField("id", "0", Field.Store.YES));
+    w.addDocument(doc);
+    doc = new Document();
+    doc.add(new FloatPoint("point", 45.0f, 55.0f));
+    doc.add(new StringField("id", "1", Field.Store.YES));
+    w.addDocument(doc);
+
+    DirectoryReader r = w.getReader();
+    // can't wrap because we require Lucene60PointsFormat directly but e.g. ParallelReader
wraps with its own points impl:
+    IndexSearcher s = newSearcher(r, false);
+    FieldDoc hit = (FieldDoc)FloatPointNearestNeighbor.nearest(s, "point", 1, 40.0f, 50.0f).scoreDocs[0];
+    assertEquals("0", r.document(hit.doc).getField("id").stringValue());
+    r.close();
+
+    w.deleteDocuments(new Term("id", "0"));
+    w.deleteDocuments(new Term("id", "1"));
+    r = w.getReader();
+    // can't wrap because we require Lucene60PointsFormat directly but e.g. ParallelReader
wraps with its own points impl:
+    s = newSearcher(r, false);
+    assertEquals(0, FloatPointNearestNeighbor.nearest(s, "point", 1, 40.0f, 50.0f).scoreDocs.length);
+    r.close();
+    w.close();
+    dir.close();
+  }
+
+  public void testTieBreakByDocID() throws Exception {
+    Directory dir = newDirectory();
+    IndexWriter w = new IndexWriter(dir, getIndexWriterConfig());
+    Document doc = new Document();
+    doc.add(new FloatPoint("point", 40.0f, 50.0f));
+    doc.add(new StringField("id", "0", Field.Store.YES));
+    w.addDocument(doc);
+    doc = new Document();
+    doc.add(new FloatPoint("point", 40.0f, 50.0f));
+    doc.add(new StringField("id", "1", Field.Store.YES));
+    w.addDocument(doc);
+
+    DirectoryReader r = DirectoryReader.open(w);
+    // can't wrap because we require Lucene60PointsFormat directly but e.g. ParallelReader
wraps with its own points impl:
+    ScoreDoc[] hits = FloatPointNearestNeighbor.nearest(newSearcher(r, false), "point", 2,
45.0f, 50.0f).scoreDocs;
+    assertEquals("0", r.document(hits[0].doc).getField("id").stringValue());
+    assertEquals("1", r.document(hits[1].doc).getField("id").stringValue());
+
+    r.close();
+    w.close();
+    dir.close();
+  }
+
+  public void testNearestNeighborWithNoDocs() throws Exception {
+    Directory dir = newDirectory();
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, getIndexWriterConfig());
+    DirectoryReader r = w.getReader();
+    // can't wrap because we require Lucene60PointsFormat directly but e.g. ParallelReader
wraps with its own points impl:
+    assertEquals(0, FloatPointNearestNeighbor.nearest(newSearcher(r, false), "point", 1,
40.0f, 50.0f).scoreDocs.length);
+    r.close();
+    w.close();
+    dir.close();
+  }
+
+  public void testNearestNeighborRandom() throws Exception {
+    Directory dir;
+    int numPoints = atLeast(5000);
+    if (numPoints > 100000) {
+      dir = newFSDirectory(createTempDir(getClass().getSimpleName()));
+    } else {
+      dir = newDirectory();
+    }
+    IndexWriterConfig iwc = getIndexWriterConfig();
+    iwc.setMergePolicy(newLogMergePolicy());
+    iwc.setMergeScheduler(new SerialMergeScheduler());
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
+
+    int dims = TestUtil.nextInt(random(), 1, PointValues.MAX_DIMENSIONS);
+    float[][] values = new float[numPoints][dims];
+    for (int id = 0 ; id < numPoints ; ++id) {
+      for (int dim = 0 ; dim < dims ; ++dim) {
+        Float f = Float.NaN;
+        while (f.isNaN()) {
+          f = Float.intBitsToFloat(random().nextInt());
+        }
+        values[id][dim] = f;
+      }
+      Document doc = new Document();
+      doc.add(new FloatPoint("point", values[id]));
+      doc.add(new StoredField("id", id));
+      w.addDocument(doc);
+    }
+
+    if (random().nextBoolean()) {
+      w.forceMerge(1);
+    }
+
+    DirectoryReader r = w.getReader();
+    if (VERBOSE) {
+      System.out.println("TEST: reader=" + r);
+    }
+    // can't wrap because we require Lucene60PointsFormat directly but e.g. ParallelReader
wraps with its own points impl:
+    IndexSearcher s = newSearcher(r, false);
+    int iters = atLeast(100);
+    for (int iter = 0 ; iter < iters ; ++iter) {
+      if (VERBOSE) {
+        System.out.println("\nTEST: iter=" + iter);
+      }
+      float[] origin = new float[dims];
+      for (int dim = 0 ; dim < dims ; ++dim) {
+        Float f = Float.NaN;
+        while (f.isNaN()) {
+          f = Float.intBitsToFloat(random().nextInt());
+        }
+        origin[dim] = f;
+      }
+
+      // dumb brute force search to get the expected result:
+      FloatPointNearestNeighbor.NearestHit[] expectedHits = new FloatPointNearestNeighbor.NearestHit[numPoints];
+      for (int id = 0 ; id < numPoints ; ++id) {
+        FloatPointNearestNeighbor.NearestHit hit = new FloatPointNearestNeighbor.NearestHit();
+        hit.distanceSquared = FloatPointNearestNeighbor.euclideanDistanceSquared(origin,
values[id]);
+        hit.docID = id;
+        expectedHits[id] = hit;
+      }
+
+      Arrays.sort(expectedHits, (a, b) -> {
+        int cmp = Double.compare(a.distanceSquared, b.distanceSquared);
+        return cmp != 0 ? cmp : a.docID - b.docID; // tie break by smaller id
+      });
+
+      int topK = TestUtil.nextInt(random(), 1, numPoints);
+
+      if (VERBOSE) {
+        System.out.println("\nhits for origin=" + Arrays.toString(origin));
+      }
+
+      ScoreDoc[] hits = FloatPointNearestNeighbor.nearest(s, "point", topK, origin).scoreDocs;
+      assertEquals("fewer than expected hits: ", topK, hits.length);
+
+      if (VERBOSE) {
+        for (int i = 0 ; i < topK ; ++i) {
+          FloatPointNearestNeighbor.NearestHit expected = expectedHits[i];
+          FieldDoc actual = (FieldDoc)hits[i];
+          Document actualDoc = r.document(actual.doc);
+          System.out.println("hit " + i);
+          System.out.println("  expected id=" + expected.docID + "  " + Arrays.toString(values[expected.docID])
+              + "  distance=" + (float)Math.sqrt(expected.distanceSquared) + "  distanceSquared="
+ expected.distanceSquared);
+          System.out.println("  actual id=" + actualDoc.getField("id") + " distance=" + actual.fields[0]);
+        }
+      }
+
+      for (int i = 0 ; i < topK ; ++i) {
+        FloatPointNearestNeighbor.NearestHit expected = expectedHits[i];
+        FieldDoc actual = (FieldDoc)hits[i];
+        assertEquals("hit " + i + ":", expected.docID, actual.doc);
+        assertEquals("hit " + i + ":", (float)Math.sqrt(expected.distanceSquared), (Float)actual.fields[0],
0.000001);
+      }
+    }
+
+    r.close();
+    w.close();
+    dir.close();
+  }
+
+  private IndexWriterConfig getIndexWriterConfig() {
+    IndexWriterConfig iwc = newIndexWriterConfig();
+    iwc.setCodec(Codec.forName("Lucene70"));
+    return iwc;
+  }
+}


Mime
View raw message