blob: 062718dc5fe4c324328fe81b77b8644ea4fb43fe [file] [log] [blame]
package org.apache.lucene.search;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.lucene.util.PriorityQueue;
import java.io.IOException;
/** Represents hits returned by {@link
* IndexSearcher#search(Query,Filter,int)} and {@link
* IndexSearcher#search(Query,int)}. */
public class TopDocs {
/** The total number of hits for the query. */
public int totalHits;
/** The top hits for the query. */
public ScoreDoc[] scoreDocs;
/** Stores the maximum score value encountered, needed for normalizing. */
private float maxScore;
/**
* Returns the maximum score value encountered. Note that in case
* scores are not tracked, this returns {@link Float#NaN}.
*/
public float getMaxScore() {
return maxScore;
}
/** Sets the maximum score value encountered. */
public void setMaxScore(float maxScore) {
this.maxScore = maxScore;
}
/** Constructs a TopDocs with a default maxScore=Float.NaN. */
TopDocs(int totalHits, ScoreDoc[] scoreDocs) {
this(totalHits, scoreDocs, Float.NaN);
}
public TopDocs(int totalHits, ScoreDoc[] scoreDocs, float maxScore) {
this.totalHits = totalHits;
this.scoreDocs = scoreDocs;
this.maxScore = maxScore;
}
// Refers to one hit:
private static class ShardRef {
// Which shard (index into shardHits[]):
final int shardIndex;
// Which hit within the shard:
int hitIndex;
public ShardRef(int shardIndex) {
this.shardIndex = shardIndex;
}
@Override
public String toString() {
return "ShardRef(shardIndex=" + shardIndex + " hitIndex=" + hitIndex + ")";
}
};
// Specialized MergeSortQueue that just merges by
// relevance score, descending:
private static class ScoreMergeSortQueue extends PriorityQueue<ShardRef> {
final ScoreDoc[][] shardHits;
public ScoreMergeSortQueue(TopDocs[] shardHits) {
super(shardHits.length);
this.shardHits = new ScoreDoc[shardHits.length][];
for(int shardIDX=0;shardIDX<shardHits.length;shardIDX++) {
this.shardHits[shardIDX] = shardHits[shardIDX].scoreDocs;
}
}
// Returns true if first is < second
@Override
public boolean lessThan(ShardRef first, ShardRef second) {
assert first != second;
final float firstScore = shardHits[first.shardIndex][first.hitIndex].score;
final float secondScore = shardHits[second.shardIndex][second.hitIndex].score;
if (firstScore < secondScore) {
return false;
} else if (firstScore > secondScore) {
return true;
} else {
// Tie break: earlier shard wins
if (first.shardIndex < second.shardIndex) {
return true;
} else if (first.shardIndex > second.shardIndex) {
return false;
} else {
// Tie break in same shard: resolve however the
// shard had resolved it:
assert first.hitIndex != second.hitIndex;
return first.hitIndex < second.hitIndex;
}
}
}
}
@SuppressWarnings({"rawtypes","unchecked"})
private static class MergeSortQueue extends PriorityQueue<ShardRef> {
// These are really FieldDoc instances:
final ScoreDoc[][] shardHits;
final FieldComparator<?>[] comparators;
final int[] reverseMul;
public MergeSortQueue(Sort sort, TopDocs[] shardHits) throws IOException {
super(shardHits.length);
this.shardHits = new ScoreDoc[shardHits.length][];
for(int shardIDX=0;shardIDX<shardHits.length;shardIDX++) {
final ScoreDoc[] shard = shardHits[shardIDX].scoreDocs;
//System.out.println(" init shardIdx=" + shardIDX + " hits=" + shard);
if (shard != null) {
this.shardHits[shardIDX] = shard;
// Fail gracefully if API is misused:
for(int hitIDX=0;hitIDX<shard.length;hitIDX++) {
final ScoreDoc sd = shard[hitIDX];
if (!(sd instanceof FieldDoc)) {
throw new IllegalArgumentException("shard " + shardIDX + " was not sorted by the provided Sort (expected FieldDoc but got ScoreDoc)");
}
final FieldDoc fd = (FieldDoc) sd;
if (fd.fields == null) {
throw new IllegalArgumentException("shard " + shardIDX + " did not set sort field values (FieldDoc.fields is null); you must pass fillFields=true to IndexSearcher.search on each shard");
}
}
}
}
final SortField[] sortFields = sort.getSort();
comparators = new FieldComparator[sortFields.length];
reverseMul = new int[sortFields.length];
for(int compIDX=0;compIDX<sortFields.length;compIDX++) {
final SortField sortField = sortFields[compIDX];
comparators[compIDX] = sortField.getComparator(1, compIDX);
reverseMul[compIDX] = sortField.getReverse() ? -1 : 1;
}
}
// Returns true if first is < second
@Override
@SuppressWarnings({"unchecked","rawtypes"})
public boolean lessThan(ShardRef first, ShardRef second) {
assert first != second;
final FieldDoc firstFD = (FieldDoc) shardHits[first.shardIndex][first.hitIndex];
final FieldDoc secondFD = (FieldDoc) shardHits[second.shardIndex][second.hitIndex];
//System.out.println(" lessThan:\n first=" + first + " doc=" + firstFD.doc + " score=" + firstFD.score + "\n second=" + second + " doc=" + secondFD.doc + " score=" + secondFD.score);
for(int compIDX=0;compIDX<comparators.length;compIDX++) {
final FieldComparator comp = comparators[compIDX];
//System.out.println(" cmp idx=" + compIDX + " cmp1=" + firstFD.fields[compIDX] + " cmp2=" + secondFD.fields[compIDX] + " reverse=" + reverseMul[compIDX]);
final int cmp = reverseMul[compIDX] * comp.compareValues(firstFD.fields[compIDX], secondFD.fields[compIDX]);
if (cmp != 0) {
//System.out.println(" return " + (cmp < 0));
return cmp < 0;
}
}
// Tie break: earlier shard wins
if (first.shardIndex < second.shardIndex) {
//System.out.println(" return tb true");
return true;
} else if (first.shardIndex > second.shardIndex) {
//System.out.println(" return tb false");
return false;
} else {
// Tie break in same shard: resolve however the
// shard had resolved it:
//System.out.println(" return tb " + (first.hitIndex < second.hitIndex));
assert first.hitIndex != second.hitIndex;
return first.hitIndex < second.hitIndex;
}
}
}
/** Returns a new TopDocs, containing topN results across
* the provided TopDocs, sorting by the specified {@link
* Sort}. Each of the TopDocs must have been sorted by
* the same Sort, and sort field values must have been
* filled (ie, <code>fillFields=true</code> must be
* passed to {@link
* TopFieldCollector#create}.
*
* <p>Pass sort=null to merge sort by score descending.
*
* @lucene.experimental */
public static TopDocs merge(Sort sort, int topN, TopDocs[] shardHits) throws IOException {
return merge(sort, 0, topN, shardHits);
}
/**
* Same as {@link #merge(Sort, int, TopDocs[])} but also slices the result at the same time based
* on the provided start and size. The return TopDocs will always have a scoreDocs with length of at most size.
*/
public static TopDocs merge(Sort sort, int start, int size, TopDocs[] shardHits) throws IOException {
final PriorityQueue<ShardRef> queue;
if (sort == null) {
queue = new ScoreMergeSortQueue(shardHits);
} else {
queue = new MergeSortQueue(sort, shardHits);
}
int totalHitCount = 0;
int availHitCount = 0;
float maxScore = Float.MIN_VALUE;
for(int shardIDX=0;shardIDX<shardHits.length;shardIDX++) {
final TopDocs shard = shardHits[shardIDX];
// totalHits can be non-zero even if no hits were
// collected, when searchAfter was used:
totalHitCount += shard.totalHits;
if (shard.scoreDocs != null && shard.scoreDocs.length > 0) {
availHitCount += shard.scoreDocs.length;
queue.add(new ShardRef(shardIDX));
maxScore = Math.max(maxScore, shard.getMaxScore());
//System.out.println(" maxScore now " + maxScore + " vs " + shard.getMaxScore());
}
}
if (availHitCount == 0) {
maxScore = Float.NaN;
}
final ScoreDoc[] hits;
if (availHitCount <= start) {
hits = new ScoreDoc[0];
} else {
hits = new ScoreDoc[Math.min(size, availHitCount - start)];
int requestedResultWindow = start + size;
int numIterOnHits = Math.min(availHitCount, requestedResultWindow);
int hitUpto = 0;
while (hitUpto < numIterOnHits) {
assert queue.size() > 0;
ShardRef ref = queue.pop();
final ScoreDoc hit = shardHits[ref.shardIndex].scoreDocs[ref.hitIndex++];
hit.shardIndex = ref.shardIndex;
if (hitUpto >= start) {
hits[hitUpto - start] = hit;
}
//System.out.println(" hitUpto=" + hitUpto);
//System.out.println(" doc=" + hits[hitUpto].doc + " score=" + hits[hitUpto].score);
hitUpto++;
if (ref.hitIndex < shardHits[ref.shardIndex].scoreDocs.length) {
// Not done with this these TopDocs yet:
queue.add(ref);
}
}
}
if (sort == null) {
return new TopDocs(totalHitCount, hits, maxScore);
} else {
return new TopFieldDocs(totalHitCount, hits, sort.getSort(), maxScore);
}
}
}