package org.apache.lucene.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.FutureTask;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/lucene/search/AbstractKnnVectorQuery.class */
public abstract class AbstractKnnVectorQuery extends Query {
    private static final TopDocs NO_RESULTS;
    protected final String field;
    protected final int k;
    private final Query filter;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/lucene/search/AbstractKnnVectorQuery$DocAndScoreQuery.class */
    public static class DocAndScoreQuery extends Query {
        private final int k;
        private final int[] docs;
        private final float[] scores;
        private final int[] segmentStarts;
        private final Object contextIdentity;

        DocAndScoreQuery(int i, int[] iArr, float[] fArr, int[] iArr2, Object obj) {
            this.k = i;
            this.docs = iArr;
            this.scores = fArr;
            this.segmentStarts = iArr2;
            this.contextIdentity = obj;
        }

        @Override // org.apache.lucene.search.Query
        public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, final float f) throws IOException {
            if (indexSearcher.getIndexReader().getContext().id() != this.contextIdentity) {
                throw new IllegalStateException("This DocAndScore query was created by a different reader");
            }
            return new Weight(this) { // from class: org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery.1

                /* renamed from: org.apache.lucene.search.AbstractKnnVectorQuery$DocAndScoreQuery$1$1, reason: invalid class name and collision with other inner class name */
                /* loaded from: input_file:org/apache/lucene/search/AbstractKnnVectorQuery$DocAndScoreQuery$1$1.class */
                class C01001 extends Scorer {
                    final int lower;
                    final int upper;
                    int upTo;
                    final /* synthetic */ LeafReaderContext val$context;

                    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
                    C01001(Weight weight, LeafReaderContext leafReaderContext) {
                        super(weight);
                        this.val$context = leafReaderContext;
                        this.lower = DocAndScoreQuery.this.segmentStarts[this.val$context.ord];
                        this.upper = DocAndScoreQuery.this.segmentStarts[this.val$context.ord + 1];
                        this.upTo = -1;
                    }

                    @Override // org.apache.lucene.search.Scorer
                    public DocIdSetIterator iterator() {
                        return new DocIdSetIterator() { // from class: org.apache.lucene.search.AbstractKnnVectorQuery.DocAndScoreQuery.1.1.1
                            @Override // org.apache.lucene.search.DocIdSetIterator
                            public int docID() {
                                return C01001.this.docIdNoShadow();
                            }

                            @Override // org.apache.lucene.search.DocIdSetIterator
                            public int nextDoc() {
                                if (C01001.this.upTo == -1) {
                                    C01001.this.upTo = C01001.this.lower;
                                } else {
                                    C01001.this.upTo++;
                                }
                                return C01001.this.docIdNoShadow();
                            }

                            @Override // org.apache.lucene.search.DocIdSetIterator
                            public int advance(int i) throws IOException {
                                return slowAdvance(i);
                            }

                            @Override // org.apache.lucene.search.DocIdSetIterator
                            public long cost() {
                                return C01001.this.upper - C01001.this.lower;
                            }
                        };
                    }

                    @Override // org.apache.lucene.search.Scorer
                    public float getMaxScore(int i) {
                        int i2 = i + this.val$context.docBase;
                        float f = 0.0f;
                        for (int max = Math.max(0, this.upTo); max < this.upper && DocAndScoreQuery.this.docs[max] <= i2; max++) {
                            f = Math.max(f, DocAndScoreQuery.this.scores[max]);
                        }
                        return f * f;
                    }

                    @Override // org.apache.lucene.search.Scorable
                    public float score() {
                        return DocAndScoreQuery.this.scores[this.upTo] * f;
                    }

                    @Override // org.apache.lucene.search.Scorer
                    public int advanceShallow(int i) {
                        int binarySearch = Arrays.binarySearch(DocAndScoreQuery.this.docs, Math.max(this.upTo, this.lower), this.upper, i + this.val$context.docBase);
                        if (binarySearch < 0) {
                            binarySearch = (-1) - binarySearch;
                        }
                        if (binarySearch >= this.upper) {
                            return Integer.MAX_VALUE;
                        }
                        return DocAndScoreQuery.this.docs[binarySearch];
                    }

                    private int docIdNoShadow() {
                        if (this.upTo == -1) {
                            return -1;
                        }
                        if (this.upTo >= this.upper) {
                            return Integer.MAX_VALUE;
                        }
                        return DocAndScoreQuery.this.docs[this.upTo] - this.val$context.docBase;
                    }

                    @Override // org.apache.lucene.search.Scorable
                    public int docID() {
                        return docIdNoShadow();
                    }
                }

                @Override // org.apache.lucene.search.Weight
                public Explanation explain(LeafReaderContext leafReaderContext, int i) {
                    int binarySearch = Arrays.binarySearch(DocAndScoreQuery.this.docs, i + leafReaderContext.docBase);
                    return binarySearch < 0 ? Explanation.noMatch("not in top " + DocAndScoreQuery.this.k, new Explanation[0]) : Explanation.match(Float.valueOf(DocAndScoreQuery.this.scores[binarySearch] * f), "within top " + DocAndScoreQuery.this.k, new Explanation[0]);
                }

                @Override // org.apache.lucene.search.Weight
                public int count(LeafReaderContext leafReaderContext) {
                    return DocAndScoreQuery.this.segmentStarts[leafReaderContext.ord + 1] - DocAndScoreQuery.this.segmentStarts[leafReaderContext.ord];
                }

                @Override // org.apache.lucene.search.Weight
                public Scorer scorer(LeafReaderContext leafReaderContext) {
                    if (DocAndScoreQuery.this.segmentStarts[leafReaderContext.ord] == DocAndScoreQuery.this.segmentStarts[leafReaderContext.ord + 1]) {
                        return null;
                    }
                    return new C01001(this, leafReaderContext);
                }

                @Override // org.apache.lucene.search.SegmentCacheable
                public boolean isCacheable(LeafReaderContext leafReaderContext) {
                    return true;
                }
            };
        }

