Skip to content

Commit

Permalink
score: introduce query.Boost to scale score
Browse files Browse the repository at this point in the history
This commit introduces a new primitive Boost to our query language. It
allows boosting (or dampening) the contribution to the score a query
atoms will match contribute.

To achieve this we introduce boostMatchTree which records this weight.
We then adjust the visitMatches to take an initial score weight (1.0),
and then each time we recurse through a boostMatchTree the score weight
is multiplied by the boost weight. Additionally candidateMatch now has a
new field, scoreWeight, which records the weight at time of candidate
collection. Without boosting in the query this value will always be 1.

Finally when scoring a candidateMatch we take the final score for it and
multiply it by scoreWeight.

Note: we do not expose a way to set this in the query language, only the
query API.

Test Plan: This functionality is currently untested. However, none of
our existings tests have broken so this is technically safe to land.
TODO add testing for boost query.
  • Loading branch information
keegancsmith committed Jan 29, 2024
1 parent cdb1665 commit 9cb40ac
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 14 deletions.
2 changes: 1 addition & 1 deletion api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func TestMatchSize(t *testing.T) {
size: 112,
}, {
v: candidateMatch{},
size: 72,
size: 80,
}, {
v: candidateChunk{},
size: 40,
Expand Down
5 changes: 5 additions & 0 deletions contentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,11 @@ func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language str
}
}

if m.scoreWeight != 1 { // should we be using epsilon comparison here?
score.score = score.score * m.scoreWeight
score.what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight)
}

