Skip to content

Commit 9cb40ac

Browse files
committed
score: introduce query.Boost to scale score
This commit introduces a new primitive Boost to our query language. It allows boosting (or dampening) the contribution to the score a query atoms will match contribute. To achieve this we introduce boostMatchTree which records this weight. We then adjust the visitMatches to take an initial score weight (1.0), and then each time we recurse through a boostMatchTree the score weight is multiplied by the boost weight. Additionally candidateMatch now has a new field, scoreWeight, which records the weight at time of candidate collection. Without boosting in the query this value will always be 1. Finally when scoring a candidateMatch we take the final score for it and multiply it by scoreWeight. Note: we do not expose a way to set this in the query language, only the query API. Test Plan: This functionality is currently untested. However, none of our existings tests have broken so this is technically safe to land. TODO add testing for boost query.
1 parent cdb1665 commit 9cb40ac

File tree

7 files changed

+121
-14
lines changed

7 files changed

+121
-14
lines changed

api_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ func TestMatchSize(t *testing.T) {
152152
size: 112,
153153
}, {
154154
v: candidateMatch{},
155-
size: 72,
155+
size: 80,
156156
}, {
157157
v: candidateChunk{},
158158
size: 40,

contentprovider.go

+5
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,11 @@ func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language str
660660
}
661661
}
662662

