Skip to content

Commit

Permalink
make solr/modules/ltr/src/java compile
Browse files Browse the repository at this point in the history
  • Loading branch information
cpoerschke committed Jan 29, 2025
1 parent f98f703 commit 494b6b4
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int top
return firstPassTopDocs;
}
final ScoreDoc[] firstPassResults = getFirstPassDocsRanked(firstPassTopDocs);
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value()));

final ScoreDoc[] reranked = rerank(searcher, topN, firstPassResults);

Expand All @@ -148,7 +148,7 @@ protected static ScoreDoc[] getFirstPassDocsRanked(TopDocs firstPassTopDocs) {
final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
Arrays.sort(hits, docComparator);

assert firstPassTopDocs.totalHits.relation == TotalHits.Relation.EQUAL_TO;
assert firstPassTopDocs.totalHits.relation() == TotalHits.Relation.EQUAL_TO;
return hits;
}

Expand Down Expand Up @@ -180,7 +180,7 @@ public void scoreFeatures(
// We advanced to another segment
if (readerContext != null) {
docBase = readerContext.docBase;
scorer = modelWeight.scorer(readerContext);
scorer = modelWeight.modelScorer(readerContext);
}
if (scoreSingleHit(topN, docBase, hitUpto, hit, docID, scorer, reranked)) {
logSingleHit(indexSearcher, modelWeight, hit.doc, scoringQuery);
Expand Down Expand Up @@ -299,7 +299,7 @@ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(
final int n = ReaderUtil.subIndex(docid, leafContexts);
final LeafReaderContext atomicContext = leafContexts.get(n);
final int deBasedDoc = docid - atomicContext.docBase;
final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.scorer(atomicContext);
final LTRScoringQuery.ModelWeight.ModelScorer r = modelWeight.modelScorer(atomicContext);
if ((r == null) || (r.iterator().advance(deBasedDoc) != deBasedDoc)) {
return new LTRScoringQuery.FeatureInfo[0];
} else {
Expand Down
30 changes: 18 additions & 12 deletions solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.util.SolrDefaultScorerSupplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -465,7 +467,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
Explanation e = ltrScoringModel.getNormalizerExplanation(explanations[f.getIndex()], idx);
featureExplanations.add(e);
}
final ModelScorer bs = scorer(context);
final ModelScorer bs = modelScorer(context);
bs.iterator().advance(doc);

final float finalScore = bs.score();
Expand All @@ -485,12 +487,16 @@ protected void reset() {
}

@Override
public ModelScorer scorer(LeafReaderContext context) throws IOException {
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
return new SolrDefaultScorerSupplier(modelScorer(context));
}

public ModelScorer modelScorer(LeafReaderContext context) throws IOException {

final List<Feature.FeatureWeight.FeatureScorer> featureScorers =
new ArrayList<Feature.FeatureWeight.FeatureScorer>(extractedFeatureWeights.length);
for (final Feature.FeatureWeight featureWeight : extractedFeatureWeights) {
final Feature.FeatureWeight.FeatureScorer scorer = featureWeight.scorer(context);
final Feature.FeatureWeight.FeatureScorer scorer = featureWeight.featureScorer(context);
if (scorer != null) {
featureScorers.add(scorer);
}
Expand Down Expand Up @@ -518,7 +524,6 @@ public DocInfo getDocInfo() {
}

public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
docInfo = new DocInfo();
for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) {
subScorer.setDocInfo(docInfo);
Expand Down Expand Up @@ -564,14 +569,13 @@ private class SparseModelScorer extends Scorer {
private int activeDoc = -1;

private SparseModelScorer(
Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
Weight unusedWeight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
if (featureScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
}
subScorers = new DisiPriorityQueue(featureScorers.size());
for (final Scorer scorer : featureScorers) {
final DisiWrapper w = new DisiWrapper(scorer);
final DisiWrapper w = new DisiWrapper(scorer, false /* impacts */);
subScorers.add(w);
}

Expand All @@ -594,8 +598,9 @@ public float score() throws IOException {
reset();
if (activeDoc == targetDoc) {
for (DisiWrapper w = topList; w != null; w = w.next) {
final Scorer subScorer = w.scorer;
Feature.FeatureWeight scFW = (Feature.FeatureWeight) subScorer.getWeight();
final Feature.FeatureWeight.FeatureScorer subScorer =
(Feature.FeatureWeight.FeatureScorer) w.scorer;
Feature.FeatureWeight scFW = subScorer.getWeight();
final int featureId = scFW.getIndex();
featuresInfo[featureId].setValue(subScorer.score());
featuresInfo[featureId].setUsed(true);
Expand Down Expand Up @@ -662,8 +667,7 @@ private class DenseModelScorer extends Scorer {
private final List<Feature.FeatureWeight.FeatureScorer> featureScorers;

private DenseModelScorer(
Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
Weight unusedWeight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
this.featureScorers = featureScorers;
}

Expand All @@ -680,7 +684,9 @@ public float score() throws IOException {
for (final Scorer scorer : featureScorers) {
if (scorer.docID() == activeDoc) {
freq++;
Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight();
Feature.FeatureWeight.FeatureScorer featureScorer =
(Feature.FeatureWeight.FeatureScorer) scorer;
Feature.FeatureWeight scFW = featureScorer.getWeight();
final int featureId = scFW.getIndex();
featuresInfo[featureId].setValue(scorer.score());
featuresInfo[featureId].setUsed(true);
Expand Down
17 changes: 14 additions & 3 deletions solr/modules/ltr/src/java/org/apache/solr/ltr/feature/Feature.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.DocInfo;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.request.macro.MacroExpander;
import org.apache.solr.util.SolrDefaultScorerSupplier;
import org.apache.solr.util.SolrPluginUtils;

/**
Expand Down Expand Up @@ -267,8 +269,12 @@ public float getDefaultValue() {
return Feature.this.getDefaultValue();
}

public abstract FeatureScorer featureScorer(LeafReaderContext context) throws IOException;

@Override
public abstract FeatureScorer scorer(LeafReaderContext context) throws IOException;
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
return new SolrDefaultScorerSupplier(featureScorer(context));
}

@Override
public boolean isCacheable(LeafReaderContext ctx) {
Expand All @@ -277,7 +283,7 @@ public boolean isCacheable(LeafReaderContext ctx) {

@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
final FeatureScorer r = scorer(context);
final FeatureScorer r = featureScorer(context);
float score = getDefaultValue();
if (r != null) {
r.iterator().advance(doc);
Expand Down Expand Up @@ -306,10 +312,11 @@ public abstract static class FeatureScorer extends Scorer {

protected final String name;
private DocInfo docInfo;
private final Feature.FeatureWeight weight;
protected final DocIdSetIterator itr;

public FeatureScorer(Feature.FeatureWeight weight, DocIdSetIterator itr) {
super(weight);
this.weight = weight;
this.itr = itr;
name = weight.getName();
docInfo = null;
Expand All @@ -332,6 +339,10 @@ public int docID() {
return itr.docID();
}

public Feature.FeatureWeight getWeight() {
return weight;
}

@Override
public DocIdSetIterator iterator() {
return itr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ public FieldLengthFeatureWeight(
super(FieldLengthFeature.this, searcher, request, originalQuery, efi);
}

@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
public FeatureScorer featureScorer(LeafReaderContext context) throws IOException {
NumericDocValues norms = context.reader().getNormValues(field);
if (norms == null) {
return new ValueFeatureScorer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ protected long readNumericDocValuesDate(long val) {
* @return FeatureScorer for the current segment and field
* @throws IOException as defined by abstract class Feature
*/
@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
public FeatureScorer featureScorer(LeafReaderContext context) throws IOException {
if (schemaField != null
&& (!schemaField.stored() || useDocValuesForStored)
&& schemaField.hasDocValues()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ public String toString() {
return "OriginalScoreFeature [query:" + originalQuery.toString() + "]";
}

@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
public FeatureScorer featureScorer(LeafReaderContext context) throws IOException {

final Scorer originalScorer = w.scorer(context);
return new OriginalScoreScorer(this, originalScorer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,7 @@ private LocalSolrQueryRequest makeRequest(
}
}

@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
public FeatureScorer featureScorer(LeafReaderContext context) throws IOException {
Scorer solrScorer = solrQueryWeight.scorer(context);
if (solrScorer == null) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ public ValueFeatureWeight(
}
}

@Override
public FeatureScorer scorer(LeafReaderContext context) throws IOException {
public FeatureScorer featureScorer(LeafReaderContext context) throws IOException {
if (featureValue != null) {
return new ValueFeatureScorer(
this, featureValue, DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int top
System.arraycopy(
firstPassTopDocs.scoreDocs, 0, firstPassResults, 0, firstPassTopDocs.scoreDocs.length);
}
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value));
topN = Math.toIntExact(Math.min(topN, firstPassTopDocs.totalHits.value()));

ScoreDoc[][] reRankedPerModel =
rerank(searcher, topN, getFirstPassDocsRanked(firstPassTopDocs));
Expand Down Expand Up @@ -144,7 +144,7 @@ public void scoreFeatures(
docBase = readerContext.docBase;
for (int i = 0; i < modelWeights.length; i++) {
if (modelWeights[i] != null) {
scorers[i] = modelWeights[i].scorer(readerContext);
scorers[i] = modelWeights[i].modelScorer(readerContext);
}
}
}
Expand Down

0 comments on commit 494b6b4

Please sign in to comment.