Skip to content

Commit

Permalink
fix(rewrite): support the case where upstream doesn't echo the question
Browse files Browse the repository at this point in the history
Apparently Tailscale's magic DNS does this.
  • Loading branch information
ThinkChaos committed Dec 20, 2023
1 parent c6304e9 commit dece894
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 17 deletions.
35 changes: 26 additions & 9 deletions resolver/rewriter_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,36 +95,53 @@ func (r *RewriterResolver) Resolve(ctx context.Context, request *model.Request)
return r.next.Resolve(ctx, request)
}

if rewritten == nil {
return response, nil
}

// Revert the rewrite in r.inner's response
if rewritten != nil {
for i := range originalNames {
response.Res.Question[i].Name = originalNames[i]

if i < len(response.Res.Answer) {
response.Res.Answer[i].Header().Name = originalNames[i]
n := max(len(response.Res.Question), len(response.Res.Answer))
for i := 0; i < n; i++ {
if i < len(response.Res.Question) {
original, ok := originalNames[response.Res.Question[i].Name]
if ok {
response.Res.Question[i].Name = original
}
}

if i < len(response.Res.Answer) {
original, ok := originalNames[response.Res.Answer[i].Header().Name]
if ok {
response.Res.Answer[i].Header().Name = original
}
}
}

return response, nil
}

func (r *RewriterResolver) rewriteRequest(logger *logrus.Entry, request *dns.Msg) (rewritten *dns.Msg, originalNames []string) { //nolint: lll
originalNames = make([]string, len(request.Question))
func (r *RewriterResolver) rewriteRequest(
logger *logrus.Entry, request *dns.Msg,
) (rewritten *dns.Msg, originalNames map[string]string) {
originalNames = make(map[string]string, len(request.Question))

for i := range request.Question {
nameOriginal := request.Question[i].Name
originalNames[i] = nameOriginal

domainOriginal := util.ExtractDomainOnly(nameOriginal)
domainRewritten, rewriteKey := r.rewriteDomain(domainOriginal)

if domainRewritten != domainOriginal {
rewrittenFQDN := dns.Fqdn(domainRewritten)

originalNames[rewrittenFQDN] = nameOriginal

if rewritten == nil {
rewritten = request.Copy()
}

rewritten.Question[i].Name = dns.Fqdn(domainRewritten)
rewritten.Question[i].Name = rewrittenFQDN

logger.WithFields(logrus.Fields{
"rewrite": util.Obfuscate(rewriteKey) + ":" + util.Obfuscate(r.cfg.Rewrite[rewriteKey]),
Expand Down
34 changes: 26 additions & 8 deletions resolver/rewriter_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,14 @@ var _ = Describe("RewriterResolver", func() {
})

When("has rewrite", func() {
var request *model.Request
var expectNilAnswer bool
var (
request *model.Request
expectNilAnswer bool
)

BeforeEach(func() {
expectNilAnswer = false
})

AfterEach(func() {
request = newRequest(fqdnOriginal, dns.Type(dns.TypeA))

mInner.On("Resolve", mock.Anything)
mInner.ResponseFn = func(req *dns.Msg) *dns.Msg {
Expect(req).Should(Equal(request.Req))

Expand All @@ -95,11 +92,19 @@ var _ = Describe("RewriterResolver", func() {

return res
}
})

AfterEach(func() {
request = newRequest(fqdnOriginal, dns.Type(dns.TypeA))

mInner.On("Resolve", mock.Anything)

resp, err := sut.Resolve(context.Background(), request)
Expect(err).Should(Succeed())
if resp != mNextResponse {
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
if len(resp.Res.Question) != 0 {
Expect(resp.Res.Question[0].Name).Should(Equal(fqdnOriginal))
}
if expectNilAnswer {
Expect(resp.Res.Answer).Should(BeEmpty())
} else {
Expand Down Expand Up @@ -128,6 +133,19 @@ var _ = Describe("RewriterResolver", func() {
fqdnRewritten = fqdnOriginal
})

It("should support a reply without the question", func() {
fqdnOriginal = sampleOriginal
fqdnRewritten = sampleRewritten

origResponseFn := mInner.ResponseFn
mInner.ResponseFn = func(req *dns.Msg) *dns.Msg {
res := origResponseFn(req)
res.Question = nil

return res
}
})

It("should call next resolver", func() {
fqdnOriginal = sampleOriginal
fqdnRewritten = sampleRewritten
Expand Down

0 comments on commit dece894

Please sign in to comment.