        @Override // org.apache.lucene.search.Query
        public String toString(String str) {
            return "DocAndScore[" + this.k + "]";
        }

        @Override // org.apache.lucene.search.Query
        public void visit(QueryVisitor queryVisitor) {
            queryVisitor.visitLeaf(this);
        }

        @Override // org.apache.lucene.search.Query
        public boolean equals(Object obj) {
            return sameClassAs(obj) && this.contextIdentity == ((DocAndScoreQuery) obj).contextIdentity && Arrays.equals(this.docs, ((DocAndScoreQuery) obj).docs) && Arrays.equals(this.scores, ((DocAndScoreQuery) obj).scores);
        }

        @Override // org.apache.lucene.search.Query
        public int hashCode() {
            return Objects.hash(Integer.valueOf(classHash()), this.contextIdentity, Integer.valueOf(Arrays.hashCode(this.docs)), Integer.valueOf(Arrays.hashCode(this.scores)));
        }
    }

    public AbstractKnnVectorQuery(String str, int i, Query query) {
        this.field = (String) Objects.requireNonNull(str, "field");
        this.k = i;
        if (i < 1) {
            throw new IllegalArgumentException("k must be at least 1, got: " + i);
        }
        this.filter = query;
    }

    @Override // org.apache.lucene.search.Query
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        IndexReader indexReader = indexSearcher.getIndexReader();
        Weight createWeight = this.filter != null ? indexSearcher.createWeight(indexSearcher.rewrite(new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add(new FieldExistsQuery(this.field), BooleanClause.Occur.FILTER).build()), ScoreMode.COMPLETE_NO_SCORES, 1.0f) : null;
        TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
        TopDocs merge = TopDocs.merge(this.k, taskExecutor == null ? sequentialSearch(indexReader.leaves(), createWeight) : parallelSearch(indexReader.leaves(), createWeight, taskExecutor));
        return merge.scoreDocs.length == 0 ? new MatchNoDocsQuery() : createRewrittenQuery(indexReader, merge);
    }

    private TopDocs[] sequentialSearch(List<LeafReaderContext> list, Weight weight) throws IOException {
        TopDocs[] topDocsArr = new TopDocs[list.size()];
        for (LeafReaderContext leafReaderContext : list) {
            topDocsArr[leafReaderContext.ord] = searchLeaf(leafReaderContext, weight);
        }
        return topDocsArr;
    }

    private TopDocs[] parallelSearch(List<LeafReaderContext> list, Weight weight, TaskExecutor taskExecutor) throws IOException {
        ArrayList arrayList = new ArrayList();
        for (LeafReaderContext leafReaderContext : list) {
            arrayList.add(new FutureTask(() -> {
                return searchLeaf(leafReaderContext, weight);
            }));
        }
        return (TopDocs[]) taskExecutor.invokeAll(arrayList).toArray(i -> {
            return new TopDocs[i];
        });
    }

    private TopDocs searchLeaf(LeafReaderContext leafReaderContext, Weight weight) throws IOException {
        TopDocs leafResults = getLeafResults(leafReaderContext, weight);
        if (leafReaderContext.docBase > 0) {
            for (ScoreDoc scoreDoc : leafResults.scoreDocs) {
                scoreDoc.doc += leafReaderContext.docBase;
            }
        }
        return leafResults;
    }

    private TopDocs getLeafResults(LeafReaderContext leafReaderContext, Weight weight) throws IOException {
        Bits liveDocs = leafReaderContext.reader().getLiveDocs();
        int maxDoc = leafReaderContext.reader().maxDoc();
        if (weight == null) {
            return approximateSearch(leafReaderContext, liveDocs, Integer.MAX_VALUE);
        }
        Scorer scorer = weight.scorer(leafReaderContext);
        if (scorer == null) {
            return NO_RESULTS;
        }
        BitSet createBitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
        int cardinality = createBitSet.cardinality();
        if (cardinality <= this.k) {
            return exactSearch(leafReaderContext, new BitSetIterator(createBitSet, cardinality));
        }
        TopDocs approximateSearch = approximateSearch(leafReaderContext, createBitSet, cardinality);
        return approximateSearch.totalHits.relation == TotalHits.Relation.EQUAL_TO ? approximateSearch : exactSearch(leafReaderContext, new BitSetIterator(createBitSet, cardinality));
    }

    private BitSet createBitSet(DocIdSetIterator docIdSetIterator, final Bits bits, int i) throws IOException {
        return (bits == null && (docIdSetIterator instanceof BitSetIterator)) ? ((BitSetIterator) docIdSetIterator).getBitSet() : BitSet.of(new FilteredDocIdSetIterator(docIdSetIterator) { // from class: org.apache.lucene.search.AbstractKnnVectorQuery.1
            @Override // org.apache.lucene.search.FilteredDocIdSetIterator
            protected boolean match(int i2) {
                return bits == null || bits.get(i2);
            }
        }, i);
    }

    protected abstract TopDocs approximateSearch(LeafReaderContext leafReaderContext, Bits bits, int i) throws IOException;

    abstract VectorScorer createVectorScorer(LeafReaderContext leafReaderContext, FieldInfo fieldInfo) throws IOException;

    protected TopDocs exactSearch(LeafReaderContext leafReaderContext, DocIdSetIterator docIdSetIterator) throws IOException {
        FieldInfo fieldInfo = leafReaderContext.reader().getFieldInfos().fieldInfo(this.field);
        if (fieldInfo == null || fieldInfo.getVectorDimension() == 0) {
            return NO_RESULTS;
        }
        VectorScorer createVectorScorer = createVectorScorer(leafReaderContext, fieldInfo);
        HitQueue hitQueue = new HitQueue(this.k, true);
        ScoreDoc pVar = hitQueue.top();
        while (true) {
            int nextDoc = docIdSetIterator.nextDoc();
            if (nextDoc == Integer.MAX_VALUE) {
                while (hitQueue.size() > 0 && hitQueue.top().score < 0.0f) {
                    hitQueue.pop();
                }
                ScoreDoc[] scoreDocArr = new ScoreDoc[hitQueue.size()];
                for (int length = scoreDocArr.length - 1; length >= 0; length--) {
                    scoreDocArr[length] = hitQueue.pop();
                }
                return new TopDocs(new TotalHits(docIdSetIterator.cost(), TotalHits.Relation.EQUAL_TO), scoreDocArr);
            }
            boolean advanceExact = createVectorScorer.advanceExact(nextDoc);
            if (!$assertionsDisabled && !advanceExact) {
                throw new AssertionError();
            }
            float score = createVectorScorer.score();
            if (score > pVar.score) {
                pVar.score = score;
                pVar.doc = nextDoc;
                pVar = hitQueue.updateTop();
            }
        }
    }

    private Query createRewrittenQuery(IndexReader indexReader, TopDocs topDocs) {
        int length = topDocs.scoreDocs.length;
        Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(scoreDoc -> {
            return scoreDoc.doc;
        }));
        int[] iArr = new int[length];
        float[] fArr = new float[length];
        for (int i = 0; i < length; i++) {
            iArr[i] = topDocs.scoreDocs[i].doc;
            fArr[i] = topDocs.scoreDocs[i].score;
        }
        return new DocAndScoreQuery(this.k, iArr, fArr, findSegmentStarts(indexReader, iArr), indexReader.getContext().id());
    }

    static int[] findSegmentStarts(IndexReader indexReader, int[] iArr) {
        int[] iArr2 = new int[indexReader.leaves().size() + 1];
        iArr2[iArr2.length - 1] = iArr.length;
        if (iArr2.length == 2) {
            return iArr2;
        }
        int i = 0;
        for (int i2 = 1; i2 < iArr2.length - 1; i2++) {
            i = Arrays.binarySearch(iArr, i, iArr.length, indexReader.leaves().get(i2).docBase);
            if (i < 0) {
                i = (-1) - i;
            }
            iArr2[i2] = i;
        }
        return iArr2;
    }

    @Override // org.apache.lucene.search.Query
    public void visit(QueryVisitor queryVisitor) {
        if (queryVisitor.acceptField(this.field)) {
            queryVisitor.visitLeaf(this);
        }
    }

    @Override // org.apache.lucene.search.Query
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        AbstractKnnVectorQuery abstractKnnVectorQuery = (AbstractKnnVectorQuery) obj;
        return this.k == abstractKnnVectorQuery.k && Objects.equals(this.field, abstractKnnVectorQuery.field) && Objects.equals(this.filter, abstractKnnVectorQuery.filter);
    }

    @Override // org.apache.lucene.search.Query
    public int hashCode() {
        return Objects.hash(this.field, Integer.valueOf(this.k), this.filter);
    }

    public String getField() {
        return this.field;
    }

    public int getK() {
        return this.k;
    }

    public Query getFilter() {
        return this.filter;
    }

    static {
        $assertionsDisabled = !AbstractKnnVectorQuery.class.desiredAssertionStatus();
        NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    }
}
