Skip to content

Commit

Permalink
Let DocIdSetIterator optimize loading into a FixedBitSet. (#14069)
Browse files Browse the repository at this point in the history
This is an iteration on #14064. The benefits of this approach are that the API
is a bit nicer and allows optimizing not only when doc IDs are stored in an
int[]. The downside is that it only helps non-scoring disjunctions for now, but
we can look into scoring disjunctions later on.
  • Loading branch information
jpountz authored Dec 17, 2024
1 parent 5f0fa2b commit e74f19b
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 24 deletions.
3 changes: 2 additions & 1 deletion lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ Other

API Changes
---------------------
(No changes)
* GITHUB#14069: Added DocIdSetIterator#intoBitSet API to let implementations
optimize loading doc IDs into a bit set. (Adrien Grand)

New Features
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.VectorUtil;

Expand Down Expand Up @@ -875,6 +877,63 @@ public int advance(int target) throws IOException {
return doc;
}

@Override
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
if (doc >= upTo) {
return;
}

// Handle the current doc separately, it may be on the previous docBuffer.
if (acceptDocs == null || acceptDocs.get(doc)) {
bitSet.set(doc - offset);
}

for (; ; ) {
if (docBufferUpto == BLOCK_SIZE) {
// refill
moveToNextLevel0Block();
}

int start = docBufferUpto;
int end = computeBufferEndBoundary(upTo);
if (end != 0) {
bufferIntoBitSet(start, end, acceptDocs, bitSet, offset);
doc = docBuffer[end - 1];
}
docBufferUpto = end;

if (end != BLOCK_SIZE) {
// Either the block is a tail block, or the block did not fully match, we're done.
nextDoc();
assert doc >= upTo;
break;
}
}
}

private int computeBufferEndBoundary(int upTo) {
if (docBufferSize != 0 && docBuffer[docBufferSize - 1] < upTo) {
// All docs in the buffer are under upTo
return docBufferSize;
} else {
// Find the index of the first doc that is greater than or equal to upTo
return VectorUtil.findNextGEQ(docBuffer, upTo, docBufferUpto, docBufferSize);
}
}

private void bufferIntoBitSet(
int start, int end, Bits acceptDocs, FixedBitSet bitSet, int offset) throws IOException {
// acceptDocs#get (if backed by FixedBitSet), bitSet#set and `doc - offset` get
// auto-vectorized
for (int i = start; i < end; ++i) {
int doc = docBuffer[i];
if (acceptDocs == null || acceptDocs.get(doc)) {
bitSet.set(doc - offset);
}
}
}

