lucene-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tom...@apache.org
Subject [lucene-solr] 05/10: Add MoreLikeThis like factory method
Date Sun, 12 Jan 2020 06:11:40 GMT
This is an automated email from the ASF dual-hosted git repository.

tomoko pushed a commit to branch jira/lucene-9004-aknn-2
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git

commit 624632cc308c42af4d166446841b110f1bffddad
Author: Tomoko Uchida <tomoko@apache.org>
AuthorDate: Sat Dec 21 19:45:47 2019 +0900

    Add MoreLikeThis like factory method
---
 .../codecs/lucene90/Lucene90FieldInfosFormat.java  |   2 +-
 .../org/apache/lucene/index/MultiVectorValues.java | 145 +++++++++++++++++++++
 .../java/org/apache/lucene/index/VectorValues.java |  16 ++-
 .../org/apache/lucene/search/KnnGraphQuery.java    |  27 +++-
 .../org/apache/lucene/search/KnnScoreWeight.java   |  12 +-
 .../org/apache/lucene/util/hnsw/HNSWGraph.java     |   6 +-
 .../apache/lucene/util/hnsw/HNSWGraphReader.java   |   8 +-
 7 files changed, 205 insertions(+), 11 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java
b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java
index e91ab88..1ddcf97 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java
@@ -241,8 +241,8 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat {
       case NONE:
       case MANHATTAN:
       case EUCLIDEAN:
-        return (byte)distFunc.getId();
       case COSINE:
+        return (byte)distFunc.getId();
       default:
         // BUG
         throw new AssertionError("unhandled DistanceFunction: " + distFunc);
diff --git a/lucene/core/src/java/org/apache/lucene/index/MultiVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/MultiVectorValues.java
new file mode 100644
index 0000000..b8abaaa
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/index/MultiVectorValues.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.lucene.index;
+
+import java.io.IOException;
+import java.util.List;
+
+public class MultiVectorValues {
+
+  /** Returns a vector values for a reader */
+  public static VectorValues getVectorValues(final IndexReader r, final String field) throws
IOException {
+    final List<LeafReaderContext> leaves = r.leaves();
+    final int size = leaves.size();
+    if (size == 0) {
+      return null;
+    } else if (size == 1) {
+      return leaves.get(0).reader().getVectorValues(field);
+    }
+
+    boolean anyReal = false;
+    for(LeafReaderContext leaf : leaves) {
+      FieldInfo fieldInfo = leaf.reader().getFieldInfos().fieldInfo(field);
+      if (fieldInfo != null) {
+        if (fieldInfo.getVectorNumDimensions() != 0) {
+          anyReal = true;
+          break;
+        }
+      }
+    }
+
+    if (anyReal == false) {
+      return null;
+    }
+
+
+    return new VectorValues() {
+      private int nextLeaf;
+      private VectorValues currentValues;
+      private LeafReaderContext currentLeaf;
+      private int docID = -1;
+
+      @Override
+      public int docID() {
+        return docID;
+      }
+
+      @Override
+      public int nextDoc() throws IOException {
+        while (true) {
+          while (currentValues == null) {
+            if (nextLeaf == leaves.size()) {
+              docID = NO_MORE_DOCS;
+              return docID;
+            }
+            currentLeaf = leaves.get(nextLeaf);
+            currentValues = currentLeaf.reader().getVectorValues(field);
+            nextLeaf++;
+          }
+
+          int newDocID = currentValues.nextDoc();
+
+          if (newDocID == NO_MORE_DOCS) {
+            currentValues = null;
+            continue;
+          } else {
+            docID = currentLeaf.docBase + newDocID;
+            return docID;
+          }
+        }
+      }
+
+      @Override
+      public int advance(int target) throws IOException {
+        if (target <= docID) {
+          throw new IllegalArgumentException("can only advance beyond current document: on
docID=" + docID + " but targetDocID=" + target);
+        }
+        int readerIndex = ReaderUtil.subIndex(target, leaves);
+        if (readerIndex >= nextLeaf) {
+          if (readerIndex == leaves.size()) {
+            currentValues = null;
+            docID = NO_MORE_DOCS;
+            return docID;
+          }
+          currentLeaf = leaves.get(readerIndex);
+          currentValues = currentLeaf.reader().getVectorValues(field);
+          nextLeaf = readerIndex+1;
+          if (currentValues == null) {
+            return nextDoc();
+          }
+        }
+        int newDocID = currentValues.advance(target - currentLeaf.docBase);
+        if (newDocID == NO_MORE_DOCS) {
+          currentValues = null;
+          return nextDoc();
+        } else {
+          docID = currentLeaf.docBase + newDocID;
+          return docID;
+        }
+      }
+
+
+      @Override
+      public float[] vectorValue() throws IOException {
+        return currentValues.vectorValue();
+      }
+
+      @Override
+      public boolean seek(int target) throws IOException {
+        nextLeaf = 0;
+        for (int i = 0; i < leaves.size(); i++) {
+          currentLeaf = leaves.get(i);
+          currentValues = currentLeaf.reader().getVectorValues(field);
+          if (currentValues.seek(target - currentLeaf.docBase)) {
+            return true;
+          }
+          nextLeaf++;
+        }
+        return false;
+      }
+
+      @Override
+      public long cost() {
+        // TODO
+        return 0;
+      }
+    };
+  }
+
+
+}
diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java
index b0aed22..e83d174 100644
--- a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java
+++ b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java
@@ -153,8 +153,20 @@ public abstract class VectorValues extends DocIdSetIterator {
     COSINE(3) {
       @Override
       float distance(float[] v1, float[] v2) {
-        // TODO
-        return 0.0f;
+        assert v1.length == v2.length;
+        if (Arrays.equals(v1, v2)) {
+          return 0.0f;
+        }
+        float sum = 0.0f;
+        float squareSum1 = 0.0f;
+        float squareSum2 = 0.0f;
+        int dim = v1.length;
+        for (int i = 0; i < dim; i++) {
+          sum += v1[i] * v2[i];
+          squareSum1 += v1[i] * v1[i];
+          squareSum2 += v2[i] * v2[i];
+        }
+        return 1.0f - sum / ((float)Math.sqrt(squareSum1) * (float)Math.sqrt(squareSum2));
       }
     };
 
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnGraphQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnGraphQuery.java
index 0ee5ab2..b1c14f6 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnGraphQuery.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnGraphQuery.java
@@ -20,9 +20,17 @@ package org.apache.lucene.search;
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Objects;
+import java.util.Set;
 
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.FieldInfo;
+import org.apache.lucene.index.FieldInfos;
 import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.MultiFields;
+import org.apache.lucene.index.MultiVectorValues;
+import org.apache.lucene.index.VectorValues;
 import org.apache.lucene.util.Accountable;
+import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.RamUsageEstimator;
 import org.apache.lucene.util.hnsw.HNSWGraphReader;
 
@@ -71,16 +79,31 @@ public class KnnGraphQuery extends Query implements Accountable {
    * @param ef number of per-segment candidates to be scored/collected. the collector does
not return results exceeding {@code ef}.
    *           increasing this value leads higher recall at the expense of the search speed.
    * @param reader index reader
+   * @param forceReload if true forcibly reloads kNN graph
    */
-  public KnnGraphQuery(String field, float[] queryVector, int ef, IndexReader reader) throws
IOException {
+  public KnnGraphQuery(String field, float[] queryVector, int ef, IndexReader reader, boolean
forceReload) throws IOException {
     this.field = field;
     this.queryVector = queryVector;
     this.ef = ef;
     if (reader != null) {
-      this.bytesUsed = HNSWGraphReader.reloadGraph(field, reader);
+      this.bytesUsed = HNSWGraphReader.loadGraphs(field, reader, forceReload);
     }
   }
 
+  public static KnnGraphQuery like(String field, int docId, int ef, IndexReader reader, boolean
forceReload) throws IOException {
+    FieldInfo fi = FieldInfos.getMergedFieldInfos(reader).fieldInfo(field);
+    int numDimensions = fi.getVectorNumDimensions();
+    if (numDimensions == 0) {
+      throw new IllegalArgumentException("Doc " + docId + " has no vector values.");
+    }
+    VectorValues vectorValues = MultiVectorValues.getVectorValues(reader, field);
+    boolean found = vectorValues.seek(docId);
+    if (!found) {
+      throw new IllegalArgumentException("Doc " + docId + " has no vector values.");
+    }
+    return new KnnGraphQuery(field, vectorValues.vectorValue(), ef, reader, forceReload);
+  }
+
   @Override
   public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws
IOException {
     return new KnnScoreWeight(this, boost, scoreMode, field, queryVector, ef);
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnScoreWeight.java b/lucene/core/src/java/org/apache/lucene/search/KnnScoreWeight.java
index 73c8071..02b6a5c 100644
--- a/lucene/core/src/java/org/apache/lucene/search/KnnScoreWeight.java
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnScoreWeight.java
@@ -107,7 +107,17 @@ class KnnScoreWeight extends ConstantScoreWeight {
                     score = 0.0f;
                   } else {
                     doc = next.docId();
-                    score = 1.0f / (next.distance() / numDimensions + 0.01f);
+                    switch (fi.getVectorDistFunc()) {
+                      case MANHATTAN:
+                      case EUCLIDEAN:
+                        score = 1.0f / (next.distance() / numDimensions + 0.01f);
+                        break;
+                      case COSINE:
+                        score = 1.0f - next.distance();
+                        break;
+                      default:
+                        break;
+                    }
                   }
                 }
                 return doc;
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraph.java
index 6722e76..e2acd0b 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraph.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraph.java
@@ -40,6 +40,7 @@ public final class HNSWGraph implements Accountable {
   private final List<Layer> layers;
 
   private boolean frozen = false;
+  private long bytesUsed;
 
   public HNSWGraph(VectorValues.DistanceFunction distFunc) {
     this.distFunc = distFunc;
@@ -261,7 +262,10 @@ public final class HNSWGraph implements Accountable {
 
   @Override
   public long ramBytesUsed() {
-    return RamUsageEstimator.sizeOfCollection(layers);
+    if (bytesUsed == 0) {
+      bytesUsed = RamUsageEstimator.sizeOfCollection(layers);
+    }
+    return bytesUsed;
   }
 
 }
diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraphReader.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraphReader.java
index 8a9d0dd..b92e96b 100644
--- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraphReader.java
+++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HNSWGraphReader.java
@@ -61,19 +61,19 @@ public final class HNSWGraphReader {
     return hnsw.searchLayer(query, ep, ef, 0, vectorValues);
   }
 
-  public static long reloadGraph(String field, IndexReader reader) throws IOException {
+  public static long loadGraphs(String field, IndexReader reader, boolean forceReload) throws
IOException {
     long bytesUsed = 0L;
     for (LeafReaderContext ctx : reader.leaves()) {
-      HNSWGraph hnsw = get(field, ctx, true);
+      HNSWGraph hnsw = get(field, ctx, forceReload);
       bytesUsed += hnsw.ramBytesUsed();
     }
     return bytesUsed;
   }
 
-  private static HNSWGraph get(String field, LeafReaderContext context, boolean reload) throws
IOException {
+  private static HNSWGraph get(String field, LeafReaderContext context, boolean forceReload)
throws IOException {
     GraphKey key = new GraphKey(field, context.id());
     IOException[] exc = new IOException[]{null};
-    if (reload) {
+    if (forceReload) {
       cache.put(key, load(field, context));
     } else {
       cache.computeIfAbsent(key, (k -> {


Mime
View raw message