663+
if m.scoreWeight != 1 { // should we be using epsilon comparison here?
664+
score.score = score.score * m.scoreWeight
665+
score.what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight)
666+
}
667+
663668
if score.score > maxScore.score {
664669
maxScore.score = score.score
665670
maxScore.what = score.what

eval.go

+30-7
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ nextFileMatch:
420420
// whether there's an exact match on a symbol, the number of query clauses that matched, etc.
421421
func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) {
422422
atomMatchCount := 0
423-
visitMatches(mt, known, func(mt matchTree) {
423+
visitMatchAtoms(mt, known, func(mt matchTree) {
424424
atomMatchCount++
425425
})
426426

@@ -544,6 +544,13 @@ func (m sortByOffsetSlice) Less(i, j int) bool {
544544
return m[i].byteOffset < m[j].byteOffset
545545
}
546546

547+
func setScoreWeight(scoreWeight float64, cm []*candidateMatch) []*candidateMatch {
548+
for _, m := range cm {
549+
m.scoreWeight = scoreWeight
550+
}
551+
return cm
552+
}
553+
547554
// Gather matches from this document. This never returns a mixture of
548555
// filename/content matches: if there are content matches, all
549556
// filename matches are trimmed from the result. The matches are
@@ -554,18 +561,20 @@ func (m sortByOffsetSlice) Less(i, j int) bool {
554561
// but adjacent matches will remain.
555562
func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candidateMatch {
556563
var cands []*candidateMatch
557-
visitMatches(mt, known, func(mt matchTree) {
564+
visitMatches(mt, known, 1, func(mt matchTree, scoreWeight float64) {
565+
// TODO apply scoreWeight to candidates
566+
_ = scoreWeight
558567
if smt, ok := mt.(*substrMatchTree); ok {
559-
cands = append(cands, smt.current...)
568+
cands = append(cands, setScoreWeight(scoreWeight, smt.current)...)
560569
}
561570
if rmt, ok := mt.(*regexpMatchTree); ok {
562-
cands = append(cands, rmt.found...)
571+
cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
563572
}
564573
if rmt, ok := mt.(*wordMatchTree); ok {
565-
cands = append(cands, rmt.found...)
574+
cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
566575
}
567576
if smt, ok := mt.(*symbolRegexpMatchTree); ok {
568-
cands = append(cands, smt.found...)
577+
cands = append(cands, setScoreWeight(scoreWeight, smt.found)...)
569578
}
570579
})
571580

@@ -590,6 +599,7 @@ func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candid
590599
// are non-overlapping.
591600
sort.Sort((sortByOffsetSlice)(cands))
592601
res = cands[:0]
602+
mergeRun := 1
593603
for i, c := range cands {
594604
if i == 0 {
595605
res = append(res, c)
@@ -599,10 +609,23 @@ func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candid
599609
lastEnd := last.byteOffset + last.byteMatchSz
600610
end := c.byteOffset + c.byteMatchSz
601611
if lastEnd >= c.byteOffset {
612+
mergeRun++
613+
614+
// Average out the score across the merged candidates. Only do it if
615+
// we are boosting to avoid floating point funkiness in the normal
616+
// case.
617+
if last.scoreWeight != 1 && c.scoreWeight != 1 {
618+
last.scoreWeight = ((last.scoreWeight * float64(mergeRun-1)) + c.scoreWeight) / float64(mergeRun)
619+
}
620+
621+
// latest candidate goes further, update our end
602622
if end > lastEnd {
603623
last.byteMatchSz = end - last.byteOffset
604624
}
625+
605626
continue
627+
} else {
628+
mergeRun = 1
606629
}
607630

608631
res = append(res, c)
@@ -649,7 +672,7 @@ func (d *indexData) branchIndex(docID uint32) int {
649672
// returns all branches containing docID.
650673
func (d *indexData) gatherBranches(docID uint32, mt matchTree, known map[matchTree]bool) []string {
651674
var mask uint64
652-
visitMatches(mt, known, func(mt matchTree) {
675+
visitMatchAtoms(mt, known, func(mt matchTree) {
653676
bq, ok := mt.(*branchQueryMatchTree)
654677
if !ok {
655678
return

matchiter.go

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ type candidateMatch struct {
2727
substrBytes []byte
2828
substrLowered []byte
2929

30+
scoreWeight float64
31+
3032
file uint32
3133
symbolIdx uint32
3234

matchtree.go

+50-6
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ type fileNameMatchTree struct {
170170
child matchTree
171171
}
172172

173+
type boostMatchTree struct {
174+
child matchTree
175+
weight float64
176+
}
177+
173178
// Don't visit this subtree for collecting matches.
174179
type noVisitMatchTree struct {
175180
matchTree
@@ -392,6 +397,10 @@ func (t *fileNameMatchTree) prepare(doc uint32) {
392397
t.child.prepare(doc)
393398
}
394399

400+
func (t *boostMatchTree) prepare(doc uint32) {
401+
t.child.prepare(doc)
402+
}
403+
395404
func (t *substrMatchTree) prepare(nextDoc uint32) {
396405
t.matchIterator.prepare(nextDoc)
397406
t.current = t.matchIterator.candidates()
@@ -455,6 +464,10 @@ func (t *fileNameMatchTree) nextDoc() uint32 {
455464
return t.child.nextDoc()
456465
}
457466

467+
func (t *boostMatchTree) nextDoc() uint32 {
468+
return t.child.nextDoc()
469+
}
470+
458471
func (t *branchQueryMatchTree) nextDoc() uint32 {
459472
var start uint32
460473
if t.firstDone {
@@ -515,6 +528,10 @@ func (t *fileNameMatchTree) String() string {
515528
return fmt.Sprintf("f(%v)", t.child)
516529
}
517530

531+
func (t *boostMatchTree) String() string {
532+
return fmt.Sprintf("boost(%f, %v)", t.weight, t.child)
533+
}
534+
518535
func (t *substrMatchTree) String() string {
519536
f := ""
520537
if t.fileName {
@@ -556,6 +573,8 @@ func visitMatchTree(t matchTree, f func(matchTree)) {
556573
visitMatchTree(s.child, f)
557574
case *fileNameMatchTree:
558575
visitMatchTree(s.child, f)
576+
case *boostMatchTree:
577+
visitMatchTree(s.child, f)
559578
case *symbolSubstrMatchTree:
560579
visitMatchTree(s.substrMatchTree, f)
561580
case *symbolRegexpMatchTree:
@@ -575,33 +594,41 @@ func updateMatchTreeStats(mt matchTree, stats *Stats) {
575594
})
576595
}
577596

597+
func visitMatchAtoms(t matchTree, known map[matchTree]bool, f func(matchTree)) {
598+
visitMatches(t, known, 1, func(mt matchTree, _ float64) {
599+
f(mt)
600+
})
601+
}
602+
578603
// visitMatches visits all atoms which can contribute matches. Note: This
579604
// skips noVisitMatchTree.
580-
func visitMatches(t matchTree, known map[matchTree]bool, f func(matchTree)) {
605+
func visitMatches(t matchTree, known map[matchTree]bool, weight float64, f func(matchTree, float64)) {
581606
switch s := t.(type) {
582607
case *andMatchTree:
583608
for _, ch := range s.children {
584609
if known[ch] {
585-
visitMatches(ch, known, f)
610+
visitMatches(ch, known, weight, f)
586611
}
587612
}
588613
case *andLineMatchTree:
589-
visitMatches(&s.andMatchTree, known, f)
614+
visitMatches(&s.andMatchTree, known, weight, f)
590615
case *orMatchTree:
591616
for _, ch := range s.children {
592617
if known[ch] {
593-
visitMatches(ch, known, f)
618+
visitMatches(ch, known, weight, f)
594619
}
595620
}
621+
case *boostMatchTree:
622+
visitMatches(s.child, known, weight*s.weight, f)
596623
case *symbolSubstrMatchTree:
597-
visitMatches(s.substrMatchTree, known, f)
624+
visitMatches(s.substrMatchTree, known, weight, f)
598625
case *notMatchTree:
599626
case *noVisitMatchTree:
600627
// don't collect into negative trees.
601628
case *fileNameMatchTree:
602629
// We will just gather the filename if we do not visit this tree.
603630
default:
604-
f(s)
631+
f(s, weight)
605632
}
606633
}
607634

@@ -876,6 +903,10 @@ func (t *fileNameMatchTree) matches(cp *contentProvider, cost int, known map[mat
876903
return evalMatchTree(cp, cost, known, t.child)
877904
}
878905

906+
func (t *boostMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState {
907+
return evalMatchTree(cp, cost, known, t.child)
908+
}
909+
879910
func (t *substrMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState {
880911
if t.contEvaluated {
881912
return matchesStateForSlice(t.current)
@@ -997,6 +1028,17 @@ func (d *indexData) newMatchTree(q query.Q, opt matchTreeOpt) (matchTree, error)
9971028
child: ct,
9981029
}, nil
9991030

1031+
case *query.Boost:
1032+
ct, err := d.newMatchTree(s.Child, opt)
1033+
if err != nil {
1034+
return nil, err
1035+
}
1036+
1037+
return &boostMatchTree{
1038+
child: ct,
1039+
weight: s.Weight,
1040+
}, nil
1041+
10001042
case *query.Substring:
10011043
return d.newSubstringMatchTree(s)
10021044

@@ -1288,6 +1330,8 @@ func pruneMatchTree(mt matchTree) (matchTree, error) {
12881330
}
12891331
case *fileNameMatchTree:
12901332
mt.child, err = pruneMatchTree(mt.child)
1333+
case *boostMatchTree:
1334+
mt.child, err = pruneMatchTree(mt.child)
12911335
case *andLineMatchTree:
12921336
child, err := pruneMatchTree(&mt.andMatchTree)
12931337
if err != nil {

query/query.go

+21
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,19 @@ func (q *Type) String() string {
386386
}
387387
}
388388

389+
// Boost scales the contribution to score of descendents.
390+
type Boost struct {
391+
Child Q
392+
// Weight will multiply the score of its descendents. Weights less than 1
393+
// will give less importance while values greater than 1 will give more
394+
// importance.
395+
Weight float64
396+
}
397+
398+
func (q *Boost) String() string {
399+
return fmt.Sprintf("(boost %f %s)", q.Weight, q.Child)
400+
}
401+
389402
// Substring is the most basic query: a query for a substring.
390403
type Substring struct {
391404
Pattern string
@@ -609,6 +622,9 @@ func flatten(q Q) (Q, bool) {
609622
case *Type:
610623
child, changed := flatten(s.Child)
611624
return &Type{Child: child, Type: s.Type}, changed
625+
case *Boost:
626+
child, changed := flatten(s.Child)
627+
return &Boost{Child: child, Weight: s.Weight}, changed
612628
default:
613629
return q, false
614630
}
@@ -680,6 +696,8 @@ func evalConstants(q Q) Q {
680696
return ch
681697
}
682698
return &Type{Child: ch, Type: s.Type}
699+
case *Boost:
700+
return &Boost{Weight: s.Weight, Child: evalConstants(s.Child)}
683701
case *Substring:
684702
if len(s.Pattern) == 0 {
685703
return &Const{true}
@@ -728,6 +746,8 @@ func Map(q Q, f func(q Q) Q) Q {
728746
q = &Not{Child: Map(s.Child, f)}
729747
case *Type:
730748
q = &Type{Type: s.Type, Child: Map(s.Child, f)}
749+
case *Boost:
750+
q = &Boost{Weight: s.Weight, Child: Map(s.Child, f)}
731751
}
732752
return f(q)
733753
}
@@ -768,6 +788,7 @@ func VisitAtoms(q Q, v func(q Q)) {
768788
case *Or:
769789
case *Not:
770790
case *Type:
791+
case *Boost:
771792
default:
772793
v(iQ)
773794
}

web/server.go

+12
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,18 @@ func (s *Server) serveSearchErr(r *http.Request) (*ApiSearchResult, error) {
241241
return nil, err
242242
}
243243

244+
// Experimental support to take the query string and boost exact phrases of
245+
// it.
246+
if phraseBoost, err := strconv.ParseFloat(qvals.Get("phrase-boost"), 64); err == nil {
247+
q = query.NewOr(
248+
&query.Boost{
249+
Weight: phraseBoost,
250+
Child: &query.Substring{Pattern: queryStr, Content: true},
251+
},
252+
q,
253+
)
254+
}
255+
244256
repoOnly := true
245257
query.VisitAtoms(q, func(q query.Q) {
246258
_, ok := q.(*query.Repo)

0 commit comments

Comments
 (0)