if score.score > maxScore.score {
maxScore.score = score.score
maxScore.what = score.what
Expand Down
37 changes: 30 additions & 7 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ nextFileMatch:
// whether there's an exact match on a symbol, the number of query clauses that matched, etc.
func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) {
atomMatchCount := 0
visitMatches(mt, known, func(mt matchTree) {
visitMatchAtoms(mt, known, func(mt matchTree) {
atomMatchCount++
})

Expand Down Expand Up @@ -544,6 +544,13 @@ func (m sortByOffsetSlice) Less(i, j int) bool {
return m[i].byteOffset < m[j].byteOffset
}

func setScoreWeight(scoreWeight float64, cm []*candidateMatch) []*candidateMatch {
for _, m := range cm {
m.scoreWeight = scoreWeight
}
return cm
}

// Gather matches from this document. This never returns a mixture of
// filename/content matches: if there are content matches, all
// filename matches are trimmed from the result. The matches are
Expand All @@ -554,18 +561,20 @@ func (m sortByOffsetSlice) Less(i, j int) bool {
// but adjacent matches will remain.
func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candidateMatch {
var cands []*candidateMatch
visitMatches(mt, known, func(mt matchTree) {
visitMatches(mt, known, 1, func(mt matchTree, scoreWeight float64) {
// TODO apply scoreWeight to candidates
_ = scoreWeight
if smt, ok := mt.(*substrMatchTree); ok {
cands = append(cands, smt.current...)
cands = append(cands, setScoreWeight(scoreWeight, smt.current)...)
}
if rmt, ok := mt.(*regexpMatchTree); ok {
cands = append(cands, rmt.found...)
cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
}
if rmt, ok := mt.(*wordMatchTree); ok {
cands = append(cands, rmt.found...)
cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
}
if smt, ok := mt.(*symbolRegexpMatchTree); ok {
cands = append(cands, smt.found...)
cands = append(cands, setScoreWeight(scoreWeight, smt.found)...)
}
})

Expand All @@ -590,6 +599,7 @@ func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candid
// are non-overlapping.
sort.Sort((sortByOffsetSlice)(cands))
res = cands[:0]
mergeRun := 1
for i, c := range cands {
if i == 0 {
res = append(res, c)
Expand All @@ -599,10 +609,23 @@ func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candid
lastEnd := last.byteOffset + last.byteMatchSz
end := c.byteOffset + c.byteMatchSz
if lastEnd >= c.byteOffset {
mergeRun++

// Average out the score across the merged candidates. Only do it if
// we are boosting to avoid floating point funkiness in the normal
// case.
if last.scoreWeight != 1 && c.scoreWeight != 1 {
last.scoreWeight = ((last.scoreWeight * float64(mergeRun-1)) + c.scoreWeight) / float64(mergeRun)
}

// latest candidate goes further, update our end
if end > lastEnd {
last.byteMatchSz = end - last.byteOffset
}

continue
} else {
mergeRun = 1
}

res = append(res, c)
Expand Down Expand Up @@ -649,7 +672,7 @@ func (d *indexData) branchIndex(docID uint32) int {
// returns all branches containing docID.
func (d *indexData) gatherBranches(docID uint32, mt matchTree, known map[matchTree]bool) []string {
var mask uint64
visitMatches(mt, known, func(mt matchTree) {
visitMatchAtoms(mt, known, func(mt matchTree) {
bq, ok := mt.(*branchQueryMatchTree)
if !ok {
return
Expand Down
2 changes: 2 additions & 0 deletions matchiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type candidateMatch struct {
substrBytes []byte
substrLowered []byte

scoreWeight float64

file uint32
symbolIdx uint32

Expand Down
56 changes: 50 additions & 6 deletions matchtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ type fileNameMatchTree struct {
child matchTree
}

type boostMatchTree struct {
child matchTree
weight float64
}

// Don't visit this subtree for collecting matches.
type noVisitMatchTree struct {
matchTree
Expand Down Expand Up @@ -392,6 +397,10 @@ func (t *fileNameMatchTree) prepare(doc uint32) {
t.child.prepare(doc)
}

func (t *boostMatchTree) prepare(doc uint32) {
t.child.prepare(doc)
}

func (t *substrMatchTree) prepare(nextDoc uint32) {
t.matchIterator.prepare(nextDoc)
t.current = t.matchIterator.candidates()
Expand Down Expand Up @@ -455,6 +464,10 @@ func (t *fileNameMatchTree) nextDoc() uint32 {
return t.child.nextDoc()
}

func (t *boostMatchTree) nextDoc() uint32 {
return t.child.nextDoc()
}

func (t *branchQueryMatchTree) nextDoc() uint32 {
var start uint32
if t.firstDone {
Expand Down Expand Up @@ -515,6 +528,10 @@ func (t *fileNameMatchTree) String() string {
return fmt.Sprintf("f(%v)", t.child)
}

func (t *boostMatchTree) String() string {
return fmt.Sprintf("boost(%f, %v)", t.weight, t.child)
}

func (t *substrMatchTree) String() string {
f := ""
if t.fileName {
Expand Down Expand Up @@ -556,6 +573,8 @@ func visitMatchTree(t matchTree, f func(matchTree)) {
visitMatchTree(s.child, f)
case *fileNameMatchTree:
visitMatchTree(s.child, f)
case *boostMatchTree:
visitMatchTree(s.child, f)
case *symbolSubstrMatchTree:
visitMatchTree(s.substrMatchTree, f)
case *symbolRegexpMatchTree:
Expand All @@ -575,33 +594,41 @@ func updateMatchTreeStats(mt matchTree, stats *Stats) {
})
}

func visitMatchAtoms(t matchTree, known map[matchTree]bool, f func(matchTree)) {
visitMatches(t, known, 1, func(mt matchTree, _ float64) {
f(mt)
})
}

// visitMatches visits all atoms which can contribute matches. Note: This
// skips noVisitMatchTree.
func visitMatches(t matchTree, known map[matchTree]bool, f func(matchTree)) {
func visitMatches(t matchTree, known map[matchTree]bool, weight float64, f func(matchTree, float64)) {
switch s := t.(type) {
case *andMatchTree:
for _, ch := range s.children {
if known[ch] {
visitMatches(ch, known, f)
visitMatches(ch, known, weight, f)
}
}
case *andLineMatchTree:
visitMatches(&s.andMatchTree, known, f)
visitMatches(&s.andMatchTree, known, weight, f)
case *orMatchTree:
for _, ch := range s.children {
if known[ch] {
visitMatches(ch, known, f)
visitMatches(ch, known, weight, f)
}
}
case *boostMatchTree:
visitMatches(s.child, known, weight*s.weight, f)
case *symbolSubstrMatchTree:
visitMatches(s.substrMatchTree, known, f)
visitMatches(s.substrMatchTree, known, weight, f)
case *notMatchTree:
case *noVisitMatchTree:
// don't collect into negative trees.
case *fileNameMatchTree:
// We will just gather the filename if we do not visit this tree.
default:
f(s)
f(s, weight)
}
}

Expand Down Expand Up @@ -876,6 +903,10 @@ func (t *fileNameMatchTree) matches(cp *contentProvider, cost int, known map[mat
return evalMatchTree(cp, cost, known, t.child)
}

func (t *boostMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState {
return evalMatchTree(cp, cost, known, t.child)
}

func (t *substrMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState {
if t.contEvaluated {
return matchesStateForSlice(t.current)
Expand Down Expand Up @@ -997,6 +1028,17 @@ func (d *indexData) newMatchTree(q query.Q, opt matchTreeOpt) (matchTree, error)
child: ct,
}, nil

case *query.Boost:
ct, err := d.newMatchTree(s.Child, opt)
if err != nil {
return nil, err
}

return &boostMatchTree{
child: ct,
weight: s.Weight,
}, nil

case *query.Substring:
return d.newSubstringMatchTree(s)

Expand Down Expand Up @@ -1288,6 +1330,8 @@ func pruneMatchTree(mt matchTree) (matchTree, error) {
}
case *fileNameMatchTree:
mt.child, err = pruneMatchTree(mt.child)
case *boostMatchTree:
mt.child, err = pruneMatchTree(mt.child)
case *andLineMatchTree:
child, err := pruneMatchTree(&mt.andMatchTree)
if err != nil {
Expand Down
21 changes: 21 additions & 0 deletions query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,19 @@ func (q *Type) String() string {
}
}

// Boost scales the contribution to score of descendents.
type Boost struct {
Child Q
// Weight will multiply the score of its descendents. Weights less than 1
// will give less importance while values greater than 1 will give more
// importance.
Weight float64
}

func (q *Boost) String() string {
return fmt.Sprintf("(boost %f %s)", q.Weight, q.Child)
}

// Substring is the most basic query: a query for a substring.
type Substring struct {
Pattern string
Expand Down Expand Up @@ -609,6 +622,9 @@ func flatten(q Q) (Q, bool) {
case *Type:
child, changed := flatten(s.Child)
return &Type{Child: child, Type: s.Type}, changed
case *Boost:
child, changed := flatten(s.Child)
return &Boost{Child: child, Weight: s.Weight}, changed
default:
return q, false
}
Expand Down Expand Up @@ -680,6 +696,8 @@ func evalConstants(q Q) Q {
return ch
}
return &Type{Child: ch, Type: s.Type}
case *Boost:
return &Boost{Weight: s.Weight, Child: evalConstants(s.Child)}
case *Substring:
if len(s.Pattern) == 0 {
return &Const{true}
Expand Down Expand Up @@ -728,6 +746,8 @@ func Map(q Q, f func(q Q) Q) Q {
q = &Not{Child: Map(s.Child, f)}
case *Type:
q = &Type{Type: s.Type, Child: Map(s.Child, f)}
case *Boost:
q = &Boost{Weight: s.Weight, Child: Map(s.Child, f)}
}
return f(q)
}
Expand Down Expand Up @@ -768,6 +788,7 @@ func VisitAtoms(q Q, v func(q Q)) {
case *Or:
case *Not:
case *Type:
case *Boost:
default:
v(iQ)
}
Expand Down
12 changes: 12 additions & 0 deletions web/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,18 @@ func (s *Server) serveSearchErr(r *http.Request) (*ApiSearchResult, error) {
return nil, err
}

// Experimental support to take the query string and boost exact phrases of
// it.
if phraseBoost, err := strconv.ParseFloat(qvals.Get("phrase-boost"), 64); err == nil {
q = query.NewOr(
&query.Boost{
Weight: phraseBoost,
Child: &query.Substring{Pattern: queryStr, Content: true},
},
q,
)
}

repoOnly := true
query.VisitAtoms(q, func(q query.Q) {
_, ok := q.(*query.Repo)
Expand Down

0 comments on commit 9cb40ac

Please sign in to comment.