private void skipPositions(int freq) throws IOException {
// Skip positions now:
int toSkip = posPendingCount - freq;
Expand Down
39 changes: 17 additions & 22 deletions lucene/core/src/java/org/apache/lucene/search/BooleanScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import org.apache.lucene.internal.hppc.LongArrayList;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.PriorityQueue;

/**
Expand All @@ -34,8 +34,6 @@ final class BooleanScorer extends BulkScorer {
static final int SHIFT = 12;
static final int SIZE = 1 << SHIFT;
static final int MASK = SIZE - 1;
static final int SET_SIZE = 1 << (SHIFT - 6);
static final int SET_MASK = SET_SIZE - 1;

static class Bucket {
double score;
Expand Down Expand Up @@ -74,8 +72,7 @@ public DisiWrapper get(int i) {
// One bucket per doc ID in the window, non-null if scores are needed or if frequencies need to be
// counted
final Bucket[] buckets;
// This is basically an inlined FixedBitSet... seems to help with bound checks
final long[] matching = new long[SET_SIZE];
final FixedBitSet matching = new FixedBitSet(SIZE);

final DisiWrapper[] leads;
final HeadPriorityQueue head;
Expand All @@ -91,11 +88,12 @@ final class DocIdStreamView extends DocIdStream {

@Override
public void forEach(CheckedIntConsumer<IOException> consumer) throws IOException {
long[] matching = BooleanScorer.this.matching;
FixedBitSet matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;
int base = this.base;
for (int idx = 0; idx < matching.length; idx++) {
long bits = matching[idx];
long[] bitArray = matching.getBits();
for (int idx = 0; idx < bitArray.length; idx++) {
long bits = bitArray[idx];
while (bits != 0L) {
int ntz = Long.numberOfTrailingZeros(bits);
if (buckets != null) {
Expand All @@ -121,11 +119,7 @@ public int count() throws IOException {
// We can't just count bits in that case
return super.count();
}
int count = 0;
for (long l : matching) {
count += Long.bitCount(l);
}
return count;
return matching.cardinality();
}
}

Expand Down Expand Up @@ -173,7 +167,7 @@ public long cost() {
private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min, int max)
throws IOException {
boolean needsScores = BooleanScorer.this.needsScores;
long[] matching = BooleanScorer.this.matching;
FixedBitSet matching = BooleanScorer.this.matching;
Bucket[] buckets = BooleanScorer.this.buckets;

DocIdSetIterator it = w.iterator;
Expand All @@ -182,12 +176,13 @@ private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min,
if (doc < min) {
doc = it.advance(min);
}
for (; doc < max; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc & MASK;
final int idx = i >> 6;
matching[idx] |= 1L << i;
if (buckets != null) {
if (buckets == null) {
it.intoBitSet(acceptDocs, max, matching, doc & ~MASK);
} else {
for (; doc < max; doc = it.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc & MASK;
matching.set(i);
final Bucket bucket = buckets[i];
bucket.freq++;
if (needsScores) {
Expand All @@ -197,7 +192,7 @@ private void scoreDisiWrapperIntoBitSet(DisiWrapper w, Bits acceptDocs, int min,
}
}

w.doc = doc;
w.doc = it.docID();
}

private void scoreWindowIntoBitSetAndReplay(
Expand All @@ -218,7 +213,7 @@ private void scoreWindowIntoBitSetAndReplay(
docIdStreamView.base = base;
collector.collect(docIdStreamView);

Arrays.fill(matching, 0L);
matching.clear();
}

private DisiWrapper advance(int min) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.lucene.search;

import java.io.IOException;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;

/**
* This abstract class defines methods to iterate over a set of non-decreasing doc ids. Note that
Expand Down Expand Up @@ -211,4 +213,33 @@ protected final int slowAdvance(int target) throws IOException {
* may be a rough heuristic, hardcoded value, or otherwise completely inaccurate.
*/
public abstract long cost();

/**
* Load doc IDs into a {@link FixedBitSet}. This should behave exactly as if implemented as below,
* which is the default implementation:
*
* <pre class="prettyprint">
* for (int doc = docID(); doc &lt; upTo; doc = nextDoc()) {
* if (acceptDocs == null || acceptDocs.get(doc)) {
* bitSet.set(doc - offset);
* }
* }
* </pre>
*
* <p><b>Note</b>: {@code offset} must be less than or equal to the {@link #docID() current doc
* ID}.
*
* <p><b>Note</b>: It is important not to clear bits from {@code bitSet} that may be already set.
*
* @lucene.internal
*/
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
assert offset <= docID();
for (int doc = docID(); doc < upTo; doc = nextDoc()) {
if (acceptDocs == null || acceptDocs.get(doc)) {
bitSet.set(doc - offset);
}
}
}
}
4 changes: 3 additions & 1 deletion lucene/core/src/java/org/apache/lucene/util/FixedBitSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ public void or(DocIdSetIterator iter) throws IOException {
DocBaseBitSetIterator baseIter = (DocBaseBitSetIterator) iter;
or(baseIter.getDocBase() >> 6, baseIter.getBitSet());
} else {
super.or(iter);
checkUnpositioned(iter);
iter.nextDoc();
iter.intoBitSet(null, DocIdSetIterator.NO_MORE_DOCS, this, 0);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil;
import org.apache.lucene.tests.util.automaton.AutomatonTestUtil.RandomAcceptedStrings;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.UnicodeUtil;
Expand Down Expand Up @@ -110,6 +111,9 @@ public enum Option {
// Sometimes don't fully consume positions at each doc
PARTIAL_POS_CONSUME,

// Check DocIdSetIterator#intoBitSet
INTO_BIT_SET,

// Sometimes check payloads
PAYLOADS,

Expand Down Expand Up @@ -1364,6 +1368,54 @@ private void verifyEnum(
idx <= impactsCopy.size() && impactsCopy.get(idx).norm <= norm);
}
}

if (options.contains(Option.INTO_BIT_SET)) {
int flags = PostingsEnum.FREQS;
if (doCheckPositions) {
flags |= PostingsEnum.POSITIONS;
if (doCheckOffsets) {
flags |= PostingsEnum.OFFSETS;
}
if (doCheckPayloads) {
flags |= PostingsEnum.PAYLOADS;
}
}
PostingsEnum pe1 = termsEnum.postings(null, flags);
if (random.nextBoolean()) {
pe1.advance(maxDoc / 2);
pe1 = termsEnum.postings(pe1, flags);
}
PostingsEnum pe2 = termsEnum.postings(null, flags);
FixedBitSet set1 = new FixedBitSet(1024);
FixedBitSet set2 = new FixedBitSet(1024);
FixedBitSet acceptDocs = new FixedBitSet(maxDoc);
for (int i = 0; i < maxDoc; i += 2) {
acceptDocs.set(i);
}

while (true) {
pe1.nextDoc();
pe2.nextDoc();

int offset =
TestUtil.nextInt(random, Math.max(0, pe1.docID() - set1.length()), pe1.docID());
int upTo = offset + random.nextInt(set1.length());
pe1.intoBitSet(acceptDocs, upTo, set1, offset);
for (int d = pe2.docID(); d < upTo; d = pe2.nextDoc()) {
if (acceptDocs.get(d)) {
set2.set(d - offset);
}
}

assertEquals(set1, set2);
assertEquals(pe1.docID(), pe2.docID());
if (pe1.docID() == DocIdSetIterator.NO_MORE_DOCS) {
break;
}
set1.clear();
set2.clear();
}
}
}

private static class TestThread extends Thread {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;

/** Wraps a Scorer with additional checks */
public class AssertingScorer extends Scorer {
Expand Down Expand Up @@ -192,6 +194,15 @@ public int advance(int target) throws IOException {
public long cost() {
return in.cost();
}

@Override
public void intoBitSet(Bits acceptDocs, int upTo, FixedBitSet bitSet, int offset)
throws IOException {
assert docID() != -1;
assert offset <= docID();
in.intoBitSet(acceptDocs, upTo, bitSet, offset);
assert docID() >= upTo;
}
};
}

Expand Down

0 comments on commit e74f19b

Please sign in to comment.